matc/
transport.rs

1use anyhow::{Context, Result};
2use std::{collections::HashMap, sync::Arc, time::Duration};
3use tokio::{net::UdpSocket, sync::Mutex};
4
5#[derive(Debug, Clone)]
6struct ConnectionInfo {
7    sender: tokio::sync::mpsc::Sender<Vec<u8>>,
8}
9
10pub struct Transport {
11    socket: Arc<UdpSocket>,
12    connections: Mutex<HashMap<String, ConnectionInfo>>,
13    remove_channel_sender: tokio::sync::mpsc::UnboundedSender<String>,
14    stop_receive_token: tokio_util::sync::CancellationToken,
15}
16
17pub struct Connection {
18    transport: Arc<Transport>,
19    remote_address: String,
20    receiver: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
21}
22
23impl Transport {
24    async fn read_from_socket_loop(
25        socket: Arc<UdpSocket>,
26        stop_receive_token: tokio_util::sync::CancellationToken,
27        self_weak: std::sync::Weak<Transport>,
28    ) -> Result<()> {
29        loop {
30            let mut buf = vec![0u8; 1024];
31            let (n, addr) = {
32                tokio::select! {
33                    recv_resp = socket.recv_from(&mut buf) => recv_resp,
34                    _ = stop_receive_token.cancelled() => break
35                }
36            }?;
37            buf.resize(n, 0);
38            let self_strong = self_weak
39                .upgrade()
40                .context("weakpointer to self is gone - just stop")?;
41            let cons = self_strong.connections.lock().await;
42            if let Some(c) = cons.get(&addr.to_string()) {
43                _ = c.sender.send(buf).await;
44            }
45        }
46        Ok(())
47    }
48
49    async fn read_from_delete_queue_loop(
50        mut remove_channel_receiver: tokio::sync::mpsc::UnboundedReceiver<String>,
51        self_weak: std::sync::Weak<Transport>,
52    ) -> Result<()> {
53        loop {
54            let to_remove = remove_channel_receiver.recv().await;
55            match to_remove {
56                Some(to_remove) => {
57                    if to_remove.is_empty() {
58                        break;
59                    }
60                    let self_strong = self_weak
61                        .upgrade()
62                        .context("weak to self is gone - just stop")?;
63                    let mut cons = self_strong.connections.lock().await;
64                    _ = cons.remove(&to_remove);
65                }
66                None => break,
67            }
68        }
69        Ok(())
70    }
71
72    pub async fn new(local: &str) -> Result<Arc<Self>> {
73        let socket = UdpSocket::bind(local).await?;
74        let (remove_channel_sender, remove_channel_receiver) =
75            tokio::sync::mpsc::unbounded_channel();
76        let stop_receive_token = tokio_util::sync::CancellationToken::new();
77        let stop_receive_token_child = stop_receive_token.child_token();
78        let o = Arc::new(Self {
79            socket: Arc::new(socket),
80            connections: Mutex::new(HashMap::new()),
81            remove_channel_sender,
82            stop_receive_token,
83        });
84        let self_weak = Arc::downgrade(&o.clone());
85        let socket = o.socket.clone();
86        tokio::spawn(async move {
87            _ = Self::read_from_socket_loop(socket, stop_receive_token_child, self_weak).await;
88        });
89        let self_weak = Arc::downgrade(&o.clone());
90        tokio::spawn(async move {
91            _ = Self::read_from_delete_queue_loop(remove_channel_receiver, self_weak).await;
92        });
93        Ok(o)
94    }
95
96    pub async fn create_connection(self: &Arc<Self>, remote: &str) -> Arc<Connection> {
97        let mut clock = self.connections.lock().await;
98        let (sender, receiver) = tokio::sync::mpsc::channel(32);
99        clock.insert(remote.to_owned(), ConnectionInfo { sender });
100        Arc::new(Connection {
101            transport: self.clone(),
102            remote_address: remote.to_owned(),
103            receiver: Mutex::new(receiver),
104        })
105    }
106}
107
108impl Connection {
109    pub async fn send(&self, data: &[u8]) -> Result<()> {
110        self.transport
111            .socket
112            .send_to(data, &self.remote_address)
113            .await?;
114        Ok(())
115    }
116    pub async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
117        let mut ch = self.receiver.lock().await;
118        let rec_future = ch.recv();
119        let with_timeout = tokio::time::timeout(timeout, rec_future);
120        with_timeout.await?.context("eof")
121    }
122}
123
124impl Drop for Transport {
125    fn drop(&mut self) {
126        _ = self.remove_channel_sender.send("".to_owned());
127        self.stop_receive_token.cancel();
128    }
129}
130
131impl Drop for Connection {
132    fn drop(&mut self) {
133        _ = self
134            .transport
135            .remove_channel_sender
136            .send(self.remote_address.clone());
137    }
138}