matc/
transport.rs

1// Simple UDP transport abstraction that multiplexes datagrams by remote address
2// into per-connection mpsc channels. Each Connection is a logical association
3// identified solely by the peer's socket address string.
4
5use anyhow::{Context, Result};
6use std::{collections::HashMap, sync::Arc, time::Duration};
7use tokio::{net::UdpSocket, sync::Mutex};
8
9#[derive(Debug, Clone)]
10struct ConnectionInfo {
11    sender: tokio::sync::mpsc::Sender<Vec<u8>>,
12}
13
14/// Shared transport holding:
15/// * a single UDP socket
16/// * a map of remote_addr -> channel sender
17/// * a task to read incoming datagrams and dispatch them
18/// * a task to remove connection entries when Connections drop
19pub struct Transport {
20    socket: Arc<UdpSocket>,
21    connections: Mutex<HashMap<String, ConnectionInfo>>,
22    remove_channel_sender: tokio::sync::mpsc::UnboundedSender<String>,
23    stop_receive_token: tokio_util::sync::CancellationToken,
24}
25
26/// Logical connection bound to a remote UDP address. Receiving is done by
27/// reading from an mpsc channel populated by the Transport reader task.
28pub struct Connection {
29    transport: Arc<Transport>,
30    remote_address: String,
31    receiver: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
32}
33
34impl Transport {
35    async fn read_from_socket_loop(
36        socket: Arc<UdpSocket>,
37        stop_receive_token: tokio_util::sync::CancellationToken,
38        self_weak: std::sync::Weak<Transport>,
39    ) -> Result<()> {
40        loop {
41            let mut buf = vec![0u8; 1024];
42            let (n, addr) = {
43                tokio::select! {
44                    recv_resp = socket.recv_from(&mut buf) => recv_resp,
45                    _ = stop_receive_token.cancelled() => break
46                }
47            }?;
48            buf.resize(n, 0);
49            let self_strong = self_weak
50                .upgrade()
51                .context("weakpointer to self is gone - just stop")?;
52            let cons = self_strong.connections.lock().await;
53            if let Some(c) = cons.get(&addr.to_string()) {
54                _ = c.sender.send(buf).await;
55            }
56        }
57        Ok(())
58    }
59
60    async fn read_from_delete_queue_loop(
61        mut remove_channel_receiver: tokio::sync::mpsc::UnboundedReceiver<String>,
62        self_weak: std::sync::Weak<Transport>,
63    ) -> Result<()> {
64        loop {
65            let to_remove = remove_channel_receiver.recv().await;
66            match to_remove {
67                Some(to_remove) => {
68                    if to_remove.is_empty() {
69                        // Empty string used as sentinel to terminate this task.
70                        break;
71                    }
72                    let self_strong = self_weak
73                        .upgrade()
74                        .context("weak to self is gone - just stop")?;
75                    let mut cons = self_strong.connections.lock().await;
76                    _ = cons.remove(&to_remove);
77                }
78                None => break, // Sender dropped => shutdown
79            }
80        }
81        Ok(())
82    }
83
84    /// Bind a UDP socket and spawn background tasks.
85    pub async fn new(local: &str) -> Result<Arc<Self>> {
86        let socket = UdpSocket::bind(local).await?;
87        let (remove_channel_sender, remove_channel_receiver) =
88            tokio::sync::mpsc::unbounded_channel();
89        let stop_receive_token = tokio_util::sync::CancellationToken::new();
90        let stop_receive_token_child = stop_receive_token.child_token();
91        let o = Arc::new(Self {
92            socket: Arc::new(socket),
93            connections: Mutex::new(HashMap::new()),
94            remove_channel_sender,
95            stop_receive_token,
96        });
97        let self_weak = Arc::downgrade(&o.clone());
98        let socket = o.socket.clone();
99        tokio::spawn(async move {
100            _ = Self::read_from_socket_loop(socket, stop_receive_token_child, self_weak).await;
101        });
102        let self_weak = Arc::downgrade(&o.clone());
103        tokio::spawn(async move {
104            _ = Self::read_from_delete_queue_loop(remove_channel_receiver, self_weak).await;
105        });
106        Ok(o)
107    }
108
109    /// Create (or replace) a logical connection entry for the given remote address.
110    pub async fn create_connection(self: &Arc<Self>, remote: &str) -> Arc<Connection> {
111        let mut clock = self.connections.lock().await;
112        let (sender, receiver) = tokio::sync::mpsc::channel(32);
113        clock.insert(remote.to_owned(), ConnectionInfo { sender });
114        Arc::new(Connection {
115            transport: self.clone(),
116            remote_address: remote.to_owned(),
117            receiver: Mutex::new(receiver),
118        })
119    }
120}
121
122impl Connection {
123    /// Send a datagram to the remote address.
124    pub async fn send(&self, data: &[u8]) -> Result<()> {
125        self.transport
126            .socket
127            .send_to(data, &self.remote_address)
128            .await?;
129        Ok(())
130    }
131    /// Receive the next datagram for this connection (with timeout).
132    pub async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
133        let mut ch = self.receiver.lock().await;
134        let rec_future = ch.recv();
135        let with_timeout = tokio::time::timeout(timeout, rec_future);
136        with_timeout.await?.context("eof")
137    }
138}
139
140impl Drop for Transport {
141    fn drop(&mut self) {
142        _ = self.remove_channel_sender.send("".to_owned());
143        self.stop_receive_token.cancel();
144    }
145}
146
147impl Drop for Connection {
148    fn drop(&mut self) {
149        _ = self
150            .transport
151            .remove_channel_sender
152            .send(self.remote_address.clone());
153    }
154}