1mod dnssd;
11mod protocol;
12
13pub use dnssd::{MdnsEvent, ServiceRegistration};
14pub use protocol::{CachedRecord, RecordCache};
15
16use std::collections::HashSet;
17use std::net::{Ipv4Addr, Ipv6Addr};
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20
21use anyhow::Result;
22use tokio::net::UdpSocket;
23use tokio::sync::broadcast;
24use tokio::sync::Mutex;
25use tokio::sync::mpsc::{self, UnboundedSender};
26use tokio_util::sync::CancellationToken;
27
28use crate::mdns;
29use dnssd::{PeriodicQuery, build_service_records, find_matching_services};
30use protocol::{
31 MDNS_ADDR_V4, MDNS_ADDR_V6, McastSocket, SendCommand, build_response,
32 create_multicast_socket_v4, create_multicast_socket_v6, get_local_ips, send_loop,
33};
34
35fn dedup_records(records: &mut Vec<mdns::RR>) {
36 let mut seen = HashSet::new();
37 records.retain(|r| seen.insert(r.clone()));
38}
39
40struct MdnsServiceInner {
41 cache: RecordCache,
42 queries: Vec<PeriodicQuery>,
43 services: Vec<ServiceRegistration>,
44 local_ips_v4: Vec<Ipv4Addr>,
45 local_ips_v6: Vec<Ipv6Addr>,
46}
47
48const EVENT_CHANNEL_CAPACITY: usize = 256;
49
50pub struct MdnsService {
52 inner: Arc<Mutex<MdnsServiceInner>>,
53 send_tx: UnboundedSender<SendCommand>,
54 event_tx: broadcast::Sender<MdnsEvent>,
55 cancel: CancellationToken,
56}
57
58async fn recv_loop(
59 socket: Arc<UdpSocket>,
60 inner: Arc<Mutex<MdnsServiceInner>>,
61 send_tx: UnboundedSender<SendCommand>,
62 event_tx: broadcast::Sender<MdnsEvent>,
63 cancel: CancellationToken,
64) {
65 let mut buf = vec![0u8; 9000];
66 loop {
67 let (n, addr) = tokio::select! {
68 result = socket.recv_from(&mut buf) => {
69 match result {
70 Ok(v) => v,
71 Err(e) => {
72 log::debug!("mdns2 recv error: {}", e);
73 continue;
74 }
75 }
76 }
77 _ = cancel.cancelled() => return,
78 };
79
80 let data = &buf[..n];
81 let msg = match mdns::parse_dns(data, addr) {
82 Ok(m) => m,
83 Err(e) => {
84 log::trace!("mdns2: failed to parse packet from {}: {:?}", addr, e);
85 continue;
86 }
87 };
88
89 let is_response = msg.flags & 0x8000 != 0;
90
91 if is_response {
92 let mut state = inner.lock().await;
94 let all_records: Vec<mdns::RR> = msg
95 .answers
96 .iter()
97 .chain(msg.additional.iter())
98 .cloned()
99 .collect();
100
101 let mut new_ptr_records = Vec::new();
102 for rr in &all_records {
103 state.cache.ingest(rr);
104 if rr.typ == mdns::TYPE_PTR {
105 if let mdns::RRData::PTR(ref target) = rr.data {
106 new_ptr_records.push((rr.name.clone(), target.clone()));
107 }
108 }
109 }
110 for (name, target) in new_ptr_records {
111 let _ = event_tx.send(MdnsEvent::ServiceDiscovered {
112 name,
113 target,
114 records: all_records.clone(),
115 });
116 }
117 } else {
118 let state = inner.lock().await;
119 if state.services.is_empty() {
120 continue;
121 }
122 let mut all_answers = Vec::new();
123 let mut all_additional = Vec::new();
124 for q in &msg.queries {
125 let (ans, add) = find_matching_services(
126 &q.name,
127 q.typ,
128 &state.services,
129 &state.local_ips_v4,
130 &state.local_ips_v6,
131 );
132 all_answers.extend(ans);
133 all_additional.extend(add);
134 }
135 drop(state);
136
137 dedup_records(&mut all_answers);
139 dedup_records(&mut all_additional);
140 all_additional.retain(|r| !all_answers.contains(r));
142
143 if !all_answers.is_empty() {
144 if let Ok(packet) = build_response(&all_answers, &all_additional) {
145 let _ = send_tx.send(SendCommand::Multicast(packet));
146 }
147 }
148 }
149 }
150}
151
152async fn periodic_loop(
153 inner: Arc<Mutex<MdnsServiceInner>>,
154 send_tx: UnboundedSender<SendCommand>,
155 event_tx: broadcast::Sender<MdnsEvent>,
156 cancel: CancellationToken,
157) {
158 let mut interval = tokio::time::interval(Duration::from_secs(1));
159 loop {
160 tokio::select! {
161 _ = interval.tick() => {}
162 _ = cancel.cancelled() => return,
163 }
164
165 let mut state = inner.lock().await;
166
167 let expired = state.cache.evict_expired();
169 for (name, rtype) in expired {
170 let _ = event_tx.send(MdnsEvent::ServiceExpired { name, rtype });
171 }
172
173 let now = Instant::now();
175 let mut packets = Vec::new();
176 for q in &mut state.queries {
177 if now.duration_since(q.last_sent) >= q.interval {
178 if let Ok(pkt) = mdns::create_query(&q.label, q.qtype) {
179 packets.push(pkt);
180 }
181 q.last_sent = now;
182 }
183 }
184 drop(state);
185
186 for pkt in packets {
187 let _ = send_tx.send(SendCommand::Multicast(pkt));
188 }
189
190 let (v4, v6) = get_local_ips();
192 let mut state = inner.lock().await;
193 state.local_ips_v4 = v4;
194 state.local_ips_v6 = v6;
195 }
196}
197
198impl MdnsService {
199 pub async fn new() -> Result<Arc<Self>> {
204 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
205 let (send_tx, send_rx) = mpsc::unbounded_channel();
206 let cancel = CancellationToken::new();
207
208 let (v4, v6) = get_local_ips();
209 let inner = Arc::new(Mutex::new(MdnsServiceInner {
210 cache: RecordCache::new(),
211 queries: Vec::new(),
212 services: Vec::new(),
213 local_ips_v4: v4,
214 local_ips_v6: v6,
215 }));
216
217 let mut mcast_sockets: Vec<McastSocket> = Vec::new();
219
220 match create_multicast_socket_v4() {
222 Ok(std_sock) => match UdpSocket::from_std(std_sock) {
223 Ok(s) => mcast_sockets.push(McastSocket {
224 sock: Arc::new(s),
225 multicast_addr: MDNS_ADDR_V4,
226 }),
227 Err(e) => log::warn!("mdns2: failed to wrap v4 socket: {}", e),
228 },
229 Err(e) => log::warn!("mdns2: failed to create v4 socket: {}", e),
230 }
231
232 if let Ok(ifaces) = if_addrs::get_if_addrs() {
233 let mut seen_indices = std::collections::HashSet::new();
234 for iface in ifaces {
235 if !iface.ip().is_ipv6() {
236 continue;
237 }
238 if let Some(idx) = iface.index {
239 if !seen_indices.insert(idx) {
240 continue;
241 }
242 match create_multicast_socket_v6(idx) {
243 Ok(std_sock) => match UdpSocket::from_std(std_sock) {
244 Ok(s) => mcast_sockets.push(McastSocket {
245 sock: Arc::new(s),
246 multicast_addr: MDNS_ADDR_V6,
247 }),
248 Err(e) => {
249 log::debug!("mdns2: failed to wrap v6 socket idx={}: {}", idx, e)
250 }
251 },
252 Err(e) => {
253 log::debug!("mdns2: failed to create v6 socket idx={}: {}", idx, e)
254 }
255 }
256 }
257 }
258 }
259
260 if mcast_sockets.is_empty() {
261 anyhow::bail!("mdns2: no sockets could be created");
262 }
263
264 for ms in &mcast_sockets {
266 let sock = ms.sock.clone();
267 let inner = inner.clone();
268 let send_tx = send_tx.clone();
269 let event_tx = event_tx.clone();
270 let cancel = cancel.child_token();
271 tokio::spawn(async move {
272 recv_loop(sock, inner, send_tx, event_tx, cancel).await;
273 });
274 }
275
276 {
278 let inner = inner.clone();
279 let send_tx = send_tx.clone();
280 let event_tx = event_tx.clone();
281 let cancel = cancel.child_token();
282 tokio::spawn(async move {
283 periodic_loop(inner, send_tx, event_tx, cancel).await;
284 });
285 }
286
287 {
289 let cancel = cancel.child_token();
290 tokio::spawn(async move {
291 send_loop(mcast_sockets, send_rx, cancel).await;
292 });
293 }
294
295 let service = Arc::new(MdnsService {
296 inner,
297 send_tx,
298 event_tx,
299 cancel,
300 });
301
302 Ok(service)
303 }
304
305 pub fn subscribe(&self) -> broadcast::Receiver<MdnsEvent> {
313 self.event_tx.subscribe()
314 }
315
316 pub async fn add_query(&self, label: &str, qtype: u16, interval: Duration) {
318 let mut state = self.inner.lock().await;
319 let sent_at = Instant::now();
321 if let Ok(pkt) = mdns::create_query(label, qtype) {
322 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
323 }
324 state.queries.push(PeriodicQuery {
325 label: label.to_owned(),
326 qtype,
327 interval,
328 last_sent: sent_at,
329 });
330 }
331
332 pub async fn remove_query(&self, label: &str) {
334 let mut state = self.inner.lock().await;
335 state.queries.retain(|q| q.label != label);
336 }
337
338 pub async fn register_service(&self, reg: ServiceRegistration) {
340 let mut state = self.inner.lock().await;
341 state.services.push(reg);
342 }
343
344 pub async fn unregister_service(&self, instance: &str, service_type: &str) {
346 let mut state = self.inner.lock().await;
347 let idx = state
348 .services
349 .iter()
350 .position(|s| s.instance_name == instance && s.service_type == service_type);
351 if let Some(idx) = idx {
352 let reg = state.services.remove(idx);
353 let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
355 let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
356 let mut goodbye_records = build_service_records(®, svc_v4, svc_v6);
357 for rr in &mut goodbye_records {
358 rr.ttl = 0;
359 }
360 drop(state);
361 if let Ok(pkt) = build_response(&goodbye_records, &[]) {
362 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
363 }
364 }
365 }
366
367 pub async fn announce(&self) {
369 let state = self.inner.lock().await;
370 let mut all_answers = Vec::new();
371 let mut all_additional = Vec::new();
372 for reg in &state.services {
373 let svc_v4 = reg.ips_v4.as_deref().unwrap_or(&state.local_ips_v4);
374 let svc_v6 = reg.ips_v6.as_deref().unwrap_or(&state.local_ips_v6);
375 let records = build_service_records(reg, svc_v4, svc_v6);
376 for r in records {
378 if r.typ == mdns::TYPE_PTR {
379 all_answers.push(r);
380 } else {
381 all_additional.push(r);
382 }
383 }
384 }
385 drop(state);
386
387 if !all_answers.is_empty() {
388 if let Ok(pkt) = build_response(&all_answers, &all_additional) {
389 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
390 }
391 }
392 }
393
394 pub async fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
396 let state = self.inner.lock().await;
397 if qtype == mdns::QTYPE_ANY {
398 state.cache.lookup_name(name)
399 } else {
400 state.cache.lookup(name, qtype)
401 }
402 }
403
404 pub async fn active_lookup(&self, name: &str, qtype: u16) {
405 if let Ok(pkt) = mdns::create_query(name, qtype) {
406 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
407 }
408 }
409
410 pub fn shutdown(&self) {
412 self.cancel.cancel();
413 }
414}
415
416impl Drop for MdnsService {
417 fn drop(&mut self) {
418 self.cancel.cancel();
419 }
420}