matc/
transport.rs

1// Simple UDP transport abstraction that multiplexes datagrams by remote address
2// into per-connection mpsc channels. Each Connection is a logical association
3// identified solely by the peer's socket address string.
4
5use anyhow::{Context, Result};
6use std::{collections::HashMap, sync::Arc, time::Duration};
7use tokio::{net::UdpSocket, sync::Mutex};
8
9/// Transport-agnostic connection: send and receive raw Matter messages.
10///
11/// Implement this for UDP ([`Connection`]) and BTP ([`crate::btp::BtpConnection`]).
12#[async_trait::async_trait]
13pub trait ConnectionTrait: Send + Sync {
14    async fn send(&self, data: &[u8]) -> Result<()>;
15    async fn receive(&self, timeout: Duration) -> Result<Vec<u8>>;
16    /// True for transports (BTP) that guarantee delivery so Matter-layer MRP
17    /// retransmit should be suppressed.  Default: false (UDP).
18    fn is_reliable(&self) -> bool { false }
19}
20
21#[derive(Debug, Clone)]
22struct ConnectionInfo {
23    sender: tokio::sync::mpsc::Sender<Vec<u8>>,
24}
25
26/// Shared transport holding:
27/// * a single UDP socket
28/// * a map of remote_addr -> channel sender
29/// * a task to read incoming datagrams and dispatch them
30/// * a task to remove connection entries when Connections drop
31pub struct Transport {
32    socket: Arc<UdpSocket>,
33    connections: Mutex<HashMap<String, ConnectionInfo>>,
34    remove_channel_sender: tokio::sync::mpsc::UnboundedSender<String>,
35    stop_receive_token: tokio_util::sync::CancellationToken,
36}
37
38/// Logical connection bound to a remote UDP address. Receiving is done by
39/// reading from an mpsc channel populated by the Transport reader task.
40pub struct Connection {
41    transport: Arc<Transport>,
42    remote_address: String,
43    receiver: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
44}
45
46impl Transport {
47    async fn read_from_socket_loop(
48        socket: Arc<UdpSocket>,
49        stop_receive_token: tokio_util::sync::CancellationToken,
50        self_weak: std::sync::Weak<Transport>,
51    ) -> Result<()> {
52        loop {
53            let mut buf = vec![0u8; 1024];
54            let (n, addr) = {
55                tokio::select! {
56                    recv_resp = socket.recv_from(&mut buf) => recv_resp,
57                    _ = stop_receive_token.cancelled() => break
58                }
59            }?;
60            buf.resize(n, 0);
61            let self_strong = self_weak
62                .upgrade()
63                .context("weakpointer to self is gone - just stop")?;
64            let cons = self_strong.connections.lock().await;
65            if let Some(c) = cons.get(&addr.to_string()) {
66                _ = c.sender.send(buf).await;
67            }
68        }
69        Ok(())
70    }
71
72    async fn read_from_delete_queue_loop(
73        mut remove_channel_receiver: tokio::sync::mpsc::UnboundedReceiver<String>,
74        self_weak: std::sync::Weak<Transport>,
75    ) -> Result<()> {
76        loop {
77            let to_remove = remove_channel_receiver.recv().await;
78            match to_remove {
79                Some(to_remove) => {
80                    if to_remove.is_empty() {
81                        // Empty string used as sentinel to terminate this task.
82                        break;
83                    }
84                    let self_strong = self_weak
85                        .upgrade()
86                        .context("weak to self is gone - just stop")?;
87                    let mut cons = self_strong.connections.lock().await;
88                    _ = cons.remove(&to_remove);
89                }
90                None => break, // Sender dropped => shutdown
91            }
92        }
93        Ok(())
94    }
95
96    /// Bind a UDP socket and spawn background tasks.
97    pub async fn new(local: &str) -> Result<Arc<Self>> {
98        let socket = UdpSocket::bind(local).await?;
99        let (remove_channel_sender, remove_channel_receiver) =
100            tokio::sync::mpsc::unbounded_channel();
101        let stop_receive_token = tokio_util::sync::CancellationToken::new();
102        let stop_receive_token_child = stop_receive_token.child_token();
103        let o = Arc::new(Self {
104            socket: Arc::new(socket),
105            connections: Mutex::new(HashMap::new()),
106            remove_channel_sender,
107            stop_receive_token,
108        });
109        let self_weak = Arc::downgrade(&o.clone());
110        let socket = o.socket.clone();
111        tokio::spawn(async move {
112            _ = Self::read_from_socket_loop(socket, stop_receive_token_child, self_weak).await;
113        });
114        let self_weak = Arc::downgrade(&o.clone());
115        tokio::spawn(async move {
116            _ = Self::read_from_delete_queue_loop(remove_channel_receiver, self_weak).await;
117        });
118        Ok(o)
119    }
120
121    /// Create (or replace) a logical connection entry for the given remote address.
122    pub async fn create_connection(self: &Arc<Self>, remote: &str) -> Arc<dyn ConnectionTrait> {
123        let mut clock = self.connections.lock().await;
124        let (sender, receiver) = tokio::sync::mpsc::channel(32);
125        clock.insert(remote.to_owned(), ConnectionInfo { sender });
126        Arc::new(Connection {
127            transport: self.clone(),
128            remote_address: remote.to_owned(),
129            receiver: Mutex::new(receiver),
130        })
131    }
132}
133
134impl Connection {
135    /// Send a datagram to the remote address.
136    pub async fn send(&self, data: &[u8]) -> Result<()> {
137        self.transport
138            .socket
139            .send_to(data, &self.remote_address)
140            .await?;
141        Ok(())
142    }
143    /// Receive the next datagram for this connection (with timeout).
144    pub async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
145        let mut ch = self.receiver.lock().await;
146        let rec_future = ch.recv();
147        let with_timeout = tokio::time::timeout(timeout, rec_future);
148        with_timeout.await?.context("eof")
149    }
150}
151
152impl Drop for Transport {
153    fn drop(&mut self) {
154        _ = self.remove_channel_sender.send("".to_owned());
155        self.stop_receive_token.cancel();
156    }
157}
158
159#[async_trait::async_trait]
160impl ConnectionTrait for Connection {
161    async fn send(&self, data: &[u8]) -> Result<()> {
162        self.send(data).await
163    }
164    async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
165        self.receive(timeout).await
166    }
167}
168
169impl Drop for Connection {
170    fn drop(&mut self) {
171        _ = self
172            .transport
173            .remove_channel_sender
174            .send(self.remote_address.clone());
175    }
176}