matc/mdns2/
protocol.rs

1//! basic mDNS protocol: record caching, wire-format encoding, multicast sockets, send loop.
2
3use std::collections::HashMap;
4use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use anyhow::Result;
9use byteorder::{BigEndian, WriteBytesExt};
10use socket2::{Domain, Protocol, Type};
11use tokio::net::UdpSocket;
12use tokio::sync::mpsc::UnboundedReceiver;
13use tokio_util::sync::CancellationToken;
14
15use crate::mdns;
16
17pub(super) const MDNS_ADDR_V4: &str = "224.0.0.251:5353";
18pub(super) const MDNS_ADDR_V6: &str = "[ff02::fb]:5353";
19
20#[derive(Debug, Clone)]
21pub struct CachedRecord {
22    pub rr: mdns::RR,
23    pub received_at: Instant,
24    pub ttl: Duration,
25}
26
27impl CachedRecord {
28    fn is_expired(&self) -> bool {
29        self.received_at.elapsed() > self.ttl
30    }
31}
32
33/// Cache of DNS resource records, keyed by "lowercase name, record type".
34pub struct RecordCache {
35    pub(super) entries: HashMap<(String, u16), Vec<CachedRecord>>,
36}
37
38impl RecordCache {
39    pub fn new() -> Self {
40        Self {
41            entries: HashMap::new(),
42        }
43    }
44
45    /// Insert or update records from a DNS response.
46    /// TTL=0 removes the specific record whose rdata matches (RFC 6762 10.1).
47    pub fn ingest(&mut self, rr: &mdns::RR) -> bool {
48        let key = (rr.name.to_lowercase(), rr.typ);
49        if rr.ttl == 0 {
50            if let Some(vec) = self.entries.get_mut(&key) {
51                vec.retain(|c| c.rr.rdata != rr.rdata);
52                if vec.is_empty() {
53                    self.entries.remove(&key);
54                }
55            }
56            return false;
57        }
58        let cached = CachedRecord {
59            rr: rr.clone(),
60            received_at: Instant::now(),
61            ttl: Duration::from_secs(rr.ttl as u64),
62        };
63        let vec = self.entries.entry(key).or_default();
64        // Replace if same rdata, otherwise add
65        if let Some(existing) = vec.iter_mut().find(|c| c.rr.rdata == rr.rdata) {
66            *existing = cached;
67            false
68        } else {
69            vec.push(cached);
70            true
71        }
72    }
73
74    /// Remove expired entries. Returns list of (name, type) keys that were fully removed.
75    pub fn evict_expired(&mut self) -> Vec<(String, u16)> {
76        let mut expired_keys = Vec::new();
77        self.entries.retain(|key, records| {
78            records.retain(|c| !c.is_expired());
79            if records.is_empty() {
80                expired_keys.push(key.clone());
81                false
82            } else {
83                true
84            }
85        });
86        expired_keys
87    }
88
89    /// Lookup non-expired records by exact (lowercase name, type).
90    pub fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
91        let key = (name.to_lowercase(), qtype);
92        self.entries
93            .get(&key)
94            .map(|v| {
95                v.iter()
96                    .filter(|c| !c.is_expired())
97                    .map(|c| c.rr.clone())
98                    .collect()
99            })
100            .unwrap_or_default()
101    }
102
103    /// Lookup all non-expired records matching a name (any type).
104    pub fn lookup_name(&self, name: &str) -> Vec<mdns::RR> {
105        let lower = name.to_lowercase();
106        self.entries
107            .iter()
108            .filter(|((n, _), _)| *n == lower)
109            .flat_map(|(_, v)| v.iter().filter(|c| !c.is_expired()).map(|c| c.rr.clone()))
110            .collect()
111    }
112}
113
114impl Default for RecordCache {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120#[allow(dead_code)]
121pub(super) enum SendCommand {
122    /// Send to multicast on all sockets
123    Multicast(Vec<u8>),
124    /// Send to a specific address (for unicast response)
125    Unicast(Vec<u8>, std::net::SocketAddr),
126}
127
128/// Encode a single resource record to wire format.
129pub(super) fn encode_rr(rr: &mdns::RR, out: &mut Vec<u8>) -> Result<()> {
130    mdns::encode_label(&rr.name, out)?;
131    out.write_u16::<BigEndian>(rr.typ)?;
132    out.write_u16::<BigEndian>(rr.class)?;
133    out.write_u32::<BigEndian>(rr.ttl)?;
134
135    if rr.typ == mdns::TYPE_SRV {
136        // SRV: priority(2) + weight(2) + port(2) + target(variable)
137        let mut rdata = Vec::new();
138        // First 6 bytes are priority, weight, port from existing rdata if available
139        if rr.rdata.len() >= 6 {
140            rdata.extend_from_slice(&rr.rdata[..6]);
141        } else {
142            rdata.write_u16::<BigEndian>(0)?; // priority
143            rdata.write_u16::<BigEndian>(0)?; // weight
144            rdata.write_u16::<BigEndian>(0)?; // port
145        }
146        // If there's a target field, re-encode it as a label
147        if let Some(ref target) = rr.target {
148            // Rebuild rdata with the target label
149            let mut srv_rdata = Vec::new();
150            srv_rdata.extend_from_slice(&rdata[..6]);
151            mdns::encode_label(target.trim_end_matches('.'), &mut srv_rdata)?;
152            out.write_u16::<BigEndian>(srv_rdata.len() as u16)?;
153            out.extend_from_slice(&srv_rdata);
154        } else {
155            out.write_u16::<BigEndian>(rr.rdata.len() as u16)?;
156            out.extend_from_slice(&rr.rdata);
157        }
158    } else {
159        out.write_u16::<BigEndian>(rr.rdata.len() as u16)?;
160        out.extend_from_slice(&rr.rdata);
161    }
162    Ok(())
163}
164
165/// Build an mDNS response packet from answer and additional record lists.
166pub(super) fn build_response(answers: &[mdns::RR], additional: &[mdns::RR]) -> Result<Vec<u8>> {
167    let mut out = Vec::with_capacity(512);
168    out.write_u16::<BigEndian>(0)?; // transaction id
169    out.write_u16::<BigEndian>(0x8400)?; // flags: response, authoritative
170    out.write_u16::<BigEndian>(0)?; // questions
171    out.write_u16::<BigEndian>(answers.len() as u16)?;
172    out.write_u16::<BigEndian>(0)?; // authority
173    out.write_u16::<BigEndian>(additional.len() as u16)?;
174
175    for rr in answers {
176        encode_rr(rr, &mut out)?;
177    }
178    for rr in additional {
179        encode_rr(rr, &mut out)?;
180    }
181    Ok(out)
182}
183
184pub(super) fn create_multicast_socket_v4() -> Result<std::net::UdpSocket> {
185    let sock = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
186    sock.set_reuse_address(true)?;
187    #[cfg(not(target_os = "windows"))]
188    sock.set_reuse_port(true)?;
189    let addr: SocketAddrV4 = "0.0.0.0:5353".parse()?;
190    sock.bind(&socket2::SockAddr::from(addr))?;
191    let maddr: Ipv4Addr = "224.0.0.251".parse()?;
192    sock.join_multicast_v4(&maddr, &Ipv4Addr::UNSPECIFIED)?;
193    sock.set_nonblocking(true)?;
194    Ok(sock.into())
195}
196
197pub(super) fn create_multicast_socket_v6(interface: u32) -> Result<std::net::UdpSocket> {
198    let sock = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
199    sock.set_reuse_address(true)?;
200    #[cfg(not(target_os = "windows"))]
201    sock.set_reuse_port(true)?;
202    let addr: SocketAddrV6 = "[::]:5353".parse()?;
203    sock.bind(&socket2::SockAddr::from(addr))?;
204    let maddr: Ipv6Addr = "ff02::fb".parse()?;
205    sock.join_multicast_v6(&maddr, interface)?;
206    sock.set_multicast_if_v6(interface)?;
207    sock.set_nonblocking(true)?;
208    Ok(sock.into())
209}
210
211// TODO: ipv6 is disabled for testing
212pub(super) fn get_local_ips() -> (Vec<Ipv4Addr>, Vec<Ipv6Addr>) {
213    let mut v4 = Vec::new();
214    let v6 = Vec::new();
215    if let Ok(ifaces) = if_addrs::get_if_addrs() {
216        for iface in ifaces {
217            match iface.ip() {
218                std::net::IpAddr::V4(ip) if !ip.is_loopback() => v4.push(ip),
219                //std::net::IpAddr::V6(ip) if !ip.is_loopback() => v6.push(ip),
220                _ => {}
221            }
222        }
223    }
224    (v4, v6)
225}
226
227pub(super) struct McastSocket {
228    pub sock: Arc<UdpSocket>,
229    pub multicast_addr: &'static str,
230}
231
232pub(super) async fn send_loop(
233    sockets: Vec<McastSocket>,
234    mut rx: UnboundedReceiver<SendCommand>,
235    cancel: CancellationToken,
236) {
237    loop {
238        let cmd = tokio::select! {
239            cmd = rx.recv() => {
240                match cmd {
241                    Some(c) => c,
242                    None => return,
243                }
244            }
245            _ = cancel.cancelled() => return,
246        };
247
248        match cmd {
249            SendCommand::Multicast(data) => {
250                for ms in &sockets {
251                    let _ = ms.sock.send_to(&data, ms.multicast_addr).await;
252                }
253            }
254            SendCommand::Unicast(data, addr) => {
255                // Send on first socket that succeeds
256                for ms in &sockets {
257                    if ms.sock.send_to(&data, addr).await.is_ok() {
258                        break;
259                    }
260                }
261            }
262        }
263    }
264}
265