1mod dnssd;
9mod protocol;
10
11pub use dnssd::{MdnsEvent, ServiceRegistration};
12pub use protocol::{CachedRecord, RecordCache};
13
14use std::collections::HashSet;
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18
19use anyhow::Result;
20use tokio::net::UdpSocket;
21use tokio::sync::Mutex;
22use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
23use tokio_util::sync::CancellationToken;
24
25use crate::mdns;
26use dnssd::{PeriodicQuery, build_service_records, find_matching_services};
27use protocol::{
28 MDNS_ADDR_V4, MDNS_ADDR_V6, McastSocket, SendCommand, build_response,
29 create_multicast_socket_v4, create_multicast_socket_v6, get_local_ips, send_loop,
30};
31
32fn dedup_records(records: &mut Vec<mdns::RR>) {
33 let mut seen = HashSet::new();
34 records.retain(|r| seen.insert(r.clone()));
35}
36
37struct MdnsServiceInner {
38 cache: RecordCache,
39 queries: Vec<PeriodicQuery>,
40 services: Vec<ServiceRegistration>,
41 local_ips_v4: Vec<Ipv4Addr>,
42 local_ips_v6: Vec<Ipv6Addr>,
43}
44
45pub struct MdnsService {
47 inner: Arc<Mutex<MdnsServiceInner>>,
48 send_tx: UnboundedSender<SendCommand>,
49 cancel: CancellationToken,
50}
51
52async fn recv_loop(
53 socket: Arc<UdpSocket>,
54 inner: Arc<Mutex<MdnsServiceInner>>,
55 send_tx: UnboundedSender<SendCommand>,
56 event_tx: UnboundedSender<MdnsEvent>,
57 cancel: CancellationToken,
58) {
59 let mut buf = vec![0u8; 9000];
60 loop {
61 let (n, addr) = tokio::select! {
62 result = socket.recv_from(&mut buf) => {
63 match result {
64 Ok(v) => v,
65 Err(e) => {
66 log::debug!("mdns2 recv error: {}", e);
67 continue;
68 }
69 }
70 }
71 _ = cancel.cancelled() => return,
72 };
73
74 let data = &buf[..n];
75 let msg = match mdns::parse_dns(data, addr) {
76 Ok(m) => m,
77 Err(e) => {
78 log::trace!("mdns2: failed to parse packet from {}: {:?}", addr, e);
79 continue;
80 }
81 };
82
83 let is_response = msg.flags & 0x8000 != 0;
84
85 if is_response {
86 let mut state = inner.lock().await;
88 let all_records: Vec<mdns::RR> = msg
89 .answers
90 .iter()
91 .chain(msg.additional.iter())
92 .cloned()
93 .collect();
94
95 let mut new_ptr_records = Vec::new();
96 for rr in &all_records {
97 state.cache.ingest(rr);
98 if rr.typ == mdns::TYPE_PTR {
99 if let mdns::RRData::PTR(ref target) = rr.data {
100 new_ptr_records.push((rr.name.clone(), target.clone()));
101 }
102 }
103 }
104 for (name, target) in new_ptr_records {
105 let _ = event_tx.send(MdnsEvent::ServiceDiscovered {
106 name,
107 target,
108 records: all_records.clone(),
109 });
110 }
111 } else {
112 let state = inner.lock().await;
114 if state.services.is_empty() {
115 continue;
116 }
117 let mut all_answers = Vec::new();
118 let mut all_additional = Vec::new();
119 for q in &msg.queries {
120 let (ans, add) = find_matching_services(
121 &q.name,
122 q.typ,
123 &state.services,
124 &state.local_ips_v4,
125 &state.local_ips_v6,
126 );
127 all_answers.extend(ans);
128 all_additional.extend(add);
129 }
130 drop(state);
131
132 dedup_records(&mut all_answers);
134 dedup_records(&mut all_additional);
135 all_additional.retain(|r| !all_answers.contains(r));
137
138 if !all_answers.is_empty() {
139 if let Ok(packet) = build_response(&all_answers, &all_additional) {
140 let _ = send_tx.send(SendCommand::Multicast(packet));
141 }
142 }
143 }
144 }
145}
146
147async fn periodic_loop(
148 inner: Arc<Mutex<MdnsServiceInner>>,
149 send_tx: UnboundedSender<SendCommand>,
150 event_tx: UnboundedSender<MdnsEvent>,
151 cancel: CancellationToken,
152) {
153 let mut interval = tokio::time::interval(Duration::from_secs(1));
154 loop {
155 tokio::select! {
156 _ = interval.tick() => {}
157 _ = cancel.cancelled() => return,
158 }
159
160 let mut state = inner.lock().await;
161
162 let expired = state.cache.evict_expired();
164 for (name, rtype) in expired {
165 let _ = event_tx.send(MdnsEvent::ServiceExpired { name, rtype });
166 }
167
168 let now = Instant::now();
170 let mut packets = Vec::new();
171 for q in &mut state.queries {
172 if now.duration_since(q.last_sent) >= q.interval {
173 if let Ok(pkt) = mdns::create_query(&q.label, q.qtype) {
174 packets.push(pkt);
175 }
176 q.last_sent = now;
177 }
178 }
179 drop(state);
180
181 for pkt in packets {
182 let _ = send_tx.send(SendCommand::Multicast(pkt));
183 }
184
185 let (v4, v6) = get_local_ips();
187 let mut state = inner.lock().await;
188 state.local_ips_v4 = v4;
189 state.local_ips_v6 = v6;
190 }
191}
192
193impl MdnsService {
194 pub async fn new() -> Result<(Arc<Self>, UnboundedReceiver<MdnsEvent>)> {
196 let (event_tx, event_rx) = mpsc::unbounded_channel();
197 let (send_tx, send_rx) = mpsc::unbounded_channel();
198 let cancel = CancellationToken::new();
199
200 let (v4, v6) = get_local_ips();
201 let inner = Arc::new(Mutex::new(MdnsServiceInner {
202 cache: RecordCache::new(),
203 queries: Vec::new(),
204 services: Vec::new(),
205 local_ips_v4: v4,
206 local_ips_v6: v6,
207 }));
208
209 let mut mcast_sockets: Vec<McastSocket> = Vec::new();
211
212 match create_multicast_socket_v4() {
214 Ok(std_sock) => match UdpSocket::from_std(std_sock) {
215 Ok(s) => mcast_sockets.push(McastSocket {
216 sock: Arc::new(s),
217 multicast_addr: MDNS_ADDR_V4,
218 }),
219 Err(e) => log::warn!("mdns2: failed to wrap v4 socket: {}", e),
220 },
221 Err(e) => log::warn!("mdns2: failed to create v4 socket: {}", e),
222 }
223
224 if let Ok(ifaces) = if_addrs::get_if_addrs() {
226 let mut seen_indices = std::collections::HashSet::new();
227 for iface in ifaces {
228 if !iface.ip().is_ipv6() {
229 continue;
230 }
231 if let Some(idx) = iface.index {
232 if !seen_indices.insert(idx) {
233 continue;
234 }
235 match create_multicast_socket_v6(idx) {
236 Ok(std_sock) => match UdpSocket::from_std(std_sock) {
237 Ok(s) => mcast_sockets.push(McastSocket {
238 sock: Arc::new(s),
239 multicast_addr: MDNS_ADDR_V6,
240 }),
241 Err(e) => {
242 log::debug!("mdns2: failed to wrap v6 socket idx={}: {}", idx, e)
243 }
244 },
245 Err(e) => {
246 log::debug!("mdns2: failed to create v6 socket idx={}: {}", idx, e)
247 }
248 }
249 }
250 }
251 }
252
253 if mcast_sockets.is_empty() {
254 anyhow::bail!("mdns2: no sockets could be created");
255 }
256
257 for ms in &mcast_sockets {
259 let sock = ms.sock.clone();
260 let inner = inner.clone();
261 let send_tx = send_tx.clone();
262 let event_tx = event_tx.clone();
263 let cancel = cancel.child_token();
264 tokio::spawn(async move {
265 recv_loop(sock, inner, send_tx, event_tx, cancel).await;
266 });
267 }
268
269 {
271 let inner = inner.clone();
272 let send_tx = send_tx.clone();
273 let event_tx = event_tx.clone();
274 let cancel = cancel.child_token();
275 tokio::spawn(async move {
276 periodic_loop(inner, send_tx, event_tx, cancel).await;
277 });
278 }
279
280 {
282 let cancel = cancel.child_token();
283 tokio::spawn(async move {
284 send_loop(mcast_sockets, send_rx, cancel).await;
285 });
286 }
287
288 let service = Arc::new(MdnsService {
289 inner,
290 send_tx,
291 cancel,
292 });
293
294 Ok((service, event_rx))
295 }
296
297 pub async fn add_query(&self, label: &str, qtype: u16, interval: Duration) {
299 let mut state = self.inner.lock().await;
300 let sent_at = Instant::now();
302 if let Ok(pkt) = mdns::create_query(label, qtype) {
303 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
304 }
305 state.queries.push(PeriodicQuery {
306 label: label.to_owned(),
307 qtype,
308 interval,
309 last_sent: sent_at,
310 });
311 }
312
313 pub async fn remove_query(&self, label: &str) {
315 let mut state = self.inner.lock().await;
316 state.queries.retain(|q| q.label != label);
317 }
318
319 pub async fn register_service(&self, reg: ServiceRegistration) {
321 let mut state = self.inner.lock().await;
322 state.services.push(reg);
323 }
324
325 pub async fn unregister_service(&self, instance: &str, service_type: &str) {
327 let mut state = self.inner.lock().await;
328 let idx = state
329 .services
330 .iter()
331 .position(|s| s.instance_name == instance && s.service_type == service_type);
332 if let Some(idx) = idx {
333 let reg = state.services.remove(idx);
334 let mut goodbye_records =
336 build_service_records(®, &state.local_ips_v4, &state.local_ips_v6);
337 for rr in &mut goodbye_records {
338 rr.ttl = 0;
339 }
340 drop(state);
341 if let Ok(pkt) = build_response(&goodbye_records, &[]) {
342 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
343 }
344 }
345 }
346
347 pub async fn announce(&self) {
349 let state = self.inner.lock().await;
350 let mut all_answers = Vec::new();
351 let mut all_additional = Vec::new();
352 for reg in &state.services {
353 let records = build_service_records(reg, &state.local_ips_v4, &state.local_ips_v6);
354 for r in records {
356 if r.typ == mdns::TYPE_PTR {
357 all_answers.push(r);
358 } else {
359 all_additional.push(r);
360 }
361 }
362 }
363 drop(state);
364
365 if !all_answers.is_empty() {
366 if let Ok(pkt) = build_response(&all_answers, &all_additional) {
367 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
368 }
369 }
370 }
371
372 pub async fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
374 let state = self.inner.lock().await;
375 if qtype == mdns::QTYPE_ANY {
376 state.cache.lookup_name(name)
377 } else {
378 state.cache.lookup(name, qtype)
379 }
380 }
381
382 pub async fn active_lookup(&self, name: &str, qtype: u16) {
383 if let Ok(pkt) = mdns::create_query(name, qtype) {
384 let _ = self.send_tx.send(SendCommand::Multicast(pkt));
385 }
386 }
387
388 pub fn shutdown(&self) {
390 self.cancel.cancel();
391 }
392}
393
394impl Drop for MdnsService {
395 fn drop(&mut self) {
396 self.cancel.cancel();
397 }
398}