matc/mdns2/
mod.rs

1//! minimal mDNS service with continuous discovery, record caching, and service registration.
2//!
3//! this provides a long-running service that:
4//! - Runs continuous discovery with periodic re-queries
5//! - Caches discovered records with TTL-based expiration
6//! - Registers local services and responds to incoming mDNS queries
7
8mod dnssd;
9mod protocol;
10
11pub use dnssd::{MdnsEvent, ServiceRegistration};
12pub use protocol::{CachedRecord, RecordCache};
13
14use std::collections::HashSet;
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18
19use anyhow::Result;
20use tokio::net::UdpSocket;
21use tokio::sync::Mutex;
22use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
23use tokio_util::sync::CancellationToken;
24
25use crate::mdns;
26use dnssd::{PeriodicQuery, build_service_records, find_matching_services};
27use protocol::{
28    MDNS_ADDR_V4, MDNS_ADDR_V6, McastSocket, SendCommand, build_response,
29    create_multicast_socket_v4, create_multicast_socket_v6, get_local_ips, send_loop,
30};
31
32fn dedup_records(records: &mut Vec<mdns::RR>) {
33    let mut seen = HashSet::new();
34    records.retain(|r| seen.insert(r.clone()));
35}
36
37struct MdnsServiceInner {
38    cache: RecordCache,
39    queries: Vec<PeriodicQuery>,
40    services: Vec<ServiceRegistration>,
41    local_ips_v4: Vec<Ipv4Addr>,
42    local_ips_v6: Vec<Ipv6Addr>,
43}
44
45/// Long-running mDNS service with discovery, caching, and service registration.
46pub struct MdnsService {
47    inner: Arc<Mutex<MdnsServiceInner>>,
48    send_tx: UnboundedSender<SendCommand>,
49    cancel: CancellationToken,
50}
51
52async fn recv_loop(
53    socket: Arc<UdpSocket>,
54    inner: Arc<Mutex<MdnsServiceInner>>,
55    send_tx: UnboundedSender<SendCommand>,
56    event_tx: UnboundedSender<MdnsEvent>,
57    cancel: CancellationToken,
58) {
59    let mut buf = vec![0u8; 9000];
60    loop {
61        let (n, addr) = tokio::select! {
62            result = socket.recv_from(&mut buf) => {
63                match result {
64                    Ok(v) => v,
65                    Err(e) => {
66                        log::debug!("mdns2 recv error: {}", e);
67                        continue;
68                    }
69                }
70            }
71            _ = cancel.cancelled() => return,
72        };
73
74        let data = &buf[..n];
75        let msg = match mdns::parse_dns(data, addr) {
76            Ok(m) => m,
77            Err(e) => {
78                log::trace!("mdns2: failed to parse packet from {}: {:?}", addr, e);
79                continue;
80            }
81        };
82
83        let is_response = msg.flags & 0x8000 != 0;
84
85        if is_response {
86            // Ingest all records into cache
87            let mut state = inner.lock().await;
88            let all_records: Vec<mdns::RR> = msg
89                .answers
90                .iter()
91                .chain(msg.additional.iter())
92                .cloned()
93                .collect();
94
95            let mut new_ptr_records = Vec::new();
96            for rr in &all_records {
97                state.cache.ingest(rr);
98                if rr.typ == mdns::TYPE_PTR {
99                    if let mdns::RRData::PTR(ref target) = rr.data {
100                        new_ptr_records.push((rr.name.clone(), target.clone()));
101                    }
102                }
103            }
104            for (name, target) in new_ptr_records {
105                let _ = event_tx.send(MdnsEvent::ServiceDiscovered {
106                    name,
107                    target,
108                    records: all_records.clone(),
109                });
110            }
111        } else {
112            // Incoming query — check if we have matching local services
113            let state = inner.lock().await;
114            if state.services.is_empty() {
115                continue;
116            }
117            let mut all_answers = Vec::new();
118            let mut all_additional = Vec::new();
119            for q in &msg.queries {
120                let (ans, add) = find_matching_services(
121                    &q.name,
122                    q.typ,
123                    &state.services,
124                    &state.local_ips_v4,
125                    &state.local_ips_v6,
126                );
127                all_answers.extend(ans);
128                all_additional.extend(add);
129            }
130            drop(state);
131
132            // Deduplicate records that matched multiple queries
133            dedup_records(&mut all_answers);
134            dedup_records(&mut all_additional);
135            // Don't repeat answer records in additional
136            all_additional.retain(|r| !all_answers.contains(r));
137
138            if !all_answers.is_empty() {
139                if let Ok(packet) = build_response(&all_answers, &all_additional) {
140                    let _ = send_tx.send(SendCommand::Multicast(packet));
141                }
142            }
143        }
144    }
145}
146
147async fn periodic_loop(
148    inner: Arc<Mutex<MdnsServiceInner>>,
149    send_tx: UnboundedSender<SendCommand>,
150    event_tx: UnboundedSender<MdnsEvent>,
151    cancel: CancellationToken,
152) {
153    let mut interval = tokio::time::interval(Duration::from_secs(1));
154    loop {
155        tokio::select! {
156            _ = interval.tick() => {}
157            _ = cancel.cancelled() => return,
158        }
159
160        let mut state = inner.lock().await;
161
162        // Evict expired cache entries
163        let expired = state.cache.evict_expired();
164        for (name, rtype) in expired {
165            let _ = event_tx.send(MdnsEvent::ServiceExpired { name, rtype });
166        }
167
168        // Send due queries
169        let now = Instant::now();
170        let mut packets = Vec::new();
171        for q in &mut state.queries {
172            if now.duration_since(q.last_sent) >= q.interval {
173                if let Ok(pkt) = mdns::create_query(&q.label, q.qtype) {
174                    packets.push(pkt);
175                }
176                q.last_sent = now;
177            }
178        }
179        drop(state);
180
181        for pkt in packets {
182            let _ = send_tx.send(SendCommand::Multicast(pkt));
183        }
184
185        // Refresh local IPs periodically (cheap operation)
186        let (v4, v6) = get_local_ips();
187        let mut state = inner.lock().await;
188        state.local_ips_v4 = v4;
189        state.local_ips_v6 = v6;
190    }
191}
192
193impl MdnsService {
194    /// Create a new mDNS service. Returns the service handle and a receiver for events.
195    pub async fn new() -> Result<(Arc<Self>, UnboundedReceiver<MdnsEvent>)> {
196        let (event_tx, event_rx) = mpsc::unbounded_channel();
197        let (send_tx, send_rx) = mpsc::unbounded_channel();
198        let cancel = CancellationToken::new();
199
200        let (v4, v6) = get_local_ips();
201        let inner = Arc::new(Mutex::new(MdnsServiceInner {
202            cache: RecordCache::new(),
203            queries: Vec::new(),
204            services: Vec::new(),
205            local_ips_v4: v4,
206            local_ips_v6: v6,
207        }));
208
209        // Create sockets
210        let mut mcast_sockets: Vec<McastSocket> = Vec::new();
211
212        // IPv4
213        match create_multicast_socket_v4() {
214            Ok(std_sock) => match UdpSocket::from_std(std_sock) {
215                Ok(s) => mcast_sockets.push(McastSocket {
216                    sock: Arc::new(s),
217                    multicast_addr: MDNS_ADDR_V4,
218                }),
219                Err(e) => log::warn!("mdns2: failed to wrap v4 socket: {}", e),
220            },
221            Err(e) => log::warn!("mdns2: failed to create v4 socket: {}", e),
222        }
223
224        // IPv6 — one per interface
225        if let Ok(ifaces) = if_addrs::get_if_addrs() {
226            let mut seen_indices = std::collections::HashSet::new();
227            for iface in ifaces {
228                if !iface.ip().is_ipv6() {
229                    continue;
230                }
231                if let Some(idx) = iface.index {
232                    if !seen_indices.insert(idx) {
233                        continue;
234                    }
235                    match create_multicast_socket_v6(idx) {
236                        Ok(std_sock) => match UdpSocket::from_std(std_sock) {
237                            Ok(s) => mcast_sockets.push(McastSocket {
238                                sock: Arc::new(s),
239                                multicast_addr: MDNS_ADDR_V6,
240                            }),
241                            Err(e) => {
242                                log::debug!("mdns2: failed to wrap v6 socket idx={}: {}", idx, e)
243                            }
244                        },
245                        Err(e) => {
246                            log::debug!("mdns2: failed to create v6 socket idx={}: {}", idx, e)
247                        }
248                    }
249                }
250            }
251        }
252
253        if mcast_sockets.is_empty() {
254            anyhow::bail!("mdns2: no sockets could be created");
255        }
256
257        // Spawn recv loops (one per socket)
258        for ms in &mcast_sockets {
259            let sock = ms.sock.clone();
260            let inner = inner.clone();
261            let send_tx = send_tx.clone();
262            let event_tx = event_tx.clone();
263            let cancel = cancel.child_token();
264            tokio::spawn(async move {
265                recv_loop(sock, inner, send_tx, event_tx, cancel).await;
266            });
267        }
268
269        // Spawn periodic loop
270        {
271            let inner = inner.clone();
272            let send_tx = send_tx.clone();
273            let event_tx = event_tx.clone();
274            let cancel = cancel.child_token();
275            tokio::spawn(async move {
276                periodic_loop(inner, send_tx, event_tx, cancel).await;
277            });
278        }
279
280        // Spawn send loop
281        {
282            let cancel = cancel.child_token();
283            tokio::spawn(async move {
284                send_loop(mcast_sockets, send_rx, cancel).await;
285            });
286        }
287
288        let service = Arc::new(MdnsService {
289            inner,
290            send_tx,
291            cancel,
292        });
293
294        Ok((service, event_rx))
295    }
296
297    /// Add a periodic query. The query will be sent immediately, then every interval.
298    pub async fn add_query(&self, label: &str, qtype: u16, interval: Duration) {
299        let mut state = self.inner.lock().await;
300        // Send immediately
301        let sent_at = Instant::now();
302        if let Ok(pkt) = mdns::create_query(label, qtype) {
303            let _ = self.send_tx.send(SendCommand::Multicast(pkt));
304        }
305        state.queries.push(PeriodicQuery {
306            label: label.to_owned(),
307            qtype,
308            interval,
309            last_sent: sent_at,
310        });
311    }
312
313    /// Remove a periodic query by label.
314    pub async fn remove_query(&self, label: &str) {
315        let mut state = self.inner.lock().await;
316        state.queries.retain(|q| q.label != label);
317    }
318
319    /// Register a local service to be advertised.
320    pub async fn register_service(&self, reg: ServiceRegistration) {
321        let mut state = self.inner.lock().await;
322        state.services.push(reg);
323    }
324
325    /// Unregister a local service. Sends a goodbye (TTL=0) for the service records.
326    pub async fn unregister_service(&self, instance: &str, service_type: &str) {
327        let mut state = self.inner.lock().await;
328        let idx = state
329            .services
330            .iter()
331            .position(|s| s.instance_name == instance && s.service_type == service_type);
332        if let Some(idx) = idx {
333            let reg = state.services.remove(idx);
334            // Build goodbye records (TTL=0)
335            let mut goodbye_records =
336                build_service_records(&reg, &state.local_ips_v4, &state.local_ips_v6);
337            for rr in &mut goodbye_records {
338                rr.ttl = 0;
339            }
340            drop(state);
341            if let Ok(pkt) = build_response(&goodbye_records, &[]) {
342                let _ = self.send_tx.send(SendCommand::Multicast(pkt));
343            }
344        }
345    }
346
347    /// Send a gratuitous announcement of all registered services.
348    pub async fn announce(&self) {
349        let state = self.inner.lock().await;
350        let mut all_answers = Vec::new();
351        let mut all_additional = Vec::new();
352        for reg in &state.services {
353            let records = build_service_records(reg, &state.local_ips_v4, &state.local_ips_v6);
354            // PTR goes as answer, everything else as additional
355            for r in records {
356                if r.typ == mdns::TYPE_PTR {
357                    all_answers.push(r);
358                } else {
359                    all_additional.push(r);
360                }
361            }
362        }
363        drop(state);
364
365        if !all_answers.is_empty() {
366            if let Ok(pkt) = build_response(&all_answers, &all_additional) {
367                let _ = self.send_tx.send(SendCommand::Multicast(pkt));
368            }
369        }
370    }
371
372    /// Lookup cached records by name and type.
373    pub async fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
374        let state = self.inner.lock().await;
375        if qtype == mdns::QTYPE_ANY {
376            state.cache.lookup_name(name)
377        } else {
378            state.cache.lookup(name, qtype)
379        }
380    }
381
382    pub async fn active_lookup(&self, name: &str, qtype: u16) {
383        if let Ok(pkt) = mdns::create_query(name, qtype) {
384            let _ = self.send_tx.send(SendCommand::Multicast(pkt));
385        }
386    }
387
388    /// Shut down all background tasks.
389    pub fn shutdown(&self) {
390        self.cancel.cancel();
391    }
392}
393
394impl Drop for MdnsService {
395    fn drop(&mut self) {
396        self.cancel.cancel();
397    }
398}