1use std::{borrow::Cow, collections::HashMap, io::{Cursor, Read, Write}};
4
5use anyhow::{Context, Result};
6
7use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
8use socket2::{Domain, Protocol, Type};
9
10pub const TYPE_A: u16 = 1;
11pub const TYPE_CNAME: u16 = 5;
12pub const TYPE_PTR: u16 = 12;
13pub const TYPE_TXT: u16 = 16;
14pub const TYPE_AAAA: u16 = 28;
15pub const TYPE_SRV: u16 = 33;
16pub const TYPE_NAPTR: u16 = 35;
17pub const QTYPE_ANY: u16 = 0xff;
18
19pub fn encode_label(label: &str, out: &mut Vec<u8>) -> Result<()> {
20 for seg in label.split(".") {
21 if seg.is_empty() {
22 continue;
23 }
24 let bytes = seg.as_bytes();
25 if bytes.len() > 63 {
26 anyhow::bail!("DNS label segment exceeds 63 bytes: {} bytes", bytes.len());
27 }
28 out.write_u8(bytes.len() as u8)?;
29 out.write_all(bytes)?;
30 }
31 out.write_u8(0)?;
32 Ok(())
33}
34
35pub fn encode_label_compressed(
36 label: &str,
37 out: &mut Vec<u8>,
38 name_offsets: &mut HashMap<String, usize>,
39) -> Result<()> {
40 let segments: Vec<&str> = label.split('.').filter(|s| !s.is_empty()).collect();
41
42 for i in 0..segments.len() {
43 let suffix = segments[i..].join(".");
44 if let Some(&offset) = name_offsets.get(&suffix) {
45 if offset < 0x3FFF {
46 out.write_u8(0xC0 | ((offset >> 8) as u8))?;
48 out.write_u8((offset & 0xFF) as u8)?;
49 return Ok(());
50 }
51 }
52 name_offsets.insert(suffix, out.len());
54 let bytes = segments[i].as_bytes();
55 if bytes.len() > 63 {
56 anyhow::bail!("DNS label segment exceeds 63 bytes: {} bytes", bytes.len());
57 }
58 out.write_u8(bytes.len() as u8)?;
59 out.write_all(bytes)?;
60 }
61
62 out.write_u8(0)?;
64 Ok(())
65}
66
67pub(crate) fn create_query(label: &str, qtype: u16) -> Result<Vec<u8>> {
68 let mut out = Vec::with_capacity(512);
69 out.write_u16::<BigEndian>(rand::random::<u16>())?; out.write_u16::<BigEndian>(0)?; out.write_u16::<BigEndian>(1)?; out.write_u16::<BigEndian>(0)?; out.write_u16::<BigEndian>(0)?; out.write_u16::<BigEndian>(0)?; encode_label(label, &mut out)?;
77
78 out.write_u16::<BigEndian>(qtype)?;
79 out.write_u16::<BigEndian>(0x0001)?; Ok(out)
81}
82
83fn read_label(data: &[u8], cursor: &mut Cursor<&[u8]>) -> Result<String> {
84 let mut out = Vec::new();
85 let mut depth = 0;
86 loop {
87 depth += 1;
88 if depth > 64 {
89 anyhow::bail!("too many label indirections");
90 }
91 let n = cursor.read_u8()?;
92 if n == 0 {
93 break;
94 } else if n & 0xc0 == 0xc0 {
95 let off = {
96 let off = n & 0x3f;
97 ((off as usize) << 8) | (cursor.read_u8()? as u16) as usize
98 };
99 if off >= data.len() {
100 anyhow::bail!("invalid compression pointer offset");
101 }
102 let frag = read_label(data, &mut Cursor::new(&data[off..]))?;
103 out.extend_from_slice(frag.as_bytes());
104 break;
105 } else {
106 if n > 63 {
108 anyhow::bail!("DNS label segment exceeds 63 bytes: {}", n);
109 }
110 let mut b = vec![0; n as usize];
111 cursor.read_exact(&mut b)?;
112 out.extend_from_slice(&b);
113 out.extend_from_slice(b".");
114 }
115 }
116 if out.len() > 1024 {
118 anyhow::bail!("DNS domain name exceeds 1024 bytes: {}", out.len());
119 }
120 Ok(std::str::from_utf8(&out)?.to_owned())
121}
122
123
124#[derive(Debug, Eq, PartialEq, Hash, Clone)]
125pub enum RRData {
126 A(std::net::Ipv4Addr),
127 AAAA(std::net::Ipv6Addr),
128 PTR(String),
129 TXT(Vec<String>),
130 SRV { priority: u16, weight: u16, port: u16, target: String },
131 Unknown(Vec<u8>),
132}
133
134#[derive(Debug, Eq, PartialEq, Hash, Clone)]
135pub struct RR {
136 pub name: String,
137 pub typ: u16,
138 pub class: u16,
139 pub ttl: u32,
140 pub rdata: Vec<u8>,
141 pub target: Option<String>,
142 pub data: RRData,
143}
144
145#[derive(Debug, Eq, PartialEq, Hash)]
146pub struct Query {
147 pub name: String,
148 pub typ: u16,
149 pub class: u16,
150}
151
152#[derive(Debug, Eq, PartialEq, Hash)]
153pub struct DnsMessage {
154 pub source: std::net::SocketAddr,
155 pub transaction: u16,
156 pub flags: u16,
157 pub queries: Vec<Query>,
158 pub answers: Vec<RR>,
159 pub authority: Vec<RR>,
160 pub additional: Vec<RR>,
161}
162
163impl RR {
164 pub fn dump(&self, indent: usize) {
165 println!(
166 "{} {} {}",
167 " ".to_owned().repeat(indent),
168 self.name,
169 self.typ
170 )
171 }
172}
173
174fn rr_type_to_string(typ: u16) -> Cow<'static, str> {
175 match typ {
176 TYPE_A => "A".into(),
177 TYPE_PTR => "PTR".into(),
178 TYPE_TXT => "TXT".into(),
179 TYPE_AAAA => "AAAA".into(),
180 TYPE_SRV => "SRV".into(),
181 TYPE_NAPTR => "NAPTR".into(),
182 TYPE_CNAME => "CNAME".into(),
183 _ => std::fmt::format(format_args!("TYPE{}", typ)).into(),
184 }
185}
186
187impl std::fmt::Display for RR {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 write!(
190 f,
191 "{} {} TTL:{}",
192 self.name, rr_type_to_string(self.typ), self.ttl
193 )
194 }
195}
196
197impl Query {
198 pub fn dump(&self, indent: usize) {
199 println!(
200 "{} {} {}",
201 " ".to_owned().repeat(indent),
202 self.name,
203 self.typ
204 )
205 }
206}
207
208impl DnsMessage {
209 pub fn dump(&self) {
210 println!("{:?} {} {:x}", self.source, self.transaction, self.flags);
211 println!(" queries:");
212 for queries in &self.queries {
213 queries.dump(4);
214 }
215 println!(" answers:");
216 for answer in &self.answers {
217 answer.dump(4);
218 }
219 println!(" authority:");
220 for authority in &self.authority {
221 authority.dump(4);
222 }
223 println!(" additional:");
224 for additional in &self.additional {
225 additional.dump(4);
226 }
227 }
228}
229
230fn parse_rr(data: &[u8], cursor: &mut Cursor<&[u8]>) -> Result<RR> {
231 let name = read_label(data, cursor)?;
232 let typ = cursor.read_u16::<BigEndian>()?;
233 let class = cursor.read_u16::<BigEndian>()?;
234 let ttl = cursor.read_u32::<BigEndian>()?;
235 let dlen = cursor.read_u16::<BigEndian>()?;
236 let mut rdata = vec![0; dlen as usize];
237 cursor.read_exact(&mut rdata)?;
238 let mut target = None;
239 if typ == TYPE_SRV && rdata.len() >= 6 {
240 target = Some(read_label(data, &mut Cursor::new(&rdata[6..])).context("can't parse target from SRV")?);
241 }
242 let rrdata = match typ {
243 TYPE_A if rdata.len() == 4 => RRData::A(std::net::Ipv4Addr::from_octets(rdata[0..4].try_into().context("invalid A rdata length")?)),
244 TYPE_AAAA if rdata.len() == 16 => RRData::AAAA(std::net::Ipv6Addr::from_octets(rdata[0..16].try_into().context("invalid AAAA rdata length")?)),
245 TYPE_PTR => RRData::PTR(read_label(data, &mut Cursor::new(&rdata)).context("can't parse PTR rdata")?),
246 TYPE_TXT => RRData::TXT(rdata.split(|b| *b == 0).filter_map(|s| std::str::from_utf8(s).ok().map(|s| s.to_owned())).collect()),
247 TYPE_SRV if rdata.len() >= 6 => {
248 let mut cursor = Cursor::new(rdata.as_slice());
249 let priority = cursor.read_u16::<BigEndian>()?;
250 let weight = cursor.read_u16::<BigEndian>()?;
251 let port = cursor.read_u16::<BigEndian>()?;
252 let target = read_label(data, &mut cursor).context("can't parse target from SRV")?;
253 RRData::SRV { priority, weight, port, target }
254 }
255 _ => RRData::Unknown(rdata.clone()),
256 };
257
258 Ok(RR {
259 name,
260 typ,
261 class,
262 ttl,
263 rdata,
264 target,
265 data: rrdata,
266 })
267}
268
269fn parse_q(data: &[u8], cursor: &mut Cursor<&[u8]>) -> Result<Query> {
270 let name = read_label(data, cursor)?;
271 let typ = cursor.read_u16::<BigEndian>()?;
272 let class = cursor.read_u16::<BigEndian>()?;
273
274 Ok(Query { name, typ, class })
275}
276
277pub fn parse_dns(data: &[u8], source: std::net::SocketAddr) -> Result<DnsMessage> {
278 let mut cursor = Cursor::new(data);
279 let transaction = cursor.read_u16::<BigEndian>()?;
280 let flags = cursor.read_u16::<BigEndian>()?;
281 let nquestions = cursor.read_u16::<BigEndian>()?;
282 let nanswers = cursor.read_u16::<BigEndian>()?;
283 let nauthority = cursor.read_u16::<BigEndian>()?;
284 let nadditional = cursor.read_u16::<BigEndian>()?;
285
286 let mut queries = Vec::new();
287 let mut answers = Vec::new();
288 let mut additional = Vec::new();
289 let mut authority = Vec::new();
290
291 for _ in 0..nquestions {
292 queries.push(parse_q(data, &mut cursor)?);
293 }
294 for _ in 0..nanswers {
295 answers.push(parse_rr(data, &mut cursor)?);
296 }
297 for _ in 0..nauthority {
298 authority.push(parse_rr(data, &mut cursor)?);
299 }
300 for _ in 0..nadditional {
301 additional.push(parse_rr(data, &mut cursor)?);
302 }
303
304 Ok(DnsMessage {
305 source,
306 transaction,
307 flags,
308 queries,
309 answers,
310 authority,
311 additional,
312 })
313}
314
315async fn discoverv4(
316 label: &str,
317 qtype: u16,
318 sender: tokio::sync::mpsc::UnboundedSender<DnsMessage>,
319 cancel: tokio_util::sync::CancellationToken,
320) -> Result<()> {
321 let stdsocket = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
322 stdsocket.set_reuse_address(true)?;
323 #[cfg(not(target_os = "windows"))]
324 stdsocket.set_reuse_port(true)?;
325 let addr: std::net::SocketAddrV4 = "0.0.0.0:5353".parse()?;
326 stdsocket.bind(&socket2::SockAddr::from(addr))?;
327 let maddr: std::net::Ipv4Addr = "224.0.0.251".parse()?;
328 stdsocket.join_multicast_v4(&maddr, &std::net::Ipv4Addr::UNSPECIFIED)?;
329 stdsocket.set_nonblocking(true)?;
330 let socket = tokio::net::UdpSocket::from_std(stdsocket.into())?;
331 let query = create_query(label, qtype)?;
332 socket.send_to(&query, "224.0.0.251:5353").await?;
333 loop {
334 let mut buf = vec![0; 9000];
335 let (n, addr) = tokio::select! {
336 v = socket.recv_from(&mut buf) => v?,
337 _ = cancel.cancelled() => return Ok(())
338 };
339
340 buf.resize(n, 0);
341 let dns = parse_dns(&buf, addr);
342 let dns = match dns {
343 Ok(v) => v,
344 Err(e) => {
345 log::debug!("failed to parse mdns message: {}", e);
346 continue;
347 }
348 };
349 if dns.flags == 0 {
350 continue;
352 }
353 sender.send(dns)?;
354 }
355}
356
357async fn discoverv6(
358 label: &str,
359 qtype: u16,
360 interface: u32,
361 sender: tokio::sync::mpsc::UnboundedSender<DnsMessage>,
362 cancel: tokio_util::sync::CancellationToken,
363) -> Result<()> {
364 let stdsocket = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
365 stdsocket.set_reuse_address(true)?;
366 #[cfg(not(target_os = "windows"))]
367 stdsocket.set_reuse_port(true)?;
368 let addr: std::net::SocketAddrV6 = "[::]:5353".parse()?;
369 stdsocket.bind(&socket2::SockAddr::from(addr))?;
370 let maddr: std::net::Ipv6Addr = "ff02::fb".parse()?;
371 stdsocket.join_multicast_v6(&maddr, interface)?;
372 stdsocket.set_multicast_if_v6(interface)?;
373 stdsocket.set_nonblocking(true)?;
374 let socket = tokio::net::UdpSocket::from_std(stdsocket.into())?;
375 let query = create_query(label, qtype)?;
376 socket.send_to(&query, "[ff02::fb]:5353").await?;
377 loop {
378 let mut buf = vec![0; 9000];
379 let (n, addr) = tokio::select! {
381 v = socket.recv_from(&mut buf) => v?,
382 _ = cancel.cancelled() => return Ok(())
383 };
384 buf.resize(n, 0);
385 let dns = parse_dns(&buf, addr);
386 let dns = match dns {
387 Ok(v) => v,
388 Err(e) => {
389 log::debug!("failed to parse mdns message: {}", e);
390 continue;
391 }
392 };
393 if dns.flags == 0 {
394 continue;
396 }
397 sender.send(dns)?;
398 }
399}
400
401pub async fn discover(
402 label: &str,
403 qtype: u16,
404 sender: tokio::sync::mpsc::UnboundedSender<DnsMessage>,
405 stop: tokio_util::sync::CancellationToken,
406) -> Result<()> {
407 let ifaces = if_addrs::get_if_addrs();
408 if let Ok(ifaces) = ifaces {
409 for iface in ifaces {
410 let stop_child = stop.child_token();
411 if !iface.ip().is_ipv6() {
412 continue;
413 }
414 if let Some(index) = iface.index {
415 let sender2 = sender.clone();
416 let label = label.to_owned();
417 tokio::spawn(async move {
418 let e = discoverv6(&label, qtype, index, sender2, stop_child).await;
419 if let Err(e) = e {
420 log::warn!("mdns discover error: {}", e);
421 }
422 });
423 }
424 }
425 };
426
427 let stop_child = stop.child_token();
428 let label = label.to_owned();
429 tokio::spawn(async move {
430 let e = discoverv4(&label, qtype, sender, stop_child).await;
431 if let Err(e) = e {
432 log::warn!("mdns discover error: {}", e);
433 }
434 });
435
436 Ok(())
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use std::collections::HashMap;
443
444 #[test]
445 fn compressed_single_label_matches_uncompressed() {
446 let label = "foo._tcp.local";
447 let mut plain = Vec::new();
448 encode_label(label, &mut plain).unwrap();
449
450 let mut compressed = Vec::new();
451 let mut offsets = HashMap::new();
452 encode_label_compressed(label, &mut compressed, &mut offsets).unwrap();
453
454 assert_eq!(plain, compressed);
455 }
456
457 #[test]
458 fn compressed_reuses_shared_suffix() {
459 let mut out = Vec::new();
460 let mut offsets = HashMap::new();
461
462 encode_label_compressed("foo._tcp.local", &mut out, &mut offsets).unwrap();
463 let first_len = out.len();
464
465 encode_label_compressed("bar._tcp.local", &mut out, &mut offsets).unwrap();
466 let second_len = out.len() - first_len;
467
468 assert_eq!(second_len, 6);
470
471 let ptr_hi = out[first_len + 4];
473 let ptr_lo = out[first_len + 5];
474 assert_eq!(ptr_hi & 0xC0, 0xC0, "top 2 bits must be set for pointer");
475
476 let ptr_offset = (((ptr_hi & 0x3F) as usize) << 8) | (ptr_lo as usize);
477 assert_eq!(ptr_offset, 4);
479 }
480
481 #[test]
482 fn compressed_output_decodable_by_read_label() {
483 let mut pkt = Vec::new();
485 let mut offsets = HashMap::new();
486
487 encode_label_compressed("foo._tcp.local", &mut pkt, &mut offsets).unwrap();
488 let second_start = pkt.len();
489 encode_label_compressed("bar._tcp.local", &mut pkt, &mut offsets).unwrap();
490
491 let label1 = read_label(&pkt, &mut Cursor::new(&pkt[..])).unwrap();
493 assert_eq!(label1, "foo._tcp.local.");
494
495 let label2 = read_label(&pkt, &mut Cursor::new(&pkt[second_start..])).unwrap();
497 assert_eq!(label2, "bar._tcp.local.");
498 }
499}