Skip to main content

matc/
transport.rs

1// Simple UDP transport abstraction that multiplexes datagrams by remote address
2// into per-connection mpsc channels. Each Connection is a logical association
3// identified solely by the peer's socket address string.
4
5/// Returned by [`Connection::receive`] when the underlying mpsc channel has been
6/// closed (e.g. because the same remote address was re-registered via
7/// [`Transport::create_connection`]). Callers can detect this via
8/// `anyhow::Error::downcast_ref::<ConnectionClosed>()` and bail immediately
9/// instead of spinning on retransmit.
10#[derive(Debug)]
11pub struct ConnectionClosed;
12
13impl std::fmt::Display for ConnectionClosed {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        write!(f, "connection closed")
16    }
17}
18
19impl std::error::Error for ConnectionClosed {}
20
21use anyhow::{Context, Result};
22use std::{
23    collections::HashMap,
24    net::{IpAddr, SocketAddr},
25    sync::{
26        atomic::{AtomicU64, Ordering},
27        Arc,
28    },
29    time::Duration,
30};
31use tokio::{net::UdpSocket, sync::Mutex};
32
33/// Normalize a peer address string so that its address family matches the
34/// local socket. This canonical form is used both as the connection map key
35/// and as the destination for `send_to`, so inbound and outbound paths agree
36/// without per-packet conversion.
37///
38/// * V6 socket + V4 peer  -> V4-mapped V6 (`[::ffff:a.b.c.d]:port`), needed
39///   because the kernel rejects cross-family `sendto` on AF_INET6.
40/// * V4 socket + V4-mapped V6 peer -> plain V4.
41/// * Other combinations are returned unchanged (real V6 on a V4 socket will
42///   fail at `send_to`, which is the correct behavior).
43fn normalize_remote_for_socket(socket: &UdpSocket, remote: &str) -> String {
44    let Ok(parsed) = remote.parse::<SocketAddr>() else {
45        return remote.to_owned();
46    };
47    let Ok(local) = socket.local_addr() else {
48        return parsed.to_string();
49    };
50    let normalized = match (local.is_ipv6(), parsed) {
51        (true, SocketAddr::V4(v4)) => {
52            let mapped = v4.ip().to_ipv6_mapped();
53            SocketAddr::new(IpAddr::V6(mapped), v4.port())
54        }
55        (false, SocketAddr::V6(v6)) => {
56            if let Some(v4) = v6.ip().to_ipv4_mapped() {
57                SocketAddr::new(IpAddr::V4(v4), v6.port())
58            } else {
59                parsed
60            }
61        }
62        _ => parsed,
63    };
64    normalized.to_string()
65}
66
67/// Transport-agnostic connection: send and receive raw Matter messages.
68///
69/// Implement this for UDP ([`Connection`]) and BTP ([`crate::btp::BtpConnection`]).
70#[async_trait::async_trait]
71pub trait ConnectionTrait: Send + Sync {
72    async fn send(&self, data: &[u8]) -> Result<()>;
73    async fn receive(&self, timeout: Duration) -> Result<Vec<u8>>;
74    /// True for transports (BTP) that guarantee delivery so Matter-layer MRP
75    /// retransmit should be suppressed.  Default: false (UDP).
76    fn is_reliable(&self) -> bool { false }
77    /// Peer MRP intervals used for retransmission timing. Defaults to spec
78    /// defaults unless overridden via [`ConnectionTrait::set_mrp_params`].
79    fn mrp_params(&self) -> crate::mrp::MrpParameters { Default::default() }
80    /// Set peer MRP intervals (typically from its mDNS SII/SAI/SAT TXT records).
81    fn set_mrp_params(&self, _params: crate::mrp::MrpParameters) {}
82    /// Time since the last message was received from the peer, if any.
83    /// Used to select the active vs idle retransmission interval.
84    fn last_received_elapsed(&self) -> Option<Duration> { None }
85}
86
87#[derive(Debug, Clone)]
88struct ConnectionInfo {
89    sender: tokio::sync::mpsc::Sender<Vec<u8>>,
90    generation: u64,
91}
92
93/// Shared transport holding:
94/// * a single UDP socket
95/// * a map of remote_addr -> channel sender
96/// * a task to read incoming datagrams and dispatch them
97/// * a task to remove connection entries when Connections drop
98pub struct Transport {
99    socket: Arc<UdpSocket>,
100    connections: Mutex<HashMap<String, ConnectionInfo>>,
101    remove_channel_sender: tokio::sync::mpsc::UnboundedSender<(String, u64)>,
102    next_generation: AtomicU64,
103    stop_receive_token: tokio_util::sync::CancellationToken,
104}
105
106/// Logical connection bound to a remote UDP address. Receiving is done by
107/// reading from an mpsc channel populated by the Transport reader task.
108pub struct Connection {
109    transport: Arc<Transport>,
110    remote_address: String,
111    receiver: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
112    generation: u64,
113    mrp: std::sync::Mutex<crate::mrp::MrpParameters>,
114    created: tokio::time::Instant,
115    /// Milliseconds since `created` of the last received datagram; u64::MAX = never.
116    last_rx_ms: AtomicU64,
117}
118
119impl Transport {
120    async fn read_from_socket_loop(
121        socket: Arc<UdpSocket>,
122        stop_receive_token: tokio_util::sync::CancellationToken,
123        self_weak: std::sync::Weak<Transport>,
124    ) -> Result<()> {
125        loop {
126            let mut buf = vec![0u8; 2048];
127            let recv_result = {
128                tokio::select! {
129                    recv_resp = socket.recv_from(&mut buf) => recv_resp,
130                    _ = stop_receive_token.cancelled() => break
131                }
132            };
133            let (n, addr) = match recv_result {
134                Ok(r) => r,
135                Err(e) => {
136                    log::debug!("transport recv error (ignored): {:?}", e);
137                    continue;
138                }
139            };
140            buf.resize(n, 0);
141            let self_strong = self_weak
142                .upgrade()
143                .context("weakpointer to self is gone - just stop")?;
144            let cons = self_strong.connections.lock().await;
145            if let Some(c) = cons.get(&addr.to_string()) {
146                _ = c.sender.send(buf).await;
147            }
148        }
149        Ok(())
150    }
151
152    async fn read_from_delete_queue_loop(
153        mut remove_channel_receiver: tokio::sync::mpsc::UnboundedReceiver<(String, u64)>,
154        self_weak: std::sync::Weak<Transport>,
155    ) -> Result<()> {
156        loop {
157            let to_remove = remove_channel_receiver.recv().await;
158            match to_remove {
159                Some((addr, _gen)) if addr.is_empty() => {
160                    // Empty address is the shutdown sentinel.
161                    break;
162                }
163                Some((addr, gen)) => {
164                    let self_strong = self_weak
165                        .upgrade()
166                        .context("weak to self is gone - just stop")?;
167                    let mut cons = self_strong.connections.lock().await;
168                    // Only remove if the entry still belongs to this Connection.
169                    // A concurrent create_connection for the same address inserts a
170                    // newer generation, so the stale remove becomes a no-op.
171                    if cons.get(&addr).map(|c| c.generation) == Some(gen) {
172                        cons.remove(&addr);
173                    }
174                }
175                None => break, // Sender dropped => shutdown
176            }
177        }
178        Ok(())
179    }
180
181    /// Bind a UDP socket and spawn background tasks.
182    pub async fn new(local: &str) -> Result<Arc<Self>> {
183        let socket = UdpSocket::bind(local).await?;
184        let (remove_channel_sender, remove_channel_receiver) =
185            tokio::sync::mpsc::unbounded_channel();
186        let stop_receive_token = tokio_util::sync::CancellationToken::new();
187        let stop_receive_token_child = stop_receive_token.child_token();
188        let o = Arc::new(Self {
189            socket: Arc::new(socket),
190            connections: Mutex::new(HashMap::new()),
191            remove_channel_sender,
192            next_generation: AtomicU64::new(1),
193            stop_receive_token,
194        });
195        let self_weak = Arc::downgrade(&o.clone());
196        let socket = o.socket.clone();
197        tokio::spawn(async move {
198            _ = Self::read_from_socket_loop(socket, stop_receive_token_child, self_weak).await;
199        });
200        let self_weak = Arc::downgrade(&o.clone());
201        tokio::spawn(async move {
202            _ = Self::read_from_delete_queue_loop(remove_channel_receiver, self_weak).await;
203        });
204        Ok(o)
205    }
206
207    /// Create (or replace) a logical connection entry for the given remote address.
208    pub async fn create_connection(self: &Arc<Self>, remote: &str) -> Arc<dyn ConnectionTrait> {
209        let remote = normalize_remote_for_socket(&self.socket, remote);
210        let mut clock = self.connections.lock().await;
211        let generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
212        let (sender, receiver) = tokio::sync::mpsc::channel(32);
213        clock.insert(remote.to_owned(), ConnectionInfo { sender, generation });
214        Arc::new(Connection {
215            transport: self.clone(),
216            remote_address: remote,
217            receiver: Mutex::new(receiver),
218            generation,
219            mrp: std::sync::Mutex::new(Default::default()),
220            created: tokio::time::Instant::now(),
221            last_rx_ms: AtomicU64::new(u64::MAX),
222        })
223    }
224}
225
226impl Connection {
227    /// Send a datagram to the remote address.
228    pub async fn send(&self, data: &[u8]) -> Result<()> {
229        self.transport
230            .socket
231            .send_to(data, &self.remote_address)
232            .await?;
233        Ok(())
234    }
235    /// Receive the next datagram for this connection (with timeout).
236    ///
237    /// Returns `Err(ConnectionClosed)` (detectable via `downcast_ref`) when the
238    /// channel is permanently closed, distinct from a normal receive timeout.
239    pub async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
240        let mut ch = self.receiver.lock().await;
241        let rec_future = ch.recv();
242        let with_timeout = tokio::time::timeout(timeout, rec_future);
243        match with_timeout.await {
244            Err(_elapsed) => Err(anyhow::anyhow!("receive timeout")),
245            Ok(None) => Err(anyhow::Error::new(ConnectionClosed)),
246            Ok(Some(v)) => {
247                self.last_rx_ms
248                    .store(self.created.elapsed().as_millis() as u64, Ordering::Relaxed);
249                Ok(v)
250            }
251        }
252    }
253}
254
255impl Drop for Transport {
256    fn drop(&mut self) {
257        _ = self.remove_channel_sender.send(("".to_owned(), 0));
258        self.stop_receive_token.cancel();
259    }
260}
261
262#[async_trait::async_trait]
263impl ConnectionTrait for Connection {
264    async fn send(&self, data: &[u8]) -> Result<()> {
265        self.send(data).await
266    }
267    async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
268        self.receive(timeout).await
269    }
270    fn mrp_params(&self) -> crate::mrp::MrpParameters {
271        *self.mrp.lock().unwrap()
272    }
273    fn set_mrp_params(&self, params: crate::mrp::MrpParameters) {
274        *self.mrp.lock().unwrap() = params;
275    }
276    fn last_received_elapsed(&self) -> Option<Duration> {
277        let ms = self.last_rx_ms.load(Ordering::Relaxed);
278        if ms == u64::MAX {
279            return None;
280        }
281        Some(self.created.elapsed().saturating_sub(Duration::from_millis(ms)))
282    }
283}
284
285impl Drop for Connection {
286    fn drop(&mut self) {
287        _ = self
288            .transport
289            .remove_channel_sender
290            .send((self.remote_address.clone(), self.generation));
291    }
292}
293