Skip to main content

matc/mdns2/
mod.rs

1//! minimal mDNS service with continuous discovery, record caching, and service registration.
2//!
3//! this provides a long-running service that:
4//! - Runs continuous discovery with periodic re-queries
5//! - Caches discovered records with TTL-based expiration
6//! - Registers local services and responds to incoming mDNS queries
7//! - Emits discovery events via a broadcast channel; call [`MdnsService::subscribe`] to get
8//!   an independent event stream per caller, supporting concurrent discovery operations
9
10mod dnssd;
11mod protocol;
12
13pub use dnssd::{MdnsEvent, ServiceRegistration};
14pub use protocol::{CachedRecord, RecordCache};
15
16use std::collections::HashSet;
17use std::net::{Ipv4Addr, Ipv6Addr};
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20
21use anyhow::Result;
22use tokio::net::UdpSocket;
23use tokio::sync::broadcast;
24use tokio::sync::Mutex;
25use tokio::sync::mpsc::{self, UnboundedSender};
26use tokio_util::sync::CancellationToken;
27
28use crate::mdns;
29use dnssd::{PeriodicQuery, build_service_records, find_matching_services};
30use protocol::{
31    MDNS_ADDR_V4, MDNS_ADDR_V6, McastSocket, SendCommand, build_response,
32    create_multicast_socket_v4, create_multicast_socket_v6, get_local_ips, send_loop,
33};
34
35fn dedup_records(records: &mut Vec<mdns::RR>) {
36    let mut seen = HashSet::new();
37    records.retain(|r| seen.insert(r.clone()));
38}
39
40struct MdnsServiceInner {
41    cache: RecordCache,
42    queries: Vec<PeriodicQuery>,
43    services: Vec<ServiceRegistration>,
44    local_ips_v4: Vec<Ipv4Addr>,
45    local_ips_v6: Vec<Ipv6Addr>,
46}
47
48const EVENT_CHANNEL_CAPACITY: usize = 256;
49
50/// Long-running mDNS service with discovery, caching, and service registration.
51pub struct MdnsService {
52    inner: Arc<Mutex<MdnsServiceInner>>,
53    send_tx: UnboundedSender<SendCommand>,
54    event_tx: broadcast::Sender<MdnsEvent>,
55    cancel: CancellationToken,
56}
57
58async fn recv_loop(
59    socket: Arc<UdpSocket>,
60    inner: Arc<Mutex<MdnsServiceInner>>,
61    send_tx: UnboundedSender<SendCommand>,
62    event_tx: broadcast::Sender<MdnsEvent>,
63    cancel: CancellationToken,
64) {
65    let mut buf = vec![0u8; 9000];
66    loop {
67        let (n, addr) = tokio::select! {
68            result = socket.recv_from(&mut buf) => {
69                match result {
70                    Ok(v) => v,
71                    Err(e) => {
72                        log::debug!("mdns2 recv error: {}", e);
73                        continue;
74                    }
75                }
76            }
77            _ = cancel.cancelled() => return,
78        };
79
80        let data = &buf[..n];
81        let msg = match mdns::parse_dns(data, addr) {
82            Ok(m) => m,
83            Err(e) => {
84                log::trace!("mdns2: failed to parse packet from {}: {:?}", addr, e);
85                continue;
86            }
87        };
88
89        let is_response = msg.flags & 0x8000 != 0;
90
91        if is_response {
92            // Ingest all records into cache
93            let mut state = inner.lock().await;
94            let all_records: Vec<mdns::RR> = msg
95                .answers
96                .iter()
97                .chain(msg.additional.iter())
98                .cloned()
99                .collect();
100
101            let mut new_ptr_records = Vec::new();
102            for rr in &all_records {
103                state.cache.ingest(rr);
104                if rr.typ == mdns::TYPE_PTR {
105                    if let mdns::RRData::PTR(ref target) = rr.data {
106                        new_ptr_records.push((rr.name.clone(), target.clone()));
107                    }
108                }
109            }
110            for (name, target) in new_ptr_records {
111                let _ = event_tx.send(MdnsEvent::ServiceDiscovered {
112                    name,
113                    target,
114                    records: all_records.clone(),
115                });
116            }
117        } else {
118            let state = inner.lock().await;
119            if state.services.is_empty() {
120                continue;
121            }
122            let mut all_answers = Vec::new();
123            let mut all_additional = Vec::new();
124            for q in &msg.queries {
125                let (ans, add) = find_matching_services(
126                    &q.name,
127                    q.typ,
128                    &state.services,
129                    &state.local_ips_v4,
130                    &state.local_ips_v6,
131                );
132                all_answers.extend(ans);
133                all_additional.extend(add);
134            }
135            drop(state);
136
137            // Deduplicate records that matched multiple queries
138            dedup_records(&mut all_answers);
139            dedup_records(&mut all_additional);
140            // Don't repeat answer records in additional
141            all_additional.retain(|r| !all_answers.contains(r));
142
143            if !all_answers.is_empty() {
144                if let Ok(packet) = build_response(&all_answers, &all_additional) {
145                    let _ = send_tx.send(SendCommand::Multicast(packet));
146                }
147            }
148        }
149    }
150}
151
152async fn periodic_loop(
153    inner: Arc<Mutex<MdnsServiceInner>>,
154    send_tx: UnboundedSender<SendCommand>,
155    event_tx: broadcast::Sender<MdnsEvent>,
156    cancel: CancellationToken,
157) {
158    let mut interval = tokio::time::interval(Duration::from_secs(1));
159    loop {
160        tokio::select! {
161            _ = interval.tick() => {}
162            _ = cancel.cancelled() => return,
163        }
164
165        let mut state = inner.lock().await;
166
167        // Evict expired cache entries
168        let expired = state.cache.evict_expired();
169        for (name, rtype) in expired {
170            let _ = event_tx.send(MdnsEvent::ServiceExpired { name, rtype });
171        }
172
173        // Send due queries
174        let now = Instant::now();
175        let mut packets = Vec::new();
176        for q in &mut state.queries {
177            if now.duration_since(q.last_sent) >= q.interval {
178                if let Ok(pkt) = mdns::create_query(&q.label, q.qtype) {
179                    packets.push(pkt);
180                }
181                q.last_sent = now;
182            }
183        }
184        drop(state);
185
186        for pkt in packets {
187            let _ = send_tx.send(SendCommand::Multicast(pkt));
188        }
189
190        // Refresh local IPs periodically (cheap operation)
191        let (v4, v6) = get_local_ips();
192        let mut state = inner.lock().await;
193        state.local_ips_v4 = v4;
194        state.local_ips_v6 = v6;
195    }
196}
197
198impl MdnsService {
199    /// Create a new mDNS service.
200    ///
201    /// Call [`subscribe`](Self::subscribe) on the returned handle to receive discovery events.
202    /// Multiple independent subscribers may receive events concurrently.
203    pub async fn new() -> Result<Arc<Self>> {
204        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
205        let (send_tx, send_rx) = mpsc::unbounded_channel();
206        let cancel = CancellationToken::new();
207
208        let (v4, v6) = get_local_ips();
209        let inner = Arc::new(Mutex::new(MdnsServiceInner {
210            cache: RecordCache::new(),
211            queries: Vec::new(),
212            services: Vec::new(),
213            local_ips_v4: v4,
214            local_ips_v6: v6,
215        }));
216
217        // Create sockets
218        let mut mcast_sockets: Vec<McastSocket> = Vec::new();
219
220        // IPv4
221        match create_multicast_socket_v4() {
222            Ok(std_sock) => match UdpSocket::from_std(std_sock) {
223                Ok(s) => mcast_sockets.push(McastSocket {
224                    sock: Arc::new(s),
225                    multicast_addr: MDNS_ADDR_V4,
226                }),
227                Err(e) => log::warn!("mdns2: failed to wrap v4 socket: {}", e),
228            },
229            Err(e) => log::warn!("mdns2: failed to create v4 socket: {}", e),
230        }
231
232        if let Ok(ifaces) = if_addrs::get_if_addrs() {
233            let mut seen_indices = std::collections::HashSet::new();
234            for iface in ifaces {
235                if !iface.ip().is_ipv6() {
236                    continue;
237                }
238                if let Some(idx) = iface.index {
239                    if !seen_indices.insert(idx) {
240                        continue;
241                    }
242                    match create_multicast_socket_v6(idx) {
243                        Ok(std_sock) => match UdpSocket::from_std(std_sock) {
244                            Ok(s) => mcast_sockets.push(McastSocket {
245                                sock: Arc::new(s),
246                                multicast_addr: MDNS_ADDR_V6,
247                            }),
248                            Err(e) => {
249                                log::debug!("mdns2: failed to wrap v6 socket idx={}: {}", idx, e)
250                            }
251                        },
252                        Err(e) => {
253                            log::debug!("mdns2: failed to create v6 socket idx={}: {}", idx, e)
254                        }
255                    }
256                }
257            }
258        }
259
260        if mcast_sockets.is_empty() {
261            anyhow::bail!("mdns2: no sockets could be created");
262        }
263
264        // Spawn recv loops (one per socket)
265        for ms in &mcast_sockets {
266            let sock = ms.sock.clone();
267            let inner = inner.clone();
268            let send_tx = send_tx.clone();
269            let event_tx = event_tx.clone();
270            let cancel = cancel.child_token();
271            tokio::spawn(async move {
272                recv_loop(sock, inner, send_tx, event_tx, cancel).await;
273            });
274        }
275
276        // Spawn periodic loop
277        {
278            let inner = inner.clone();
279            let send_tx = send_tx.clone();
280            let event_tx = event_tx.clone();
281            let cancel = cancel.child_token();
282            tokio::spawn(async move {
283                periodic_loop(inner, send_tx, event_tx, cancel).await;
284            });
285        }
286
287        // Spawn send loop
288        {
289            let cancel = cancel.child_token();
290            tokio::spawn(async move {
291                send_loop(mcast_sockets, send_rx, cancel).await;
292            });
293        }
294
295        let service = Arc::new(MdnsService {
296            inner,
297            send_tx,
298            event_tx,
299            cancel,
300        });
301
302        Ok(service)
303    }
304
305    /// Subscribe to discovery events.
306    ///
307    /// Returns an independent [`broadcast::Receiver`]; each subscriber receives every event.
308    /// Subscribe before calling [`active_lookup`](Self::active_lookup) to avoid missing
309    /// responses that arrive before the next `recv()` call.
310    /// On lag (`RecvError::Lagged`), log a warning and keep draining — events are recoverable
311    /// by re-issuing [`active_lookup`](Self::active_lookup).
312    pub fn subscribe(&self) -> broadcast::Receiver<MdnsEvent> {
313        self.event_tx.subscribe()
314    }
315
316    /// Add a periodic query. The query will be sent immediately, then every interval.
317    pub async fn add_query(&self, label: &str, qtype: u16, interval: Duration) {
318        let mut state = self.inner.lock().await;
319        // Send immediately
320        let sent_at = Instant::now();
321        if let Ok(pkt) = mdns::create_query(label, qtype) {
322            let _ = self.send_tx.send(SendCommand::Multicast(pkt));
323        }
324        state.queries.push(PeriodicQuery {
325            label: label.to_owned(),
326            qtype,
327            interval,
328            last_sent: sent_at,
329        });
330    }
331
332    /// Remove a periodic query by label.
333    pub async fn remove_query(&self, label: &str) {
334        let mut state = self.inner.lock().await;
335        state.queries.retain(|q| q.label != label);
336    }
337
338    /// Register a local service to be advertised.
339    pub async fn register_service(&self, reg: ServiceRegistration) {
340        let mut state = self.inner.lock().await;
341        state.services.push(reg);
342    }
343
344    /// Unregister a local service. Sends a goodbye (TTL=0) for the service records.
345    pub async fn unregister_service(&self, instance: &str, service_type: &str) {
346        let mut state = self.inner.lock().await;
347        let idx = state
348            .services
349            .iter()
350            .position(|s| s.instance_name == instance && s.service_type == service_type);
351        if let Some(idx) = idx {
352            let reg = state.services.remove(idx);
353            // Build goodbye records (TTL=0)
354            let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
355            let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
356            let mut goodbye_records = build_service_records(&reg, svc_v4, svc_v6);
357            for rr in &mut goodbye_records {
358                rr.ttl = 0;
359            }
360            drop(state);
361            if let Ok(pkt) = build_response(&goodbye_records, &[]) {
362                let _ = self.send_tx.send(SendCommand::Multicast(pkt));
363            }
364        }
365    }
366
367    /// Send a gratuitous announcement of all registered services.
368    pub async fn announce(&self) {
369        let state = self.inner.lock().await;
370        let mut all_answers = Vec::new();
371        let mut all_additional = Vec::new();
372        for reg in &state.services {
373            let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
374            let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
375            let records = build_service_records(reg, svc_v4, svc_v6);
376            // PTR goes as answer, everything else as additional
377            for r in records {
378                if r.typ == mdns::TYPE_PTR {
379                    all_answers.push(r);
380                } else {
381                    all_additional.push(r);
382                }
383            }
384        }
385        drop(state);
386
387        if !all_answers.is_empty() {
388            if let Ok(pkt) = build_response(&all_answers, &all_additional) {
389                let _ = self.send_tx.send(SendCommand::Multicast(pkt));
390            }
391        }
392    }
393
394    /// Lookup cached records by name and type.
395    pub async fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
396        let state = self.inner.lock().await;
397        if qtype == mdns::QTYPE_ANY {
398            state.cache.lookup_name(name)
399        } else {
400            state.cache.lookup(name, qtype)
401        }
402    }
403
404    pub async fn active_lookup(&self, name: &str, qtype: u16) {
405        if let Ok(pkt) = mdns::create_query(name, qtype) {
406            let _ = self.send_tx.send(SendCommand::Multicast(pkt));
407        }
408    }
409
410    /// Shut down all background tasks.
411    pub fn shutdown(&self) {
412        self.cancel.cancel();
413    }
414}
415
416impl Drop for MdnsService {
417    fn drop(&mut self) {
418        self.cancel.cancel();
419    }
420}