matc/mdns2/
mod.rs

1//! minimal mDNS service with continuous discovery, record caching, and service registration.
2//!
3//! this provides a long-running service that:
4//! - Runs continuous discovery with periodic re-queries
5//! - Caches discovered records with TTL-based expiration
6//! - Registers local services and responds to incoming mDNS queries
7
8mod dnssd;
9mod protocol;
10
11pub use dnssd::{MdnsEvent, ServiceRegistration};
12pub use protocol::{CachedRecord, RecordCache};
13
14use std::collections::HashSet;
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18
19use anyhow::Result;
20use tokio::net::UdpSocket;
21use tokio::sync::Mutex;
22use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
23use tokio_util::sync::CancellationToken;
24
25use crate::mdns;
26use dnssd::{PeriodicQuery, build_service_records, find_matching_services};
27use protocol::{
28    MDNS_ADDR_V4, MDNS_ADDR_V6, McastSocket, SendCommand, build_response,
29    create_multicast_socket_v4, create_multicast_socket_v6, get_local_ips, send_loop,
30};
31
32fn dedup_records(records: &mut Vec<mdns::RR>) {
33    let mut seen = HashSet::new();
34    records.retain(|r| seen.insert(r.clone()));
35}
36
37struct MdnsServiceInner {
38    cache: RecordCache,
39    queries: Vec<PeriodicQuery>,
40    services: Vec<ServiceRegistration>,
41    local_ips_v4: Vec<Ipv4Addr>,
42    local_ips_v6: Vec<Ipv6Addr>,
43}
44
45/// Long-running mDNS service with discovery, caching, and service registration.
46pub struct MdnsService {
47    inner: Arc<Mutex<MdnsServiceInner>>,
48    send_tx: UnboundedSender<SendCommand>,
49    cancel: CancellationToken,
50}
51
52async fn recv_loop(
53    socket: Arc<UdpSocket>,
54    inner: Arc<Mutex<MdnsServiceInner>>,
55    send_tx: UnboundedSender<SendCommand>,
56    event_tx: UnboundedSender<MdnsEvent>,
57    cancel: CancellationToken,
58) {
59    let mut buf = vec![0u8; 9000];
60    loop {
61        let (n, addr) = tokio::select! {
62            result = socket.recv_from(&mut buf) => {
63                match result {
64                    Ok(v) => v,
65                    Err(e) => {
66                        log::debug!("mdns2 recv error: {}", e);
67                        continue;
68                    }
69                }
70            }
71            _ = cancel.cancelled() => return,
72        };
73
74        let data = &buf[..n];
75        let msg = match mdns::parse_dns(data, addr) {
76            Ok(m) => m,
77            Err(e) => {
78                log::trace!("mdns2: failed to parse packet from {}: {:?}", addr, e);
79                continue;
80            }
81        };
82
83        let is_response = msg.flags & 0x8000 != 0;
84
85        if is_response {
86            // Ingest all records into cache
87            let mut state = inner.lock().await;
88            let all_records: Vec<mdns::RR> = msg
89                .answers
90                .iter()
91                .chain(msg.additional.iter())
92                .cloned()
93                .collect();
94
95            let mut new_ptr_records = Vec::new();
96            for rr in &all_records {
97                state.cache.ingest(rr);
98                if rr.typ == mdns::TYPE_PTR {
99                    if let mdns::RRData::PTR(ref target) = rr.data {
100                        new_ptr_records.push((rr.name.clone(), target.clone()));
101                    }
102                }
103            }
104            for (name, target) in new_ptr_records {
105                let _ = event_tx.send(MdnsEvent::ServiceDiscovered {
106                    name,
107                    target,
108                    records: all_records.clone(),
109                });
110            }
111        } else {
112            let state = inner.lock().await;
113            if state.services.is_empty() {
114                continue;
115            }
116            let mut all_answers = Vec::new();
117            let mut all_additional = Vec::new();
118            for q in &msg.queries {
119                let (ans, add) = find_matching_services(
120                    &q.name,
121                    q.typ,
122                    &state.services,
123                    &state.local_ips_v4,
124                    &state.local_ips_v6,
125                );
126                all_answers.extend(ans);
127                all_additional.extend(add);
128            }
129            drop(state);
130
131            // Deduplicate records that matched multiple queries
132            dedup_records(&mut all_answers);
133            dedup_records(&mut all_additional);
134            // Don't repeat answer records in additional
135            all_additional.retain(|r| !all_answers.contains(r));
136
137            if !all_answers.is_empty() {
138                if let Ok(packet) = build_response(&all_answers, &all_additional) {
139                    let _ = send_tx.send(SendCommand::Multicast(packet));
140                }
141            }
142        }
143    }
144}
145
146async fn periodic_loop(
147    inner: Arc<Mutex<MdnsServiceInner>>,
148    send_tx: UnboundedSender<SendCommand>,
149    event_tx: UnboundedSender<MdnsEvent>,
150    cancel: CancellationToken,
151) {
152    let mut interval = tokio::time::interval(Duration::from_secs(1));
153    loop {
154        tokio::select! {
155            _ = interval.tick() => {}
156            _ = cancel.cancelled() => return,
157        }
158
159        let mut state = inner.lock().await;
160
161        // Evict expired cache entries
162        let expired = state.cache.evict_expired();
163        for (name, rtype) in expired {
164            let _ = event_tx.send(MdnsEvent::ServiceExpired { name, rtype });
165        }
166
167        // Send due queries
168        let now = Instant::now();
169        let mut packets = Vec::new();
170        for q in &mut state.queries {
171            if now.duration_since(q.last_sent) >= q.interval {
172                if let Ok(pkt) = mdns::create_query(&q.label, q.qtype) {
173                    packets.push(pkt);
174                }
175                q.last_sent = now;
176            }
177        }
178        drop(state);
179
180        for pkt in packets {
181            let _ = send_tx.send(SendCommand::Multicast(pkt));
182        }
183
184        // Refresh local IPs periodically (cheap operation)
185        let (v4, v6) = get_local_ips();
186        let mut state = inner.lock().await;
187        state.local_ips_v4 = v4;
188        state.local_ips_v6 = v6;
189    }
190}
191
192impl MdnsService {
193    /// Create a new mDNS service. Returns the service handle and a receiver for events.
194    pub async fn new() -> Result<(Arc<Self>, UnboundedReceiver<MdnsEvent>)> {
195        let (event_tx, event_rx) = mpsc::unbounded_channel();
196        let (send_tx, send_rx) = mpsc::unbounded_channel();
197        let cancel = CancellationToken::new();
198
199        let (v4, v6) = get_local_ips();
200        let inner = Arc::new(Mutex::new(MdnsServiceInner {
201            cache: RecordCache::new(),
202            queries: Vec::new(),
203            services: Vec::new(),
204            local_ips_v4: v4,
205            local_ips_v6: v6,
206        }));
207
208        // Create sockets
209        let mut mcast_sockets: Vec<McastSocket> = Vec::new();
210
211        // IPv4
212        match create_multicast_socket_v4() {
213            Ok(std_sock) => match UdpSocket::from_std(std_sock) {
214                Ok(s) => mcast_sockets.push(McastSocket {
215                    sock: Arc::new(s),
216                    multicast_addr: MDNS_ADDR_V4,
217                }),
218                Err(e) => log::warn!("mdns2: failed to wrap v4 socket: {}", e),
219            },
220            Err(e) => log::warn!("mdns2: failed to create v4 socket: {}", e),
221        }
222
223        if let Ok(ifaces) = if_addrs::get_if_addrs() {
224            let mut seen_indices = std::collections::HashSet::new();
225            for iface in ifaces {
226                if !iface.ip().is_ipv6() {
227                    continue;
228                }
229                if let Some(idx) = iface.index {
230                    if !seen_indices.insert(idx) {
231                        continue;
232                    }
233                    match create_multicast_socket_v6(idx) {
234                        Ok(std_sock) => match UdpSocket::from_std(std_sock) {
235                            Ok(s) => mcast_sockets.push(McastSocket {
236                                sock: Arc::new(s),
237                                multicast_addr: MDNS_ADDR_V6,
238                            }),
239                            Err(e) => {
240                                log::debug!("mdns2: failed to wrap v6 socket idx={}: {}", idx, e)
241                            }
242                        },
243                        Err(e) => {
244                            log::debug!("mdns2: failed to create v6 socket idx={}: {}", idx, e)
245                        }
246                    }
247                }
248            }
249        }
250
251        if mcast_sockets.is_empty() {
252            anyhow::bail!("mdns2: no sockets could be created");
253        }
254
255        // Spawn recv loops (one per socket)
256        for ms in &mcast_sockets {
257            let sock = ms.sock.clone();
258            let inner = inner.clone();
259            let send_tx = send_tx.clone();
260            let event_tx = event_tx.clone();
261            let cancel = cancel.child_token();
262            tokio::spawn(async move {
263                recv_loop(sock, inner, send_tx, event_tx, cancel).await;
264            });
265        }
266
267        // Spawn periodic loop
268        {
269            let inner = inner.clone();
270            let send_tx = send_tx.clone();
271            let event_tx = event_tx.clone();
272            let cancel = cancel.child_token();
273            tokio::spawn(async move {
274                periodic_loop(inner, send_tx, event_tx, cancel).await;
275            });
276        }
277
278        // Spawn send loop
279        {
280            let cancel = cancel.child_token();
281            tokio::spawn(async move {
282                send_loop(mcast_sockets, send_rx, cancel).await;
283            });
284        }
285
286        let service = Arc::new(MdnsService {
287            inner,
288            send_tx,
289            cancel,
290        });
291
292        Ok((service, event_rx))
293    }
294
295    /// Add a periodic query. The query will be sent immediately, then every interval.
296    pub async fn add_query(&self, label: &str, qtype: u16, interval: Duration) {
297        let mut state = self.inner.lock().await;
298        // Send immediately
299        let sent_at = Instant::now();
300        if let Ok(pkt) = mdns::create_query(label, qtype) {
301            let _ = self.send_tx.send(SendCommand::Multicast(pkt));
302        }
303        state.queries.push(PeriodicQuery {
304            label: label.to_owned(),
305            qtype,
306            interval,
307            last_sent: sent_at,
308        });
309    }
310
311    /// Remove a periodic query by label.
312    pub async fn remove_query(&self, label: &str) {
313        let mut state = self.inner.lock().await;
314        state.queries.retain(|q| q.label != label);
315    }
316
317    /// Register a local service to be advertised.
318    pub async fn register_service(&self, reg: ServiceRegistration) {
319        let mut state = self.inner.lock().await;
320        state.services.push(reg);
321    }
322
323    /// Unregister a local service. Sends a goodbye (TTL=0) for the service records.
324    pub async fn unregister_service(&self, instance: &str, service_type: &str) {
325        let mut state = self.inner.lock().await;
326        let idx = state
327            .services
328            .iter()
329            .position(|s| s.instance_name == instance && s.service_type == service_type);
330        if let Some(idx) = idx {
331            let reg = state.services.remove(idx);
332            // Build goodbye records (TTL=0)
333            let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
334            let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
335            let mut goodbye_records = build_service_records(&reg, svc_v4, svc_v6);
336            for rr in &mut goodbye_records {
337                rr.ttl = 0;
338            }
339            drop(state);
340            if let Ok(pkt) = build_response(&goodbye_records, &[]) {
341                let _ = self.send_tx.send(SendCommand::Multicast(pkt));
342            }
343        }
344    }
345
346    /// Send a gratuitous announcement of all registered services.
347    pub async fn announce(&self) {
348        let state = self.inner.lock().await;
349        let mut all_answers = Vec::new();
350        let mut all_additional = Vec::new();
351        for reg in &state.services {
352            let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
353            let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
354            let records = build_service_records(reg, svc_v4, svc_v6);
355            // PTR goes as answer, everything else as additional
356            for r in records {
357                if r.typ == mdns::TYPE_PTR {
358                    all_answers.push(r);
359                } else {
360                    all_additional.push(r);
361                }
362            }
363        }
364        drop(state);
365
366        if !all_answers.is_empty() {
367            if let Ok(pkt) = build_response(&all_answers, &all_additional) {
368                let _ = self.send_tx.send(SendCommand::Multicast(pkt));
369            }
370        }
371    }
372
373    /// Lookup cached records by name and type.
374    pub async fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
375        let state = self.inner.lock().await;
376        if qtype == mdns::QTYPE_ANY {
377            state.cache.lookup_name(name)
378        } else {
379            state.cache.lookup(name, qtype)
380        }
381    }
382
383    pub async fn active_lookup(&self, name: &str, qtype: u16) {
384        if let Ok(pkt) = mdns::create_query(name, qtype) {
385            let _ = self.send_tx.send(SendCommand::Multicast(pkt));
386        }
387    }
388
389    /// Shut down all background tasks.
390    pub fn shutdown(&self) {
391        self.cancel.cancel();
392    }
393}
394
395impl Drop for MdnsService {
396    fn drop(&mut self) {
397        self.cancel.cancel();
398    }
399}