1mod dnssd;
9mod protocol;
10
11pub use dnssd::{MdnsEvent, ServiceRegistration};
12pub use protocol::{CachedRecord, RecordCache};
13
14use std::collections::HashSet;
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18
19use anyhow::Result;
20use tokio::net::UdpSocket;
21use tokio::sync::Mutex;
22use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
23use tokio_util::sync::CancellationToken;
24
25use crate::mdns;
26use dnssd::{PeriodicQuery, build_service_records, find_matching_services};
27use protocol::{
28 MDNS_ADDR_V4, MDNS_ADDR_V6, McastSocket, SendCommand, build_response,
29 create_multicast_socket_v4, create_multicast_socket_v6, get_local_ips, send_loop,
30};
31
32fn dedup_records(records: &mut Vec<mdns::RR>) {
33 let mut seen = HashSet::new();
34 records.retain(|r| seen.insert(r.clone()));
35}
36
37struct MdnsServiceInner {
38 cache: RecordCache,
39 queries: Vec<PeriodicQuery>,
40 services: Vec<ServiceRegistration>,
41 local_ips_v4: Vec<Ipv4Addr>,
42 local_ips_v6: Vec<Ipv6Addr>,
43}
44
45pub struct MdnsService {
47 inner: Arc<Mutex<MdnsServiceInner>>,
48 send_tx: UnboundedSender<SendCommand>,
49 cancel: CancellationToken,
50}
51
52async fn recv_loop(
53 socket: Arc<UdpSocket>,
54 inner: Arc<Mutex<MdnsServiceInner>>,
55 send_tx: UnboundedSender<SendCommand>,
56 event_tx: UnboundedSender<MdnsEvent>,
57 cancel: CancellationToken,
58) {
59 let mut buf = vec![0u8; 9000];
60 loop {
61 let (n, addr) = tokio::select! {
62 result = socket.recv_from(&mut buf) => {
63 match result {
64 Ok(v) => v,
65 Err(e) => {
66 log::debug!("mdns2 recv error: {}", e);
67 continue;
68 }
69 }
70 }
71 _ = cancel.cancelled() => return,
72 };
73
74 let data = &buf[..n];
75 let msg = match mdns::parse_dns(data, addr) {
76 Ok(m) => m,
77 Err(e) => {
78 log::trace!("mdns2: failed to parse packet from {}: {:?}", addr, e);
79 continue;
80 }
81 };
82
83 let is_response = msg.flags & 0x8000 != 0;
84
85 if is_response {
86 let mut state = inner.lock().await;
88 let all_records: Vec<mdns::RR> = msg
89 .answers
90 .iter()
91 .chain(msg.additional.iter())
92 .cloned()
93 .collect();
94
95 let mut new_ptr_records = Vec::new();
96 for rr in &all_records {
97 state.cache.ingest(rr);
98 if rr.typ == mdns::TYPE_PTR {
99 if let mdns::RRData::PTR(ref target) = rr.data {
100 new_ptr_records.push((rr.name.clone(), target.clone()));
101 }
102 }
103 }
104 for (name, target) in new_ptr_records {
105 let _ = event_tx.send(MdnsEvent::ServiceDiscovered {
106 name,
107 target,
108 records: all_records.clone(),
109 });
110 }
111 } else {
112 let state = inner.lock().await;
113 if state.services.is_empty() {
114 continue;
115 }
116 let mut all_answers = Vec::new();
117 let mut all_additional = Vec::new();
118 for q in &msg.queries {
119 let (ans, add) = find_matching_services(
120 &q.name,
121 q.typ,
122 &state.services,
123 &state.local_ips_v4,
124 &state.local_ips_v6,
125 );
126 all_answers.extend(ans);
127 all_additional.extend(add);
128 }
129 drop(state);
130
131 dedup_records(&mut all_answers);
133 dedup_records(&mut all_additional);
134 all_additional.retain(|r| !all_answers.contains(r));
136
137 if !all_answers.is_empty() {
138 if let Ok(packet) = build_response(&all_answers, &all_additional) {
139 let _ = send_tx.send(SendCommand::Multicast(packet));
140 }
141 }
142 }
143 }
144}
145
146async fn periodic_loop(
147 inner: Arc<Mutex<MdnsServiceInner>>,
148 send_tx: UnboundedSender<SendCommand>,
149 event_tx: UnboundedSender<MdnsEvent>,
150 cancel: CancellationToken,
151) {
152 let mut interval = tokio::time::interval(Duration::from_secs(1));
153 loop {
154 tokio::select! {
155 _ = interval.tick() => {}
156 _ = cancel.cancelled() => return,
157 }
158
159 let mut state = inner.lock().await;
160
161 let expired = state.cache.evict_expired();
163 for (name, rtype) in expired {
164 let _ = event_tx.send(MdnsEvent::ServiceExpired { name, rtype });
165 }
166
167 let now = Instant::now();
169 let mut packets = Vec::new();
170 for q in &mut state.queries {
171 if now.duration_since(q.last_sent) >= q.interval {
172 if let Ok(pkt) = mdns::create_query(&q.label, q.qtype) {
173 packets.push(pkt);
174 }
175 q.last_sent = now;
176 }
177 }
178 drop(state);
179
180 for pkt in packets {
181 let _ = send_tx.send(SendCommand::Multicast(pkt));
182 }
183
184 let (v4, v6) = get_local_ips();
186 let mut state = inner.lock().await;
187 state.local_ips_v4 = v4;
188 state.local_ips_v6 = v6;
189 }
190}
191
192impl MdnsService {
193 pub async fn new() -> Result<(Arc<Self>, UnboundedReceiver<MdnsEvent>)> {
195 let (event_tx, event_rx) = mpsc::unbounded_channel();
196 let (send_tx, send_rx) = mpsc::unbounded_channel();
197 let cancel = CancellationToken::new();
198
199 let (v4, v6) = get_local_ips();
200 let inner = Arc::new(Mutex::new(MdnsServiceInner {
201 cache: RecordCache::new(),
202 queries: Vec::new(),
203 services: Vec::new(),
204 local_ips_v4: v4,
205 local_ips_v6: v6,
206 }));
207
208 let mut mcast_sockets: Vec<McastSocket> = Vec::new();
210
211 match create_multicast_socket_v4() {
213 Ok(std_sock) => match UdpSocket::from_std(std_sock) {
214 Ok(s) => mcast_sockets.push(McastSocket {
215 sock: Arc::new(s),
216 multicast_addr: MDNS_ADDR_V4,
217 }),
218 Err(e) => log::warn!("mdns2: failed to wrap v4 socket: {}", e),
219 },
220 Err(e) => log::warn!("mdns2: failed to create v4 socket: {}", e),
221 }
222
223 if let Ok(ifaces) = if_addrs::get_if_addrs() {
224 let mut seen_indices = std::collections::HashSet::new();
225 for iface in ifaces {
226 if !iface.ip().is_ipv6() {
227 continue;
228 }
229 if let Some(idx) = iface.index {
230 if !seen_indices.insert(idx) {
231 continue;
232 }
233 match create_multicast_socket_v6(idx) {
234 Ok(std_sock) => match UdpSocket::from_std(std_sock) {
235 Ok(s) => mcast_sockets.push(McastSocket {
236 sock: Arc::new(s),
237 multicast_addr: MDNS_ADDR_V6,
238 }),
239 Err(e) => {
240 log::debug!("mdns2: failed to wrap v6 socket idx={}: {}", idx, e)
241 }
242 },
243 Err(e) => {
244 log::debug!("mdns2: failed to create v6 socket idx={}: {}", idx, e)
245 }
246 }
247 }
248 }
249 }
250
251 if mcast_sockets.is_empty() {
252 anyhow::bail!("mdns2: no sockets could be created");
253 }
254
255 for ms in &mcast_sockets {
257 let sock = ms.sock.clone();
258 let inner = inner.clone();
259 let send_tx = send_tx.clone();
260 let event_tx = event_tx.clone();
261 let cancel = cancel.child_token();
262 tokio::spawn(async move {
263 recv_loop(sock, inner, send_tx, event_tx, cancel).await;
264 });
265 }
266
267 {
269 let inner = inner.clone();
270 let send_tx = send_tx.clone();
271 let event_tx = event_tx.clone();
272 let cancel = cancel.child_token();
273 tokio::spawn(async move {
274 periodic_loop(inner, send_tx, event_tx, cancel).await;
275 });
276 }
277
278 {
280 let cancel = cancel.child_token();
281 tokio::spawn(async move {
282 send_loop(mcast_sockets, send_rx, cancel).await;
283 });
284 }
285
286 let service = Arc::new(MdnsService {
287 inner,
288 send_tx,
289 cancel,
290 });
291
292 Ok((service, event_rx))
293 }
294
295 pub async fn add_query(&self, label: &str, qtype: u16, interval: Duration) {
297 let mut state = self.inner.lock().await;
298 let sent_at = Instant::now();
300 if let Ok(pkt) = mdns::create_query(label, qtype) {
301 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
302 }
303 state.queries.push(PeriodicQuery {
304 label: label.to_owned(),
305 qtype,
306 interval,
307 last_sent: sent_at,
308 });
309 }
310
311 pub async fn remove_query(&self, label: &str) {
313 let mut state = self.inner.lock().await;
314 state.queries.retain(|q| q.label != label);
315 }
316
317 pub async fn register_service(&self, reg: ServiceRegistration) {
319 let mut state = self.inner.lock().await;
320 state.services.push(reg);
321 }
322
323 pub async fn unregister_service(&self, instance: &str, service_type: &str) {
325 let mut state = self.inner.lock().await;
326 let idx = state
327 .services
328 .iter()
329 .position(|s| s.instance_name == instance && s.service_type == service_type);
330 if let Some(idx) = idx {
331 let reg = state.services.remove(idx);
332 let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
334 let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
335 let mut goodbye_records = build_service_records(®, svc_v4, svc_v6);
336 for rr in &mut goodbye_records {
337 rr.ttl = 0;
338 }
339 drop(state);
340 if let Ok(pkt) = build_response(&goodbye_records, &[]) {
341 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
342 }
343 }
344 }
345
346 pub async fn announce(&self) {
348 let state = self.inner.lock().await;
349 let mut all_answers = Vec::new();
350 let mut all_additional = Vec::new();
351 for reg in &state.services {
352 let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
353 let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
354 let records = build_service_records(reg, svc_v4, svc_v6);
355 for r in records {
357 if r.typ == mdns::TYPE_PTR {
358 all_answers.push(r);
359 } else {
360 all_additional.push(r);
361 }
362 }
363 }
364 drop(state);
365
366 if !all_answers.is_empty() {
367 if let Ok(pkt) = build_response(&all_answers, &all_additional) {
368 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
369 }
370 }
371 }
372
373 pub async fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
375 let state = self.inner.lock().await;
376 if qtype == mdns::QTYPE_ANY {
377 state.cache.lookup_name(name)
378 } else {
379 state.cache.lookup(name, qtype)
380 }
381 }
382
383 pub async fn active_lookup(&self, name: &str, qtype: u16) {
384 if let Ok(pkt) = mdns::create_query(name, qtype) {
385 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
386 }
387 }
388
389 pub fn shutdown(&self) {
391 self.cancel.cancel();
392 }
393}
394
395impl Drop for MdnsService {
396 fn drop(&mut self) {
397 self.cancel.cancel();
398 }
399}