matc/
mdns.rs

1//! Very simple mdns client library
2
3use std::{borrow::Cow, collections::HashMap, io::{Cursor, Read, Write}};
4
5use anyhow::{Context, Result};
6
7use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
8use socket2::{Domain, Protocol, Type};
9
10pub const TYPE_A: u16 = 1;
11pub const TYPE_CNAME: u16 = 5;
12pub const TYPE_PTR: u16 = 12;
13pub const TYPE_TXT: u16 = 16;
14pub const TYPE_AAAA: u16 = 28;
15pub const TYPE_SRV: u16 = 33;
16pub const TYPE_NAPTR: u16 = 35;
17pub const QTYPE_ANY: u16 = 0xff;
18
19pub fn encode_label(label: &str, out: &mut Vec<u8>) -> Result<()> {
20    for seg in label.split(".") {
21        if seg.is_empty() {
22            continue;
23        }
24        let bytes = seg.as_bytes();
25        if bytes.len() > 63 {
26            anyhow::bail!("DNS label segment exceeds 63 bytes: {} bytes", bytes.len());
27        }
28        out.write_u8(bytes.len() as u8)?;
29        out.write_all(bytes)?;
30    }
31    out.write_u8(0)?;
32    Ok(())
33}
34
35pub fn encode_label_compressed(
36    label: &str,
37    out: &mut Vec<u8>,
38    name_offsets: &mut HashMap<String, usize>,
39) -> Result<()> {
40    let segments: Vec<&str> = label.split('.').filter(|s| !s.is_empty()).collect();
41
42    for i in 0..segments.len() {
43        let suffix = segments[i..].join(".");
44        if let Some(&offset) = name_offsets.get(&suffix) {
45            if offset < 0x3FFF {
46                // Write 2-byte compression pointer to the previously-written suffix
47                out.write_u8(0xC0 | ((offset >> 8) as u8))?;
48                out.write_u8((offset & 0xFF) as u8)?;
49                return Ok(());
50            }
51        }
52        // Record where this suffix starts, then write this segment
53        name_offsets.insert(suffix, out.len());
54        let bytes = segments[i].as_bytes();
55        if bytes.len() > 63 {
56            anyhow::bail!("DNS label segment exceeds 63 bytes: {} bytes", bytes.len());
57        }
58        out.write_u8(bytes.len() as u8)?;
59        out.write_all(bytes)?;
60    }
61
62    // No suffix matched — terminate with null
63    out.write_u8(0)?;
64    Ok(())
65}
66
67pub(crate) fn create_query(label: &str, qtype: u16) -> Result<Vec<u8>> {
68    let mut out = Vec::with_capacity(512);
69    out.write_u16::<BigEndian>(rand::random::<u16>())?; // transaction id
70    out.write_u16::<BigEndian>(0)?; // flags
71    out.write_u16::<BigEndian>(1)?; // questions
72    out.write_u16::<BigEndian>(0)?; // answers
73    out.write_u16::<BigEndian>(0)?; // authority
74    out.write_u16::<BigEndian>(0)?; // additional
75
76    encode_label(label, &mut out)?;
77
78    out.write_u16::<BigEndian>(qtype)?;
79    out.write_u16::<BigEndian>(0x0001)?; // class
80    Ok(out)
81}
82
83fn read_label(data: &[u8], cursor: &mut Cursor<&[u8]>) -> Result<String> {
84    let mut out = Vec::new();
85    let mut depth = 0;
86    loop {
87        depth += 1;
88        if depth > 64 {
89            anyhow::bail!("too many label indirections");
90        }
91        let n = cursor.read_u8()?;
92        if n == 0 {
93            break;
94        } else if n & 0xc0 == 0xc0 {
95            let off = {
96                let off = n & 0x3f;
97                ((off as usize) << 8) | (cursor.read_u8()? as u16) as usize
98            };
99            if off >= data.len() {
100                anyhow::bail!("invalid compression pointer offset");
101            }
102            let frag = read_label(data, &mut Cursor::new(&data[off..]))?;
103            out.extend_from_slice(frag.as_bytes());
104            break;
105        } else {
106            // RFC 1035: label length must be <= 63
107            if n > 63 {
108                anyhow::bail!("DNS label segment exceeds 63 bytes: {}", n);
109            }
110            let mut b = vec![0; n as usize];
111            cursor.read_exact(&mut b)?;
112            out.extend_from_slice(&b);
113            out.extend_from_slice(b".");
114        }
115    }
116    // RFC 1035: total domain name length must be <= 255
117    if out.len() > 1024 {
118        anyhow::bail!("DNS domain name exceeds 1024 bytes: {}", out.len());
119    }
120    Ok(std::str::from_utf8(&out)?.to_owned())
121}
122
123
124#[derive(Debug, Eq, PartialEq, Hash, Clone)]
125pub enum RRData {
126    A(std::net::Ipv4Addr),
127    AAAA(std::net::Ipv6Addr),
128    PTR(String),
129    TXT(Vec<String>),
130    SRV { priority: u16, weight: u16, port: u16, target: String },
131    Unknown(Vec<u8>),
132}
133
134#[derive(Debug, Eq, PartialEq, Hash, Clone)]
135pub struct RR {
136    pub name: String,
137    pub typ: u16,
138    pub class: u16,
139    pub ttl: u32,
140    pub rdata: Vec<u8>,
141    pub target: Option<String>,
142    pub data: RRData,
143}
144
145#[derive(Debug, Eq, PartialEq, Hash)]
146pub struct Query {
147    pub name: String,
148    pub typ: u16,
149    pub class: u16,
150}
151
152#[derive(Debug, Eq, PartialEq, Hash)]
153pub struct DnsMessage {
154    pub source: std::net::SocketAddr,
155    pub transaction: u16,
156    pub flags: u16,
157    pub queries: Vec<Query>,
158    pub answers: Vec<RR>,
159    pub authority: Vec<RR>,
160    pub additional: Vec<RR>,
161}
162
163impl RR {
164    pub fn dump(&self, indent: usize) {
165        println!(
166            "{} {} {}",
167            " ".to_owned().repeat(indent),
168            self.name,
169            self.typ
170        )
171    }
172}
173
174fn rr_type_to_string(typ: u16) -> Cow<'static, str> {
175    match typ {
176        TYPE_A => "A".into(),
177        TYPE_PTR => "PTR".into(),
178        TYPE_TXT => "TXT".into(),
179        TYPE_AAAA => "AAAA".into(),
180        TYPE_SRV => "SRV".into(),
181        TYPE_NAPTR => "NAPTR".into(),
182        TYPE_CNAME => "CNAME".into(),
183        _ => std::fmt::format(format_args!("TYPE{}", typ)).into(),
184    }
185}
186
187impl std::fmt::Display for RR {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        write!(
190            f,
191            "{} {} TTL:{}",
192            self.name, rr_type_to_string(self.typ), self.ttl
193        )
194    }
195}
196
197impl Query {
198    pub fn dump(&self, indent: usize) {
199        println!(
200            "{} {} {}",
201            " ".to_owned().repeat(indent),
202            self.name,
203            self.typ
204        )
205    }
206}
207
208impl DnsMessage {
209    pub fn dump(&self) {
210        println!("{:?} {} {:x}", self.source, self.transaction, self.flags);
211        println!("  queries:");
212        for queries in &self.queries {
213            queries.dump(4);
214        }
215        println!("  answers:");
216        for answer in &self.answers {
217            answer.dump(4);
218        }
219        println!("  authority:");
220        for authority in &self.authority {
221            authority.dump(4);
222        }
223        println!("  additional:");
224        for additional in &self.additional {
225            additional.dump(4);
226        }
227    }
228}
229
230fn parse_rr(data: &[u8], cursor: &mut Cursor<&[u8]>) -> Result<RR> {
231    let name = read_label(data, cursor)?;
232    let typ = cursor.read_u16::<BigEndian>()?;
233    let class = cursor.read_u16::<BigEndian>()?;
234    let ttl = cursor.read_u32::<BigEndian>()?;
235    let dlen = cursor.read_u16::<BigEndian>()?;
236    let mut rdata = vec![0; dlen as usize];
237    cursor.read_exact(&mut rdata)?;
238    let mut target = None;
239    if typ == TYPE_SRV && rdata.len() >= 6 {
240        target = Some(read_label(data, &mut Cursor::new(&rdata[6..])).context("can't parse target from SRV")?);
241    }
242    let rrdata = match typ {
243        TYPE_A if rdata.len() == 4 => RRData::A(std::net::Ipv4Addr::from_octets(rdata[0..4].try_into().context("invalid A rdata length")?)),
244        TYPE_AAAA if rdata.len() == 16 => RRData::AAAA(std::net::Ipv6Addr::from_octets(rdata[0..16].try_into().context("invalid AAAA rdata length")?)),
245        TYPE_PTR => RRData::PTR(read_label(data, &mut Cursor::new(&rdata)).context("can't parse PTR rdata")?),
246        TYPE_TXT => RRData::TXT(rdata.split(|b| *b == 0).filter_map(|s| std::str::from_utf8(s).ok().map(|s| s.to_owned())).collect()),
247        TYPE_SRV if rdata.len() >= 6 => {
248            let mut cursor = Cursor::new(rdata.as_slice());
249            let priority = cursor.read_u16::<BigEndian>()?;
250            let weight = cursor.read_u16::<BigEndian>()?;
251            let port = cursor.read_u16::<BigEndian>()?;
252            let target = read_label(data, &mut cursor).context("can't parse target from SRV")?;
253            RRData::SRV { priority, weight, port, target }
254        }
255        _ => RRData::Unknown(rdata.clone()),
256    };
257
258    Ok(RR {
259        name,
260        typ,
261        class,
262        ttl,
263        rdata,
264        target,
265        data: rrdata,
266    })
267}
268
269fn parse_q(data: &[u8], cursor: &mut Cursor<&[u8]>) -> Result<Query> {
270    let name = read_label(data, cursor)?;
271    let typ = cursor.read_u16::<BigEndian>()?;
272    let class = cursor.read_u16::<BigEndian>()?;
273
274    Ok(Query { name, typ, class })
275}
276
277pub fn parse_dns(data: &[u8], source: std::net::SocketAddr) -> Result<DnsMessage> {
278    let mut cursor = Cursor::new(data);
279    let transaction = cursor.read_u16::<BigEndian>()?;
280    let flags = cursor.read_u16::<BigEndian>()?;
281    let nquestions = cursor.read_u16::<BigEndian>()?;
282    let nanswers = cursor.read_u16::<BigEndian>()?;
283    let nauthority = cursor.read_u16::<BigEndian>()?;
284    let nadditional = cursor.read_u16::<BigEndian>()?;
285
286    let mut queries = Vec::new();
287    let mut answers = Vec::new();
288    let mut additional = Vec::new();
289    let mut authority = Vec::new();
290
291    for _ in 0..nquestions {
292        queries.push(parse_q(data, &mut cursor)?);
293    }
294    for _ in 0..nanswers {
295        answers.push(parse_rr(data, &mut cursor)?);
296    }
297    for _ in 0..nauthority {
298        authority.push(parse_rr(data, &mut cursor)?);
299    }
300    for _ in 0..nadditional {
301        additional.push(parse_rr(data, &mut cursor)?);
302    }
303
304    Ok(DnsMessage {
305        source,
306        transaction,
307        flags,
308        queries,
309        answers,
310        authority,
311        additional,
312    })
313}
314
315async fn discoverv4(
316    label: &str,
317    qtype: u16,
318    sender: tokio::sync::mpsc::UnboundedSender<DnsMessage>,
319    cancel: tokio_util::sync::CancellationToken,
320) -> Result<()> {
321    let stdsocket = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
322    stdsocket.set_reuse_address(true)?;
323    #[cfg(not(target_os = "windows"))]
324    stdsocket.set_reuse_port(true)?;
325    let addr: std::net::SocketAddrV4 = "0.0.0.0:5353".parse()?;
326    stdsocket.bind(&socket2::SockAddr::from(addr))?;
327    let maddr: std::net::Ipv4Addr = "224.0.0.251".parse()?;
328    stdsocket.join_multicast_v4(&maddr, &std::net::Ipv4Addr::UNSPECIFIED)?;
329    stdsocket.set_nonblocking(true)?;
330    let socket = tokio::net::UdpSocket::from_std(stdsocket.into())?;
331    let query = create_query(label, qtype)?;
332    socket.send_to(&query, "224.0.0.251:5353").await?;
333    loop {
334        let mut buf = vec![0; 9000];
335        let (n, addr) = tokio::select! {
336            v = socket.recv_from(&mut buf) => v?,
337            _ = cancel.cancelled() => return Ok(())
338        };
339
340        buf.resize(n, 0);
341        let dns = parse_dns(&buf, addr);
342        let dns = match dns {
343            Ok(v) => v,
344            Err(e) => {
345                log::debug!("failed to parse mdns message: {}", e);
346                continue;
347            }
348        };
349        if dns.flags == 0 {
350            // ignore requests
351            continue;
352        }
353        sender.send(dns)?;
354    }
355}
356
357async fn discoverv6(
358    label: &str,
359    qtype: u16,
360    interface: u32,
361    sender: tokio::sync::mpsc::UnboundedSender<DnsMessage>,
362    cancel: tokio_util::sync::CancellationToken,
363) -> Result<()> {
364    let stdsocket = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
365    stdsocket.set_reuse_address(true)?;
366    #[cfg(not(target_os = "windows"))]
367    stdsocket.set_reuse_port(true)?;
368    let addr: std::net::SocketAddrV6 = "[::]:5353".parse()?;
369    stdsocket.bind(&socket2::SockAddr::from(addr))?;
370    let maddr: std::net::Ipv6Addr = "ff02::fb".parse()?;
371    stdsocket.join_multicast_v6(&maddr, interface)?;
372    stdsocket.set_multicast_if_v6(interface)?;
373    stdsocket.set_nonblocking(true)?;
374    let socket = tokio::net::UdpSocket::from_std(stdsocket.into())?;
375    let query = create_query(label, qtype)?;
376    socket.send_to(&query, "[ff02::fb]:5353").await?;
377    loop {
378        let mut buf = vec![0; 9000];
379        //let (n, addr) = socket.recv_from(&mut buf).await?;
380        let (n, addr) = tokio::select! {
381            v = socket.recv_from(&mut buf) => v?,
382            _ = cancel.cancelled() => return Ok(())
383        };
384        buf.resize(n, 0);
385        let dns = parse_dns(&buf, addr);
386        let dns = match dns {
387            Ok(v) => v,
388            Err(e) => {
389                log::debug!("failed to parse mdns message: {}", e);
390                continue;
391            }
392        };
393        if dns.flags == 0 {
394            // ignore requests
395            continue;
396        }
397        sender.send(dns)?;
398    }
399}
400
401pub async fn discover(
402    label: &str,
403    qtype: u16,
404    sender: tokio::sync::mpsc::UnboundedSender<DnsMessage>,
405    stop: tokio_util::sync::CancellationToken,
406) -> Result<()> {
407    let ifaces = if_addrs::get_if_addrs();
408    if let Ok(ifaces) = ifaces {
409        for iface in ifaces {
410            let stop_child = stop.child_token();
411            if !iface.ip().is_ipv6() {
412                continue;
413            }
414            if let Some(index) = iface.index {
415                let sender2 = sender.clone();
416                let label = label.to_owned();
417                tokio::spawn(async move {
418                    let e = discoverv6(&label, qtype, index, sender2, stop_child).await;
419                    if let Err(e) = e {
420                        log::warn!("mdns discover error: {}", e);
421                    }
422                });
423            }
424        }
425    };
426
427    let stop_child = stop.child_token();
428    let label = label.to_owned();
429    tokio::spawn(async move {
430        let e = discoverv4(&label, qtype, sender, stop_child).await;
431        if let Err(e) = e {
432            log::warn!("mdns discover error: {}", e);
433        }
434    });
435
436    Ok(())
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use std::collections::HashMap;
443
444    #[test]
445    fn compressed_single_label_matches_uncompressed() {
446        let label = "foo._tcp.local";
447        let mut plain = Vec::new();
448        encode_label(label, &mut plain).unwrap();
449
450        let mut compressed = Vec::new();
451        let mut offsets = HashMap::new();
452        encode_label_compressed(label, &mut compressed, &mut offsets).unwrap();
453
454        assert_eq!(plain, compressed);
455    }
456
457    #[test]
458    fn compressed_reuses_shared_suffix() {
459        let mut out = Vec::new();
460        let mut offsets = HashMap::new();
461
462        encode_label_compressed("foo._tcp.local", &mut out, &mut offsets).unwrap();
463        let first_len = out.len();
464
465        encode_label_compressed("bar._tcp.local", &mut out, &mut offsets).unwrap();
466        let second_len = out.len() - first_len;
467
468        // "bar" (1+3) + pointer (2) = 6 bytes, much less than full uncompressed
469        assert_eq!(second_len, 6);
470
471        // The last two bytes should be a compression pointer to "_tcp.local" in the first label
472        let ptr_hi = out[first_len + 4];
473        let ptr_lo = out[first_len + 5];
474        assert_eq!(ptr_hi & 0xC0, 0xC0, "top 2 bits must be set for pointer");
475
476        let ptr_offset = (((ptr_hi & 0x3F) as usize) << 8) | (ptr_lo as usize);
477        // "_tcp.local" starts at offset 4 in the first label (after \x03foo)
478        assert_eq!(ptr_offset, 4);
479    }
480
481    #[test]
482    fn compressed_output_decodable_by_read_label() {
483        // Build a small packet with two labels sharing a suffix
484        let mut pkt = Vec::new();
485        let mut offsets = HashMap::new();
486
487        encode_label_compressed("foo._tcp.local", &mut pkt, &mut offsets).unwrap();
488        let second_start = pkt.len();
489        encode_label_compressed("bar._tcp.local", &mut pkt, &mut offsets).unwrap();
490
491        // Decode first label
492        let label1 = read_label(&pkt, &mut Cursor::new(&pkt[..])).unwrap();
493        assert_eq!(label1, "foo._tcp.local.");
494
495        // Decode second label (uses compression pointer)
496        let label2 = read_label(&pkt, &mut Cursor::new(&pkt[second_start..])).unwrap();
497        assert_eq!(label2, "bar._tcp.local.");
498    }
499}