Skip to main content

matc/
discover.rs

1//! Module with very simple mdns based discovery of matter devices.
2//! Usually application shall discover devices using these methods and filter according discriminator.
3//! This module tries to send mdns using ipv4 and ipv6 multicast at same time.
4//! If more control over discovery mechanism is required, it may be better to use some external mdns library.
5
6use crate::{mdns::{self, DnsMessage}, mdns2};
7use anyhow::{Context, Result};
8use byteorder::ReadBytesExt;
9use std::{
10    collections::{BTreeMap, HashMap},
11    io::{Cursor, Read},
12    net::{IpAddr, Ipv4Addr, Ipv6Addr},
13    time::Duration,
14};
15use tokio_util::bytes::Buf;
16
17#[derive(Debug, Clone)]
18pub enum CommissioningMode {
19    No,
20    Yes,
21    WithPasscode,
22}
23
24#[derive(Debug, Clone)]
25pub struct MatterDeviceInfo {
26    pub instance: String,
27    pub device: String,
28    pub ips: Vec<IpAddr>,
29    pub name: Option<String>,
30    pub vendor_id: Option<String>,
31    pub product_id: Option<String>,
32    pub discriminator: Option<String>,
33    pub commissioning_mode: Option<CommissioningMode>,
34    pub pairing_hint: Option<String>,
35    pub source_ip: String,
36    pub port: Option<u16>,
37    /// MRP idle interval (SII TXT key, milliseconds)
38    pub session_idle_interval_ms: Option<u32>,
39    /// MRP active interval (SAI TXT key, milliseconds)
40    pub session_active_interval_ms: Option<u32>,
41    /// MRP active threshold (SAT TXT key, milliseconds)
42    pub session_active_threshold_ms: Option<u32>,
43}
44
45impl MatterDeviceInfo {
46    /// MRP timing parameters from the advertised SII/SAI/SAT values,
47    /// with spec defaults for missing keys.
48    pub fn mrp_params(&self) -> crate::mrp::MrpParameters {
49        crate::mrp::MrpParameters::from_txt_ms(
50            self.session_idle_interval_ms,
51            self.session_active_interval_ms,
52            self.session_active_threshold_ms,
53        )
54    }
55
56    pub fn print_compact(&self) {
57        let mut info = format!("{} ({})", self.instance, self.device);
58        if let Some(name) = &self.name {
59            info += &format!(", name: {}", name);
60        }
61        if let Some(vendor_id) = &self.vendor_id {
62            info += &format!(", vendor_id: {}", vendor_id);
63        }
64        if let Some(product_id) = &self.product_id {
65            info += &format!(", product_id: {}", product_id);
66        }
67        if let Some(discriminator) = &self.discriminator {
68            info += &format!(", discriminator: {}", discriminator);
69        }
70        if let Some(cm) = &self.commissioning_mode {
71            info += &format!(", commissioning_mode: {:?}", cm);
72        }
73        if let Some(pairing_hint) = &self.pairing_hint {
74            info += &format!(", pairing_hint: {}", pairing_hint);
75        }
76        if let Some(port) = &self.port {
77            info += &format!(", port: {}", port);
78        }
79        if let Some(sii) = &self.session_idle_interval_ms {
80            info += &format!(", sii_ms: {}", sii);
81        }
82        if let Some(sai) = &self.session_active_interval_ms {
83            info += &format!(", sai_ms: {}", sai);
84        }
85        println!("{}", info);
86        if !self.ips.is_empty() {
87            println!("  ips:");
88            for ip in &self.ips {
89                println!("      {}", ip);
90            }
91        }
92
93    }
94}
95
96
97pub fn parse_txt_records(data: &[u8]) -> Result<HashMap<String, String>> {
98    let mut cursor = Cursor::new(data);
99    let mut out = HashMap::new();
100    while cursor.remaining() > 0 {
101        let len = cursor.read_u8()?;
102        let mut buf = vec![0; len as usize];
103        cursor.read_exact(buf.as_mut_slice())?;
104        let splitstr = std::str::from_utf8(&buf)?.splitn(2, "=");
105        let x: Vec<&str> = splitstr.collect();
106        if x.len() == 2 {
107            out.insert(x[0].to_owned(), x[1].to_owned());
108        }
109    }
110    Ok(out)
111}
112
113/// Extract (SII, SAI, SAT) millisecond values from parsed TXT records.
114/// Unparseable values are ignored.
115fn parse_mrp_txt(rec: &HashMap<String, String>) -> (Option<u32>, Option<u32>, Option<u32>) {
116    let get = |key: &str| rec.get(key).and_then(|v| v.parse::<u32>().ok());
117    (get("SII"), get("SAI"), get("SAT"))
118}
119
120fn remove_string_suffix(string: &str, suffix: &str) -> String {
121    if let Some(s) = string.strip_suffix(suffix) {
122        s.to_owned()
123    } else {
124        string.to_owned()
125    }
126}
127
128pub fn to_matter_info2(msg: &DnsMessage, svc: &str) -> Result<Vec<MatterDeviceInfo>> {
129    let mut out = Vec::new();
130    let mut matter_service = false;
131    let svcname = ".".to_owned() + svc + ".";
132    for answer in &msg.answers {
133        if answer.name == svcname[1..] {
134            matter_service = true
135        }
136    }
137    if !matter_service {
138        return Err(anyhow::anyhow!("not matter service"));
139    }
140    let mut services = HashMap::new();
141    let mut targets = HashMap::new();
142    for additional in &msg.additional {
143        if additional.typ == mdns::TYPE_A {
144            let arr: [u8; 4] = match additional.rdata.clone().try_into() {
145                Ok(v) => v,
146                Err(_e) => return Err(anyhow::anyhow!("A record is not correct")),
147            };
148            let val = IpAddr::V4(Ipv4Addr::from_bits(u32::from_be_bytes(arr)));
149            if !targets.contains_key(&additional.name) {
150                targets.insert(additional.name.clone(), Vec::new());
151            }
152            targets.get_mut(&additional.name).unwrap().push(val);
153        }
154        if additional.typ == mdns::TYPE_AAAA {
155            let arr: [u8; 16] = match additional.rdata.clone().try_into() {
156                Ok(v) => v,
157                Err(_e) => return Err(anyhow::anyhow!("AAAA record is not correct")),
158            };
159            let val = IpAddr::V6(Ipv6Addr::from_bits(u128::from_be_bytes(arr)));
160            if !targets.contains_key(&additional.name) {
161                targets.insert(additional.name.clone(), Vec::new());
162            }
163            targets.get_mut(&additional.name).unwrap().push(val);
164        }
165    }
166    let mut all = msg.additional.to_vec();
167    all.append(&mut msg.answers.to_vec());
168    for additional in &all {
169        if additional.typ == mdns::TYPE_SRV {
170            let service_name = remove_string_suffix(&additional.name, &svcname);
171            if additional.rdata.len() < 6 {
172                continue;
173            }
174            let port = ((additional.rdata[4] as u16) << 8) | (additional.rdata[5] as u16);
175            let target_name = {
176                if let Some(at) = additional.target.as_ref() {
177                    at
178                } else {
179                    continue;
180                }
181            };
182            let target_ip = targets.get(target_name).cloned().unwrap_or_default();
183            let mi = MatterDeviceInfo {
184                instance: service_name.clone(),
185                device: remove_string_suffix(target_name, ".local.").to_owned(),
186                ips: target_ip,
187                name: None,
188                discriminator: None,
189                commissioning_mode: None,
190                pairing_hint: None,
191                source_ip: msg.source.to_string(),
192                vendor_id: None,
193                product_id: None,
194                port: Some(port),
195                session_idle_interval_ms: None,
196                session_active_interval_ms: None,
197                session_active_threshold_ms: None,
198            };
199            services.insert(service_name, mi);
200        }
201    }
202    for s in services.values() {
203        out.push(s.clone());
204    }
205
206    Ok(out)
207}
208
209pub fn to_matter_info(msg: &DnsMessage, svc: &str) -> Result<MatterDeviceInfo> {
210    let mut device = None;
211    let mut service = None;
212    let mut ips = BTreeMap::new();
213    let mut name = None;
214    let mut discriminator = None;
215    let mut cm = None;
216    let mut pairing_hint = None;
217    let mut vendor_id = None;
218    let mut product_id = None;
219    let mut port: Option<u16> = None;
220    let mut mrp = (None, None, None);
221
222    let mut matter_service = false;
223    let svcname = ".".to_owned() + svc + ".";
224    for answer in &msg.answers {
225        if answer.name == svcname[1..] {
226            matter_service = true
227        }
228    }
229    for additional in &msg.additional {
230        if additional.typ == mdns::TYPE_A {
231            let arr: [u8; 4] = match additional.rdata.clone().try_into() {
232                Ok(v) => v,
233                Err(_e) => return Err(anyhow::anyhow!("A record is not correct")),
234            };
235            let val = IpAddr::V4(Ipv4Addr::from_bits(u32::from_be_bytes(arr)));
236            ips.insert(val, true);
237            device = Some(remove_string_suffix(&additional.name, ".local."));
238        }
239        if additional.typ == mdns::TYPE_AAAA {
240            let arr: [u8; 16] = match additional.rdata.clone().try_into() {
241                Ok(v) => v,
242                Err(_e) => return Err(anyhow::anyhow!("AAAA record is not correct")),
243            };
244            let val = IpAddr::V6(Ipv6Addr::from_bits(u128::from_be_bytes(arr)));
245            ips.insert(val, true);
246            device = Some(remove_string_suffix(&additional.name, ".local."));
247        }
248        if additional.typ == mdns::TYPE_SRV {
249            service = Some(remove_string_suffix(&additional.name, &svcname));
250            if additional.rdata.len() >= 6 {
251                port = Some(((additional.rdata[4] as u16) << 8) | (additional.rdata[5] as u16))
252            }
253        }
254        if additional.typ == mdns::TYPE_TXT {
255            let rec = parse_txt_records(&additional.rdata)?;
256            name = rec.get("DN").cloned();
257            discriminator = rec.get("D").cloned();
258            pairing_hint = rec.get("PH").cloned();
259            mrp = parse_mrp_txt(&rec);
260            if let Some(vp) = rec.get("VP") {
261                let mut split = vp.split("+");
262                vendor_id = split.next().map(str::to_owned);
263                product_id = split.next().map(str::to_owned);
264            }
265            cm = match rec.get("CM") {
266                Some(v) => match v.as_str() {
267                    "0" => Some(CommissioningMode::No),
268                    "1" => Some(CommissioningMode::Yes),
269                    "2" => Some(CommissioningMode::WithPasscode),
270                    _ => None,
271                },
272                None => None,
273            };
274        }
275    }
276
277    if !matter_service {
278        return Err(anyhow::anyhow!("not matter service"));
279    }
280
281    Ok(MatterDeviceInfo {
282        instance: service.context("service name not detected")?,
283        device: device.context("device name not detected")?,
284        ips: ips.into_keys().collect(),
285        name,
286        discriminator,
287        commissioning_mode: cm,
288        pairing_hint,
289        source_ip: msg.source.to_string(),
290        vendor_id,
291        product_id,
292        port,
293        session_idle_interval_ms: mrp.0,
294        session_active_interval_ms: mrp.1,
295        session_active_threshold_ms: mrp.2,
296    })
297}
298
299async fn discover_common(timeout: Duration, svc_type: &str) -> Result<Vec<MatterDeviceInfo>> {
300    let stop = tokio_util::sync::CancellationToken::new();
301    let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<DnsMessage>();
302
303    mdns::discover(svc_type, mdns::QTYPE_ANY, sender, stop.child_token()).await?;
304
305    tokio::spawn(async move {
306        tokio::time::sleep(timeout).await;
307        stop.cancel();
308    });
309    let mut cache = HashMap::new();
310    let mut out = Vec::new();
311    while let Some(dns) = receiver.recv().await {
312        if cache.contains_key(&dns) {
313            continue;
314        }
315        let info = match to_matter_info(&dns, svc_type) {
316            Ok(info) => info,
317            Err(_) => continue,
318        };
319        out.push(info);
320        cache.insert(dns, true);
321    }
322    Ok(out)
323}
324
325/// Discover commissionable devices using mdns
326pub async fn discover_commissionable(timeout: Duration) -> Result<Vec<MatterDeviceInfo>> {
327    discover_common(timeout, "_matterc._udp.local").await
328}
329
330/// Discover commissioned devices using mdns
331pub async fn discover_commissioned(timeout: Duration) -> Result<Vec<MatterDeviceInfo>> {
332    discover_common(timeout, "_matter._tcp.local").await
333}
334
335
336async fn discover_common2(timeout: Duration, svc_type: &str) -> Result<Vec<MatterDeviceInfo>> {
337    let stop = tokio_util::sync::CancellationToken::new();
338    let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<DnsMessage>();
339
340    mdns::discover(svc_type, mdns::QTYPE_ANY, sender, stop.child_token()).await?;
341
342    tokio::spawn(async move {
343        tokio::time::sleep(timeout).await;
344        stop.cancel();
345    });
346    let mut cache = HashMap::new();
347    let mut out: Vec<MatterDeviceInfo> = Vec::new();
348    while let Some(dns) = receiver.recv().await {
349        if cache.contains_key(&dns) {
350            continue;
351        }
352        let info = match to_matter_info2(&dns, svc_type) {
353            Ok(info) => info,
354            Err(e) => {
355                log::trace!("failed to parse mdns message from {}: {:?}", dns.source, e);
356                continue;
357            },
358        };
359        for i in &info {
360            out.push(i.clone());
361        }
362        cache.insert(dns, true);
363    }
364    Ok(out)
365}
366
367/// Discover commissionable devices using mdns
368pub async fn discover_commissionable2(timeout: Duration) -> Result<Vec<MatterDeviceInfo>> {
369    discover_common2(timeout, "_matterc._udp.local").await
370}
371
372/// Discover commissioned devices using mdns
373pub async fn discover_commissioned2(timeout: Duration, device: &Option<String>) -> Result<Vec<MatterDeviceInfo>> {
374    let query = {
375        match device {
376            None => "_matter._tcp.local".to_owned(),
377            Some(d) => format!("{}._matter._tcp.local", d),
378        }
379    };
380    discover_common2(timeout, &query).await
381}
382
383
384
385/// Discover the first device matching a predicate.
386///
387/// Subscribes to the broadcast channel, sends `query` as an active mDNS lookup, then
388/// drains events until one matching `service_name` passes `predicate`. Lag events (dropped
389/// due to buffer overflow) are logged and skipped; discovery continues normally.
390///
391/// `predicate` receives the full instance target string and the parsed `MatterDeviceInfo`.
392pub async fn discover_one<F>(
393    mdns: &mdns2::MdnsService,
394    query: &str,
395    service_name: &str,
396    timeout: Duration,
397    predicate: F,
398) -> Result<(String, MatterDeviceInfo)>
399where
400    F: Fn(&str, &MatterDeviceInfo) -> bool,
401{
402    let mut rx = mdns.subscribe();
403    mdns.active_lookup(query, mdns::QTYPE_ANY).await;
404    let deadline = std::time::Instant::now() + timeout;
405    loop {
406        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
407        if remaining.is_zero() {
408            anyhow::bail!("mDNS discovery timeout for {}", query);
409        }
410        match tokio::time::timeout(remaining, rx.recv()).await {
411            Err(_) => anyhow::bail!("mDNS discovery timeout for {}", query),
412            Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(n))) => {
413                log::warn!("mDNS discovery: dropped {} events due to lag, continuing", n);
414            }
415            Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
416                anyhow::bail!("mDNS service closed");
417            }
418            Ok(Ok(mdns2::MdnsEvent::ServiceExpired { .. })) => {}
419            Ok(Ok(mdns2::MdnsEvent::ServiceDiscovered { name, target, .. })) => {
420                if name != service_name {
421                    continue;
422                }
423                let info = match extract_matter_info(&target, mdns).await {
424                    Ok(i) => i,
425                    Err(e) => {
426                        log::debug!("failed to extract Matter info from {}: {}", target, e);
427                        continue;
428                    }
429                };
430                if predicate(&target, &info) {
431                    return Ok((target, info));
432                }
433            }
434        }
435    }
436}
437
438/// Discover all matching devices until the timeout expires.
439///
440/// Like [`discover_one`] but collects every device whose `ServiceDiscovered` event
441/// matches `service_name` and for which `extract_matter_info` succeeds, until `timeout`
442/// elapses. Returns an empty `Vec` if no devices are found (not an error).
443pub async fn discover_all(
444    mdns: &mdns2::MdnsService,
445    query: &str,
446    service_name: &str,
447    timeout: Duration,
448) -> Result<Vec<(String, MatterDeviceInfo)>> {
449    let mut rx = mdns.subscribe();
450    mdns.active_lookup(query, mdns::QTYPE_ANY).await;
451    let deadline = std::time::Instant::now() + timeout;
452    let mut out = Vec::new();
453    loop {
454        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
455        if remaining.is_zero() {
456            break;
457        }
458        match tokio::time::timeout(remaining, rx.recv()).await {
459            Err(_) => break,
460            Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(n))) => {
461                log::warn!("mDNS discover_all: dropped {} events due to lag, continuing", n);
462            }
463            Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
464            Ok(Ok(mdns2::MdnsEvent::ServiceExpired { .. })) => {}
465            Ok(Ok(mdns2::MdnsEvent::ServiceDiscovered { name, target, .. })) => {
466                if name != service_name {
467                    continue;
468                }
469                match extract_matter_info(&target, mdns).await {
470                    Ok(info) => out.push((target, info)),
471                    Err(e) => {
472                        log::debug!("failed to extract Matter info from {}: {}", target, e);
473                    }
474                }
475            }
476        }
477    }
478    Ok(out)
479}
480
481pub async fn extract_matter_info(target: &str, mdns: &mdns2::MdnsService) -> Result<MatterDeviceInfo> {
482    let txt_records = mdns.lookup(target, mdns::TYPE_TXT).await;
483    let mut txt_info = HashMap::new();
484    for txt_rr in txt_records {
485        txt_info.extend(parse_txt_records(&txt_rr.rdata)?);
486    }
487    let srv_records = mdns.lookup(target, mdns::TYPE_SRV).await;
488    let srv_rr = srv_records.first().ok_or_else(|| anyhow::anyhow!("No SRV record found for {}", target))?;
489    let (srv_target, port) = match srv_rr.data {
490        mdns::RRData::SRV { ref target, port, .. } => (target.clone(), port),
491        _ => return Err(anyhow::anyhow!("Invalid SRV record for {}", target)),
492    };
493    let mut ips = Vec::new();
494    let a_records = mdns.lookup(&srv_target, mdns::TYPE_A).await;
495    for a_rr in a_records {
496        if let mdns::RRData::A(ip) = a_rr.data {
497            ips.push(ip.into());
498        }
499    }
500    let aaaa_records = mdns.lookup(&srv_target, mdns::TYPE_AAAA).await;
501    for aaaa_rr in aaaa_records {
502        if let mdns::RRData::AAAA(ip) = aaaa_rr.data {
503            ips.push(ip.into());
504        }
505    }
506    let (vendor_id, product_id) = {
507        let vp = txt_info.get("VP");
508        if let Some(vp) = vp {
509            let mut parts = vp.split('+');
510            let vendor_id = parts.next();
511            let product_id = parts.next();
512            (vendor_id.map(|v| v.to_owned()), product_id.map(|p| p.to_owned()))
513        } else {
514            (None, None)
515        }
516    };
517    let discriminator = txt_info.get("D").cloned();
518    let name = txt_info.get("DN").cloned();
519    let commissioning_mode = match txt_info.get("CM") {
520                Some(v) => match v.as_str() {
521                    "0" => Some(CommissioningMode::No),
522                    "1" => Some(CommissioningMode::Yes),
523                    "2" => Some(CommissioningMode::WithPasscode),
524                    _ => None,
525                },
526                None => None,
527            };
528    let pairing_hint = txt_info.get("PH").cloned();
529    let (sii, sai, sat) = parse_mrp_txt(&txt_info);
530    Ok(MatterDeviceInfo {
531        name,
532        instance: target.trim_end_matches('.').to_owned(),
533        device: srv_target.trim_end_matches('.').to_owned(),
534        ips,
535        vendor_id,
536        product_id,
537        discriminator,
538        commissioning_mode,
539        pairing_hint,
540        source_ip: "".to_owned(),
541        port: Some(port),
542        session_idle_interval_ms: sii,
543        session_active_interval_ms: sai,
544        session_active_threshold_ms: sat,
545    })
546}
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    fn txt_rdata(entries: &[&str]) -> Vec<u8> {
552        let mut out = Vec::new();
553        for e in entries {
554            out.push(e.len() as u8);
555            out.extend_from_slice(e.as_bytes());
556        }
557        out
558    }
559
560    #[test]
561    fn test_parse_mrp_txt() {
562        let rec = parse_txt_records(&txt_rdata(&["SII=5000", "SAI=300", "SAT=4000", "D=840"]))
563            .unwrap();
564        assert_eq!(parse_mrp_txt(&rec), (Some(5000), Some(300), Some(4000)));
565
566        let rec = parse_txt_records(&txt_rdata(&["SII=abc", "D=840"])).unwrap();
567        assert_eq!(parse_mrp_txt(&rec), (None, None, None));
568
569        let rec = parse_txt_records(&txt_rdata(&["D=840"])).unwrap();
570        assert_eq!(parse_mrp_txt(&rec), (None, None, None));
571    }
572}