Skip to main content

matc/
controller.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use crate::{
4    active_connection::{ActiveConnection, Exchange},
5    cert_matter, certmanager, commission, fabric, im,
6    messages::{self, Message},
7    retransmit, session, sigma, spake2p,
8    tlv::TlvItemValue,
9    transport::{self, ConnectionTrait},
10    util::cryptoutil,
11};
12use anyhow::{Context, Result};
13use tokio::sync::mpsc;
14use byteorder::{LittleEndian, WriteBytesExt};
15
16pub struct Controller {
17    certmanager: Arc<dyn certmanager::CertManager>,
18    #[allow(dead_code)]
19    transport: Arc<transport::Transport>,
20    fabric: fabric::Fabric,
21    /// In-memory CASE session resumption records keyed by peer node ID.
22    resumption: Arc<tokio::sync::Mutex<HashMap<u64, sigma::ResumptionRecord>>>,
23}
24
25pub struct Connection {
26    active: ActiveConnection,
27}
28//trait IsSync: Sync {}
29//impl IsSync for Controller {}
30
31const CA_ID: u64 = 1;
32
33#[derive(Debug, Clone, Copy)]
34pub struct SigmaBusy {
35    pub wait_ms: Option<u32>,
36}
37impl std::fmt::Display for SigmaBusy {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self.wait_ms {
40            Some(ms) => write!(f, "responder BUSY (min wait {} ms)", ms),
41            None => write!(f, "responder BUSY"),
42        }
43    }
44}
45impl std::error::Error for SigmaBusy {}
46
47impl Controller {
48    pub fn new(
49        certmanager: &Arc<dyn certmanager::CertManager>,
50        transport: &Arc<transport::Transport>,
51        fabric_id: u64,
52    ) -> Result<Arc<Self>> {
53        let fabric = fabric::Fabric::new(
54            fabric_id,
55            CA_ID,
56            &certmanager.get_ca_public_key()?,
57            &certmanager.get_ipk_epoch_key(),
58        );
59        Ok(Arc::new(Self {
60            certmanager: certmanager.clone(),
61            transport: transport.clone(),
62            fabric,
63            resumption: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
64        }))
65    }
66
67    /// commission device
68    /// - authenticate using pin
69    /// - push CA certificate to device
70    /// - sign device's certificate
71    /// - set controller id - user which can control device
72    /// - return authenticated connection which can be used to send additional commands
73    pub async fn commission(
74        &self,
75        connection: &Arc<dyn ConnectionTrait>,
76        pin: u32,
77        node_id: u64,
78        controller_id: u64,
79    ) -> Result<Connection> {
80        let mut session = auth_spake(connection.as_ref(), pin).await?;
81        let session = commission::commission(
82            connection.as_ref(),
83            &mut session,
84            &self.fabric,
85            self.certmanager.as_ref(),
86            node_id,
87            controller_id,
88        )
89        .await?;
90        Ok(Connection {
91            active: ActiveConnection::new(connection.clone(), session),
92        })
93    }
94
95    /// create authenticated connection to control device
96    pub async fn auth_sigma(
97        &self,
98        connection: &Arc<dyn ConnectionTrait>,
99        node_id: u64,
100        controller_id: u64,
101    ) -> Result<Connection> {
102        let (session, resumption) = auth_sigma(
103            connection.as_ref(),
104            &self.fabric,
105            self.certmanager.as_ref(),
106            node_id,
107            controller_id,
108        )
109        .await?;
110        if let Some(record) = resumption {
111            self.resumption.lock().await.insert(node_id, record);
112        }
113        Ok(Connection {
114            active: ActiveConnection::new(connection.clone(), session),
115        })
116    }
117
118    /// Run auth_sigma with automatic BUSY retry.
119    /// Attempts a CASE session resumption first; falls back to full SIGMA on failure.
120    /// Returns only the Session so that both initial connect and in-place reauth can use it.
121    pub async fn auth_sigma_with_busy_retry(
122        &self,
123        connection: &Arc<dyn ConnectionTrait>,
124        node_id: u64,
125        controller_id: u64,
126    ) -> Result<session::Session> {
127        if let Some(ses) = self.try_auth_sigma_resume(connection, node_id, controller_id).await? {
128            return Ok(ses);
129        }
130
131        const MAX_BUSY_RETRIES: u32 = 5;
132        const DEFAULT_BUSY_WAIT: Duration = Duration::from_millis(3000);
133        const MAX_BUSY_WAIT: Duration = Duration::from_secs(60);
134
135        let mut busy_retries = 0u32;
136        loop {
137            match auth_sigma(connection.as_ref(), &self.fabric, self.certmanager.as_ref(), node_id, controller_id).await {
138                Ok((ses, resumption)) => {
139                    if let Some(record) = resumption {
140                        self.resumption.lock().await.insert(node_id, record);
141                    }
142                    return Ok(ses);
143                }
144                Err(e) => {
145                    if let Some(busy) = e.downcast_ref::<SigmaBusy>() {
146                        if busy_retries < MAX_BUSY_RETRIES {
147                            let wait = busy.wait_ms
148                                .map(|ms| Duration::from_millis(ms.into()))
149                                .unwrap_or(DEFAULT_BUSY_WAIT)
150                                .min(MAX_BUSY_WAIT);
151                            log::info!(
152                                "CASE responder BUSY, waiting {:?} before retry ({}/{})",
153                                wait, busy_retries + 1, MAX_BUSY_RETRIES
154                            );
155                            tokio::time::sleep(wait).await;
156                            busy_retries += 1;
157                            continue;
158                        }
159                        return Err(e).context(format!(
160                            "still BUSY after {} retries", MAX_BUSY_RETRIES
161                        ));
162                    }
163                    return Err(e);
164                }
165            }
166        }
167    }
168
169    async fn try_auth_sigma_resume(
170        &self,
171        connection: &Arc<dyn ConnectionTrait>,
172        node_id: u64,
173        controller_id: u64,
174    ) -> Result<Option<session::Session>> {
175        let record = {
176            let map = self.resumption.lock().await;
177            map.get(&node_id).cloned()
178        };
179        let record = match record {
180            Some(r) => r,
181            None => return Ok(None),
182        };
183
184        let exchange: u16 = rand::random();
185        let session = session::Session::new();
186        let mut retrctx = retransmit::RetrContext::new(connection.as_ref(), &session);
187        retrctx.subscribe_exchange(exchange);
188
189        let mut ctx = sigma::SigmaContext::new(node_id);
190        let ca_pubkey = self.certmanager.get_ca_key()?.public_key().to_sec1_bytes();
191        sigma::sigma1_resume(&self.fabric, &mut ctx, &ca_pubkey, &record)?;
192        let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
193
194        log::debug!("CASE resume: send Sigma1Resume exchange:{}", exchange);
195        retrctx.send(&s1).await?;
196
197        let sigma2 = retrctx.get_next_message().await?;
198
199        // Responder sent a status report instead of Sigma2 / Sigma2Resume - this includes
200        // Fall back to full SIGMA in all cases.
201        if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
202            && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
203        {
204            log::debug!(
205                "CASE resume: responder rejected with status report, falling back to full SIGMA (exchange:{} {:?})",
206                exchange,
207                sigma2.status_report_info
208            );
209            return Ok(None);
210        }
211
212        if !sigma::is_sigma2_resume(&sigma2.payload) {
213            // Responder gracefully fell back to full Sigma2 - evict the stale record so
214            // the next reconnect does a fresh full SIGMA.
215            log::debug!("CASE resume: responder sent full Sigma2, falling back");
216            self.resumption.lock().await.remove(&node_id);
217            return Ok(None);
218        }
219
220        let parsed = match sigma::parse_sigma2_resume(&sigma2.payload) {
221            Ok(p) => p,
222            Err(e) => {
223                log::debug!("CASE resume: malformed Sigma2Resume ({:?}), falling back to full SIGMA", e);
224                self.resumption.lock().await.remove(&node_id);
225                return Ok(None);
226            }
227        };
228
229        if let Err(e) = sigma::verify_sigma2_resume_mic(
230            &record.shared_secret,
231            &ctx.initiator_random,
232            &parsed.new_resumption_id,
233            &parsed.sigma2_resume_mic,
234        ) {
235            log::debug!("CASE resume: MIC verification failed: {:?}, falling back to full SIGMA", e);
236            self.resumption.lock().await.remove(&node_id);
237            return Ok(None);
238        }
239
240        let sr = messages::status_report_success(exchange)?;
241        if let Err(e) = retrctx.send(&sr).await {
242            log::debug!("CASE resume: failed to send StatusReport ({:?}), falling back to full SIGMA", e);
243            self.resumption.lock().await.remove(&node_id);
244            return Ok(None);
245        }
246
247        let keypack = sigma::derive_resumed_session_keys(
248            &record.shared_secret,
249            &ctx.initiator_random,
250            &record.resumption_id,
251        )?;
252
253        let mut ses = session::Session::new();
254        ses.session_id = parsed.responder_session_id;
255        ses.my_session_id = ctx.session_id;
256        ses.set_decrypt_key(&keypack[16..32]);
257        ses.set_encrypt_key(&keypack[..16]);
258
259        let mut local_node = Vec::new();
260        local_node.write_u64::<LittleEndian>(controller_id)?;
261        ses.local_node = Some(local_node);
262
263        let mut remote_node = Vec::new();
264        remote_node.write_u64::<LittleEndian>(node_id)?;
265        ses.remote_node = Some(remote_node);
266
267        // Rotate the resumption ID to the one the responder issued for the next round.
268        {
269            let mut map = self.resumption.lock().await;
270            if let Some(entry) = map.get_mut(&node_id) {
271                entry.resumption_id = parsed.new_resumption_id;
272            }
273        }
274
275        log::info!("CASE session resumed for node_id={}", node_id);
276        Ok(Some(ses))
277    }
278
279    /// Commission a device that is advertising over BLE.
280    ///
281    /// 1. Scans for a commissionable BLE device with the given `discriminator`.
282    /// 2. Runs PASE over BTP (BLE transport protocol).
283    /// 3. Pushes the CA cert, signs the device cert (AddNOC).
284    /// 4. Sends ArmFailSafe + SetRegulatoryConfig.
285    /// 5. Optionally provisions network credentials (Wi-Fi / Thread).
286    /// 6. Drops the BLE connection.
287    /// 7. Discovers the device on the IP network via mDNS.
288    /// 8. Establishes CASE + sends CommissioningComplete over UDP.
289    /// 9. Returns an authenticated [`Connection`] ready for commands.
290    ///
291    /// Requires the `ble` Cargo feature.
292    #[cfg(feature = "ble")]
293    pub async fn commission_ble(
294        &self,
295        discriminator: u16,
296        short_discriminator: bool,
297        pin: u32,
298        node_id: u64,
299        controller_id: u64,
300        network_creds: commission::NetworkCreds,
301        mdns: &std::sync::Arc<crate::mdns2::MdnsService>,
302    ) -> Result<Connection> {
303        use crate::btp::BtpConnection;
304
305        // 1. BLE scan + GATT connect + BTP handshake
306        let peripheral = crate::ble::find_by_discriminator(discriminator, short_discriminator, std::time::Duration::from_secs(30))
307            .await
308            .context("BLE scan")?;
309        log::debug!("BLE device found: z2");
310        let btp_conn = BtpConnection::connect(peripheral).await.context("BTP connect")?;
311
312        // 2. PASE
313        let mut pase_session = auth_spake(btp_conn.as_ref(), pin).await.context("PASE over BLE")?;
314
315        // 3. BLE-side commissioning phase
316        commission::commission_ble_phase(
317            btp_conn.as_ref(),
318            &mut pase_session,
319            &self.fabric,
320            self.certmanager.as_ref(),
321            node_id,
322            controller_id,
323            &network_creds,
324        )
325        .await
326        .context("BLE commissioning phase")?;
327        tokio::time::sleep(std::time::Duration::from_secs(5)).await; // wait for device to finish BLE-side commissioning before dropping connection
328
329        // 4. Drop BTP (BLE connection closes when btp_conn is dropped)
330        drop(btp_conn);
331
332        // 5 + 6. Rediscover via operational mDNS and commission over UDP
333        for attempt in 0..5 {
334            let addresses = match self.discover_operational_addresses(node_id, mdns).await {
335                Ok(a) => a,
336                Err(e) => {
337                    log::debug!("mDNS discovery failed (attempt {}/{}): {:?}", attempt + 1, 5, e);
338                    continue;
339                }
340            };
341            for address in &addresses {
342                log::debug!("Trying to commission over UDP at {}... (attempt {}/{})", address, attempt + 1, 5);
343                let udp_conn = self.transport.create_connection(&address).await;
344                let ses = commission::commissioning_complete_udp(
345                    udp_conn.as_ref(),
346                    self.certmanager.as_ref(),
347                    node_id,
348                    controller_id,
349                    &self.fabric,
350                )
351                .await;
352                if let Ok(ses) = ses {
353                    return Ok(Connection {
354                        active: ActiveConnection::new(udp_conn, ses),
355                    });
356                } else {
357                    log::debug!("Failed to commission over UDP at {}: {:?}", address, ses.err());
358                }
359            }
360        }
361        Err(anyhow::anyhow!("failed to commission device over UDP at any discovered address"))
362    }
363
364    #[cfg(feature = "ble")]
365    async fn discover_operational_addresses(
366        &self,
367        node_id: u64,
368        mdns: &std::sync::Arc<crate::mdns2::MdnsService>,
369    ) -> Result<Vec<String>> {
370        use crate::discover;
371
372        let ca_pubkey = self.certmanager.get_ca_public_key()?;
373        let fabric_tmp = fabric::Fabric::new(self.fabric.id, 0, &ca_pubkey, &self.certmanager.get_ipk_epoch_key());
374        let compressed = fabric_tmp.compressed().context("compressed fabric ID")?;
375        let instance = format!("{}-{:016X}", hex::encode_upper(&compressed), node_id);
376        let expected_target = format!("{}._matter._tcp.local.", instance);
377
378        log::debug!("Discovering operational device via mDNS with target {}", expected_target);
379        let (_, info) = discover::discover_one(
380            mdns,
381            "_matter._tcp.local",
382            "_matter._tcp.local.",
383            std::time::Duration::from_secs(120),
384            move |target, _| target == expected_target,
385        ).await.context(format!("operational mDNS timeout for {}", instance))?;
386        log::debug!("Operational mDNS discovered device: {:?}", info);
387
388        let port = info.port.unwrap_or(5540);
389        let addresses: Vec<String> = info
390            .ips
391            .iter()
392            .map(|ip| crate::discover::addr_string(ip, port, info.scope_id))
393            .collect();
394        log::info!("Device discovered at {}", addresses.join(", "));
395        Ok(addresses)
396    }
397}
398
399/// Authenticated virtual connection can be used to send commands to device.
400impl Connection {
401    /// Build a Connection from a transport-layer connection and an established session.
402    pub(crate) fn from_parts(conn: Arc<dyn ConnectionTrait>, session: session::Session) -> Self {
403        Self { active: ActiveConnection::new(conn, session) }
404    }
405
406    /// Read attribute from device and return parsed matter protocol response.
407    pub async fn read_request(
408        &self,
409        endpoint: u16,
410        cluster: u32,
411        attr: u32,
412    ) -> Result<Message> {
413        let exchange: u16 = rand::random();
414        let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
415        self.active.request(exchange, &msg).await
416    }
417
418    /// Read attribute from device and return tlv with attribute value.
419    /// Reassembles chunked reports (MoreChunkedMessages) transparently.
420    pub async fn read_request2(
421        &self,
422        endpoint: u16,
423        cluster: u32,
424        attr: u32,
425    ) -> Result<TlvItemValue> {
426        let exchange: u16 = rand::random();
427        let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
428        let mut ex = self.active.open_exchange(exchange);
429        ex.send(&msg).await?;
430        let report = self.collect_reports(&mut ex).await?;
431        let first = report
432            .attribute_reports
433            .into_iter()
434            .next()
435            .context("report data contains no attribute reports")?;
436        match first.data {
437            im::AttributeData::Value(v) => Ok(v),
438            im::AttributeData::Status { status, .. } => {
439                Err(anyhow::anyhow!("report data with status {}", status))
440            }
441        }
442    }
443
444    /// Receive ReportData chunks on the exchange until the last chunk,
445    /// sending the IM StatusResponse between chunks as required, and return
446    /// the merged report. The final StatusResponse is only sent when the
447    /// device did not set SuppressResponse (e.g. subscribe priming reports).
448    async fn collect_reports(&self, exchange: &mut Exchange<'_>) -> Result<im::ReportData> {
449        let mut merged: Option<im::ReportData> = None;
450        loop {
451            let msg = exchange.recv().await?;
452            if let Some(status) = &msg.status_report_info {
453                return Err(anyhow::anyhow!(
454                    "status report while waiting for report data: {:?}",
455                    status
456                ));
457            }
458            if msg.protocol_header.protocol_id
459                != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
460                || msg.protocol_header.opcode
461                    != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA
462            {
463                return Err(anyhow::anyhow!(
464                    "response is not expected report_data {:?}",
465                    msg.protocol_header
466                ));
467            }
468            let report = im::ReportData::parse(&msg.tlv)?;
469            let more = report.more_chunks;
470            let respond = more || !report.suppress_response;
471            match merged.as_mut() {
472                Some(m) => m.merge(report),
473                None => merged = Some(report),
474            }
475            if respond {
476                let flags = messages::im_status_flags_for(msg.protocol_header.exchange_flags);
477                let resp = messages::im_status_response(
478                    exchange.id,
479                    flags,
480                    msg.message_header.message_counter,
481                )?;
482                exchange.send(&resp).await?;
483            }
484            if !more {
485                return merged.context("no report data received");
486            }
487        }
488    }
489
490    /// Invoke command
491    pub async fn invoke_request(
492        &self,
493        endpoint: u16,
494        cluster: u32,
495        command: u32,
496        payload: &[u8],
497    ) -> Result<Message> {
498        let exchange: u16 = rand::random();
499        log::debug!(
500            "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
501            exchange,
502            endpoint,
503            cluster,
504            command
505        );
506        let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
507        self.active.request(exchange, &msg).await
508    }
509
510    /// Invoke command and return result TLV
511    pub async fn invoke_request2(
512        &self,
513        endpoint: u16,
514        cluster: u32,
515        command: u32,
516        payload: &[u8],
517    ) -> Result<TlvItemValue> {
518        let res = self.invoke_request(endpoint, cluster, command, payload).await?;
519        let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
520        Ok(o.clone())
521    }
522
523    pub async fn write_request(
524        &self,
525        endpoint: u16,
526        cluster: u32,
527        attr: u32,
528        payload: &[u8],
529    ) -> Result<()> {
530        let exchange: u16 = rand::random();
531        log::debug!(
532            "write_request exch:{} endpoint:{} cluster:{} attr:{}",
533            exchange,
534            endpoint,
535            cluster,
536            attr,
537        );
538
539        let msg = messages::im_write_request(endpoint, cluster, attr, exchange, payload)?;
540        let res = self.active.request(exchange, &msg).await?;
541        if res.status_report_info.is_some() {
542            return Err(anyhow::anyhow!(
543                "write_request failed with status {:?}",
544                res.status_report_info
545            ))
546        };
547        if res.protocol_header.protocol_id
548            == messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
549            && res.protocol_header.opcode
550                == messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
551        {
552            let stat = res
553                .tlv
554                .get_int(&[0])
555                .context("status not found in status response")?;
556            res.tlv.dump(1);
557            return Err(anyhow::anyhow!(
558                "response is not expected status_resp 0x{:x}",
559                stat
560            ))
561        };
562        if res.protocol_header.protocol_id
563            != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
564            || res.protocol_header.opcode
565                != messages::ProtocolMessageHeader::INTERACTION_OPCODE_WRITE_RESP
566        {
567            return Err(anyhow::anyhow!(
568                "response is not expected write_resp {:?}",
569                res.protocol_header
570            ))
571        };
572        let stat = res.tlv.get_int(&[0, 0, 1, 0]).context("status not found in write response")?;
573        if stat != 0 {
574            return Err(anyhow::anyhow!("write failed with status 0x{:x}", stat));
575        }
576        Ok(())
577    }
578
579    /// Subscribe to attribute changes. `None` path fields act as wildcards.
580    /// Set `keep_subscriptions = true` when adding a second subscription on the same
581    /// connection so the device does not cancel the first one.
582    ///
583    /// Handles the full subscribe transaction (chunked priming report, IM
584    /// StatusResponse acks, SubscribeResponse) and returns a [Subscription]
585    /// delivering decoded updates; updates are acked automatically by the
586    /// background read loop.
587    pub async fn subscribe_attrs(
588        &self,
589        endpoint: Option<u16>,
590        cluster: Option<u32>,
591        attr: Option<u32>,
592        keep_subscriptions: bool,
593    ) -> Result<Subscription> {
594        let exchange: u16 = rand::random();
595        log::debug!(
596            "subscribe_attrs exch:{} endpoint:{:?} cluster:{:?} attr:{:?} keep:{}",
597            exchange, endpoint, cluster, attr, keep_subscriptions
598        );
599        let msg = messages::im_subscribe_request_attr(endpoint, cluster, attr, exchange, keep_subscriptions)?;
600        self.subscribe_internal(exchange, &msg).await
601    }
602
603    /// Subscribe to events. `None` path fields act as wildcards.
604    /// See [Connection::subscribe_attrs] for transaction details.
605    pub async fn subscribe_events(
606        &self,
607        endpoint: Option<u16>,
608        cluster: Option<u32>,
609        event: Option<u32>,
610        keep_subscriptions: bool,
611    ) -> Result<Subscription> {
612        let exchange: u16 = rand::random();
613        log::debug!(
614            "subscribe_events exch:{} endpoint:{:?} cluster:{:?} event:{:?} keep:{}",
615            exchange, endpoint, cluster, event, keep_subscriptions
616        );
617        let msg = messages::im_subscribe_request_event(endpoint, cluster, event, exchange, keep_subscriptions)?;
618        self.subscribe_internal(exchange, &msg).await
619    }
620
621    async fn subscribe_internal(&self, exchange_id: u16, msg: &[u8]) -> Result<Subscription> {
622        let mut exchange = self.active.open_exchange(exchange_id);
623        exchange.send(msg).await?;
624        let priming = self.collect_reports(&mut exchange).await?;
625        let subscription_id = priming
626            .subscription_id
627            .context("priming report missing subscription id")?;
628
629        // Register before awaiting the SubscribeResponse so no update can be
630        // missed; the device cannot report before the transaction completes.
631        let rx = self.active.register_subscription(subscription_id);
632        let registry = self.active.subscriptions_handle();
633
634        let response = async {
635            let resp = exchange.recv().await?;
636            if resp.protocol_header.protocol_id
637                != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
638                || resp.protocol_header.opcode
639                    != messages::ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_RESP
640            {
641                anyhow::bail!(
642                    "response is not expected subscribe_resp {:?}",
643                    resp.protocol_header
644                );
645            }
646            let sr = im::SubscribeResponse::parse(&resp.tlv)?;
647            if sr.subscription_id != subscription_id {
648                anyhow::bail!(
649                    "subscribe response id {} does not match priming report id {}",
650                    sr.subscription_id,
651                    subscription_id
652                );
653            }
654            Ok(sr)
655        }
656        .await;
657
658        match response {
659            Ok(sr) => Ok(Subscription {
660                subscription_id,
661                max_interval: sr.max_interval,
662                priming_attribute_reports: priming.attribute_reports,
663                priming_event_reports: priming.event_reports,
664                rx,
665                registry,
666            }),
667            Err(e) => {
668                registry.lock().unwrap().remove(&subscription_id);
669                Err(e)
670            }
671        }
672    }
673
674    /// Cancel all subscriptions on this session by sending a SubscribeRequest with
675    /// `KeepSubscriptions = false` and no paths. The device drops all prior subscriptions.
676    pub async fn im_unsubscribe_all(&self) -> Result<Message> {
677        let exchange: u16 = rand::random();
678        log::debug!("im_unsubscribe_all exch:{}", exchange);
679        let msg = messages::im_unsubscribe_all(exchange)?;
680        self.active.request(exchange, &msg).await
681    }
682
683    /// Enable or disable automatic IM StatusResponse replies to unsolicited
684    /// ReportData (enabled by default). Disable only when acking reports
685    /// manually via the raw message API.
686    pub fn set_auto_status_response(&self, enabled: bool) {
687        self.active.set_auto_status_response(enabled);
688    }
689
690    /// Invoke command with timed interaction
691    pub async fn invoke_request_timed(
692        &self,
693        endpoint: u16,
694        cluster: u32,
695        command: u32,
696        payload: &[u8],
697        timeout: u16,
698    ) -> Result<Message> {
699        let exchange: u16 = rand::random();
700
701        // Send timed request first
702        let tr = messages::im_timed_request(exchange, timeout)?;
703        let result = self.active.request(exchange, &tr).await?;
704
705        if result.protocol_header.protocol_id
706            != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
707            || result.protocol_header.opcode
708                != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
709        {
710            return Err(anyhow::anyhow!(
711                "invoke_request_timed: unexpected response {:?}",
712                result
713            ));
714        }
715        let status = result
716            .tlv
717            .get_int(&[0])
718            .context("invoke_request_timed: status not found")?;
719        if status != 0 {
720            return Err(anyhow::anyhow!(
721                "invoke_request_timed: unexpected status {}",
722                status
723            ));
724        }
725
726        log::debug!(
727            "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
728            exchange,
729            endpoint,
730            cluster,
731            command
732        );
733        let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
734        self.active.request(exchange, &msg).await
735    }
736
737    /// Receive next unsolicited raw message not handled elsewhere (subscription
738    /// reports are delivered decoded via [Subscription]; only reports with an
739    /// unknown subscription id and other unsolicited messages end up here).
740    /// Returns None when connection is closed. Messages may be dropped when
741    /// nobody drains this channel.
742    pub async fn recv_event(&self) -> Option<Message> {
743        self.active.recv_event().await
744    }
745
746    /// Try receive event without blocking.
747    pub fn try_recv_event(&self) -> Option<Message> {
748        self.active.try_recv_event()
749    }
750
751    /// Re-run CASE over the existing transport channel without tearing it down.
752    /// Stops the active read loop, runs auth_sigma (with BUSY retry), swaps the session,
753    /// and restarts the read loop -- all on the same underlying UDP channel registration.
754    pub async fn reauth(
755        &self,
756        controller: &Controller,
757        node_id: u64,
758        controller_id: u64,
759    ) -> Result<()> {
760        self.active.pause_read_loop().await;
761        let new_session = controller
762            .auth_sigma_with_busy_retry(&self.active.transport_conn, node_id, controller_id)
763            .await?;
764        self.active.reauth_with_session(new_session).await
765    }
766}
767
768/// Active subscription created by [Connection::subscribe_attrs] or
769/// [Connection::subscribe_events]. Decoded updates are delivered via [Subscription::next];
770/// the background read loop acks them automatically. Dropping the handle stops
771/// delivery (the device-side subscription stays active until it expires or is
772/// cancelled via [Connection::im_unsubscribe_all]).
773pub struct Subscription {
774    pub subscription_id: u32,
775    /// Maximum reporting interval in seconds granted by the device.
776    pub max_interval: u16,
777    /// Attribute reports from the priming report (current values at subscribe time).
778    pub priming_attribute_reports: Vec<im::AttributeReport>,
779    /// Event reports from the priming report.
780    pub priming_event_reports: Vec<im::EventReport>,
781    rx: mpsc::Receiver<im::ReportUpdate>,
782    registry: Arc<std::sync::Mutex<HashMap<u32, mpsc::Sender<im::ReportUpdate>>>>,
783}
784
785impl Subscription {
786    /// Receive the next decoded update. Returns None when the connection is
787    /// closed or re-authenticated (the subscription is then gone; resubscribe).
788    pub async fn next(&mut self) -> Option<im::ReportUpdate> {
789        self.rx.recv().await
790    }
791}
792
793impl Drop for Subscription {
794    fn drop(&mut self) {
795        self.registry.lock().unwrap().remove(&self.subscription_id);
796    }
797}
798
799pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
800    let mut out = Vec::new();
801    out.write_u32::<LittleEndian>(pin)?;
802    Ok(out)
803}
804
805pub(crate) async fn auth_spake(connection: &dyn ConnectionTrait, pin: u32) -> Result<session::Session> {
806    let exchange = rand::random();
807    log::debug!("start auth_spake");
808    let mut session = session::Session::new();
809    session.my_session_id = 1;
810    let mut retrctx = retransmit::RetrContext::new(connection, &session);
811    // send pbkdf
812    log::debug!("send pbkdf request");
813    let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
814    retrctx.send(&pbkdf_req_protocol_message).await?;
815
816    // get pbkdf response
817    let pbkdf_response = retrctx.get_next_message().await?;
818    if pbkdf_response.protocol_header.protocol_id
819        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
820        || pbkdf_response.protocol_header.opcode
821            != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
822    {
823        return Err(anyhow::anyhow!("pbkdf response not received"));
824    }
825
826    let iterations = pbkdf_response
827        .tlv
828        .get_int(&[4, 1])
829        .context("pbkdf_response - iterations missing")?;
830    let salt = pbkdf_response
831        .tlv
832        .get_octet_string(&[4, 2])
833        .context("pbkdf_response - salt missing")?;
834    let p_session = pbkdf_response
835        .tlv
836        .get_int(&[3])
837        .context("pbkdf_response - session missing")?;
838
839    // send pake1
840    let engine = spake2p::Engine::new()?;
841    let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
842    log::debug!("send pake1 request");
843    let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
844    retrctx.send(&pake1_protocol_message).await?;
845
846    // receive pake2
847    let pake2 = retrctx.get_next_message().await?;
848    if pake2.protocol_header.protocol_id
849        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
850        || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
851    {
852        return Err(anyhow::anyhow!("pake2 not received"));
853    }
854    let pake2_pb = pake2
855        .tlv
856        .get_octet_string(&[1])
857        .context("pake2 pb tlv missing")?;
858    ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
859
860    let pake2_cb = pake2
861        .tlv
862        .get_octet_string(&[2])
863        .context("pake2 cb tlv missing")?;
864
865    // send pake3
866    let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
867    hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
868    hash_seed.extend_from_slice(&pbkdf_response.payload);
869    engine.finish(&mut ctx, &hash_seed, pake2_cb)?;
870    let pake3_protocol_message = messages::pake3(
871        exchange,
872        &ctx.ca.context("ca value not present in context")?,
873        -1,
874    )?;
875    log::debug!("send pake3 request");
876    retrctx.send(&pake3_protocol_message).await?;
877
878    let pake3_resp = retrctx.get_next_message().await?;
879    match &pake3_resp.status_report_info {
880        Some(s) => {
881            if !s.is_ok() {
882                return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
883            }
884        }
885        None => {
886            return Err(anyhow::anyhow!(
887                "expecting status report (pake3 resp), got {:?}",
888                pake3_resp
889            ))
890        }
891    }
892
893    session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
894    session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
895    session.session_id = p_session as u16;
896    log::debug!("auth_spake ok; session: {}", session.session_id);
897    Ok(session)
898}
899
900pub(crate) async fn auth_sigma(
901    connection: &dyn ConnectionTrait,
902    fabric: &fabric::Fabric,
903    cm: &dyn certmanager::CertManager,
904    node_id: u64,
905    controller_id: u64,
906) -> Result<(session::Session, Option<sigma::ResumptionRecord>)> {
907    log::debug!("auth_sigma");
908    let exchange = rand::random();
909    let session = session::Session::new();
910    let mut retrctx = retransmit::RetrContext::new(connection, &session);
911    retrctx.subscribe_exchange(exchange);
912    let mut ctx = sigma::SigmaContext::new(node_id);
913    let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
914    sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
915    let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
916
917    log::debug!("send sigma1 {}", exchange);
918    retrctx.send(&s1).await?;
919
920    // receive sigma2
921    log::debug!("receive sigma2 {}", exchange);
922    let sigma2 = retrctx.get_next_message().await?;
923    log::debug!("sigma2 received {:?}", sigma2);
924    if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
925        && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
926    {
927        let sri = sigma2.status_report_info.context("status report info missing")?;
928        if sri.is_busy() {
929            return Err(anyhow::Error::new(SigmaBusy { wait_ms: sri.minimum_wait_time_ms() }));
930        }
931        return Err(anyhow::anyhow!("sigma2 not received, status: {}", sri));
932    }
933    ctx.sigma2_payload = sigma2.payload;
934    ctx.responder_session = sigma2
935        .tlv
936        .get_int(&[2])
937        .context("responder session tlv missing in sigma2")? as u16;
938    ctx.responder_public = sigma2
939        .tlv
940        .get_octet_string(&[3])
941        .context("responder public tlv missing in sigma2")?
942        .to_vec();
943
944    log::debug!("verify sigma2 {}", exchange);
945    let resumption_id =
946        sigma::verify_sigma2(fabric, &ctx, &ca_pubkey).context("sigma2 verification failed")?;
947
948    let controller_private = cm.get_user_key(controller_id)?;
949    let controller_x509 = cm.get_user_cert(controller_id)?;
950    let controller_matter_cert =
951        cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
952
953    // send sigma3
954    log::debug!("send sigma3 {} with piggyback ack for {}", exchange, sigma2.message_header.message_counter);
955    sigma::sigma3(
956        fabric,
957        &mut ctx,
958        &controller_private.to_sec1_der()?,
959        &controller_matter_cert,
960    )?;
961    let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload, sigma2.message_header.message_counter)?;
962    retrctx.send(&sigma3).await?;
963
964    log::debug!("receive result {}", exchange);
965    let status = retrctx.get_next_message().await?;
966    if !status
967        .status_report_info
968        .as_ref()
969        .context("sigma3 status resp not received")?
970        .is_ok()
971    {
972        return Err(anyhow::anyhow!(format!(
973            "response to sigma3 does not contain status ok {:?}",
974            status
975        )));
976    }
977
978    //session keys
979    let mut th = ctx.sigma1_payload.clone();
980    th.extend_from_slice(&ctx.sigma2_payload);
981
982    let mut transcript = th;
983    transcript.extend_from_slice(&ctx.sigma3_payload);
984    let transcript_hash = cryptoutil::sha256(&transcript);
985    let mut salt = fabric.signed_ipk()?;
986    salt.extend_from_slice(&transcript_hash);
987    let shared = ctx.shared.context("shared secret not in context")?;
988    let shared_bytes: [u8; 32] = shared.raw_secret_bytes().as_slice()
989        .try_into()
990        .map_err(|_| anyhow::anyhow!("shared secret wrong length"))?;
991    let keypack = cryptoutil::hkdf_sha256(
992        &salt,
993        &shared_bytes,
994        "SessionKeys".as_bytes(),
995        16 * 3,
996    )?;
997    let mut ses = session::Session::new();
998    ses.session_id = ctx.responder_session;
999    ses.my_session_id = ctx.session_id;
1000    ses.set_decrypt_key(&keypack[16..32]);
1001    ses.set_encrypt_key(&keypack[..16]);
1002
1003    let mut local_node = Vec::new();
1004    local_node.write_u64::<LittleEndian>(controller_id)?;
1005    ses.local_node = Some(local_node);
1006
1007    let mut remote_node = Vec::new();
1008    remote_node.write_u64::<LittleEndian>(node_id)?;
1009    ses.remote_node = Some(remote_node);
1010
1011    let resumption = resumption_id
1012        .map(|id| sigma::ResumptionRecord { resumption_id: id, shared_secret: shared_bytes });
1013
1014    if resumption.is_none() {
1015        log::debug!("auth_sigma: responder did not include a NewResumptionID - resumption unavailable for node {}", node_id);
1016    }
1017
1018    Ok((ses, resumption))
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024    use crate::messages::ProtocolMessageHeader;
1025    use crate::tlv;
1026    use std::time::Duration;
1027
1028    // Loopback transport: the test acts as the device on the other end.
1029    struct MockConn {
1030        inbound: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
1031        outbound: mpsc::UnboundedSender<Vec<u8>>,
1032        reliable: bool,
1033        mrp: std::sync::Mutex<crate::mrp::MrpParameters>,
1034    }
1035
1036    #[async_trait::async_trait]
1037    impl ConnectionTrait for MockConn {
1038        async fn send(&self, data: &[u8]) -> Result<()> {
1039            self.outbound
1040                .send(data.to_vec())
1041                .map_err(|_| anyhow::anyhow!("mock closed"))
1042        }
1043        async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
1044            let mut rx = self.inbound.lock().await;
1045            match tokio::time::timeout(timeout, rx.recv()).await {
1046                Ok(Some(d)) => Ok(d),
1047                Ok(None) => Err(anyhow::anyhow!("mock closed")),
1048                Err(_) => Err(anyhow::anyhow!("timeout")),
1049            }
1050        }
1051        fn is_reliable(&self) -> bool {
1052            self.reliable
1053        }
1054        fn mrp_params(&self) -> crate::mrp::MrpParameters {
1055            *self.mrp.lock().unwrap()
1056        }
1057        fn set_mrp_params(&self, params: crate::mrp::MrpParameters) {
1058            *self.mrp.lock().unwrap() = params;
1059        }
1060    }
1061
1062    struct MockDevice {
1063        rx: mpsc::UnboundedReceiver<Vec<u8>>,
1064        tx: mpsc::Sender<Vec<u8>>,
1065        session: session::Session,
1066    }
1067
1068    impl MockDevice {
1069        async fn recv(&mut self) -> Message {
1070            let data = tokio::time::timeout(Duration::from_secs(2), self.rx.recv())
1071                .await
1072                .expect("timeout waiting for controller message")
1073                .expect("mock closed");
1074            Message::decode(&data).unwrap()
1075        }
1076
1077        async fn expect_status_response(&mut self, want_flags: u8, want_ack: u32) {
1078            let msg = self.recv().await;
1079            assert_eq!(
1080                msg.protocol_header.protocol_id,
1081                ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
1082            );
1083            assert_eq!(
1084                msg.protocol_header.opcode,
1085                ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
1086            );
1087            assert_eq!(
1088                msg.protocol_header.exchange_flags,
1089                ProtocolMessageHeader::FLAG_RELIABILITY | want_flags
1090            );
1091            assert_eq!(msg.protocol_header.ack_counter, want_ack);
1092            assert_eq!(msg.tlv.get_int(&[0]), Some(0));
1093        }
1094
1095        async fn expect_silence(&mut self) {
1096            assert!(
1097                tokio::time::timeout(Duration::from_millis(200), self.rx.recv())
1098                    .await
1099                    .is_err(),
1100                "unexpected message from controller"
1101            );
1102        }
1103
1104        async fn send(&self, payload: &[u8]) -> u32 {
1105            let encoded = self.session.encode_message(payload).unwrap();
1106            let (header, _) = messages::MessageHeader::decode(&encoded).unwrap();
1107            self.tx.send(encoded).await.unwrap();
1108            header.message_counter
1109        }
1110
1111        async fn recv_within(&mut self, d: Duration) -> Option<Message> {
1112            match tokio::time::timeout(d, self.rx.recv()).await {
1113                Ok(Some(data)) => Some(Message::decode(&data).unwrap()),
1114                _ => None,
1115            }
1116        }
1117    }
1118
1119    fn mock_pair() -> (Connection, MockDevice) {
1120        mock_pair_with(true, Default::default())
1121    }
1122
1123    fn mock_pair_unreliable(mrp: crate::mrp::MrpParameters) -> (Connection, MockDevice) {
1124        mock_pair_with(false, mrp)
1125    }
1126
1127    fn mock_pair_with(reliable: bool, mrp: crate::mrp::MrpParameters) -> (Connection, MockDevice) {
1128        let (to_ctrl_tx, to_ctrl_rx) = mpsc::channel(32);
1129        let (to_dev_tx, to_dev_rx) = mpsc::unbounded_channel();
1130        let mock = Arc::new(MockConn {
1131            inbound: tokio::sync::Mutex::new(to_ctrl_rx),
1132            outbound: to_dev_tx,
1133            reliable,
1134            mrp: std::sync::Mutex::new(mrp),
1135        });
1136        let conn = Connection::from_parts(mock, session::Session::new());
1137        let device = MockDevice {
1138            rx: to_dev_rx,
1139            tx: to_ctrl_tx,
1140            session: session::Session::new(),
1141        };
1142        (conn, device)
1143    }
1144
1145    fn report_data(
1146        exchange: u16,
1147        flags: u8,
1148        sub_id: Option<u32>,
1149        values: &[(u16, bool)],
1150        more: bool,
1151        suppress: bool,
1152    ) -> Vec<u8> {
1153        let b = ProtocolMessageHeader {
1154            exchange_flags: flags,
1155            opcode: ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA,
1156            exchange_id: exchange,
1157            protocol_id: ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
1158            ack_counter: 0,
1159        }
1160        .encode()
1161        .unwrap();
1162        let mut t = tlv::TlvBuffer::from_vec(b);
1163        t.write_anon_struct().unwrap();
1164        if let Some(id) = sub_id {
1165            t.write_uint32(0, id).unwrap();
1166        }
1167        t.write_array(1).unwrap();
1168        for (endpoint, value) in values {
1169            t.write_anon_struct().unwrap();
1170            t.write_struct(1).unwrap();
1171            t.write_uint32(0, 0).unwrap();
1172            t.write_list(1).unwrap();
1173            t.write_uint16(2, *endpoint).unwrap();
1174            t.write_uint32(3, 6).unwrap();
1175            t.write_uint32(4, 0).unwrap();
1176            t.write_struct_end().unwrap();
1177            t.write_bool(2, *value).unwrap();
1178            t.write_struct_end().unwrap();
1179            t.write_struct_end().unwrap();
1180        }
1181        t.write_struct_end().unwrap();
1182        if more {
1183            t.write_bool(3, true).unwrap();
1184        }
1185        if suppress {
1186            t.write_bool(4, true).unwrap();
1187        }
1188        t.write_struct_end().unwrap();
1189        t.data
1190    }
1191
1192    fn subscribe_response(exchange: u16, sub_id: u32, max_interval: u16) -> Vec<u8> {
1193        let b = ProtocolMessageHeader {
1194            exchange_flags: 0,
1195            opcode: ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_RESP,
1196            exchange_id: exchange,
1197            protocol_id: ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
1198            ack_counter: 0,
1199        }
1200        .encode()
1201        .unwrap();
1202        let mut t = tlv::TlvBuffer::from_vec(b);
1203        t.write_anon_struct().unwrap();
1204        t.write_uint32(0, sub_id).unwrap();
1205        t.write_uint16(2, max_interval).unwrap();
1206        t.write_struct_end().unwrap();
1207        t.data
1208    }
1209
1210    const FLAGS_RESPONDER: u8 = 0;
1211    const FLAGS_DEVICE_INITIATED: u8 = ProtocolMessageHeader::FLAG_INITIATOR;
1212    const ACK_AND_INITIATOR: u8 =
1213        ProtocolMessageHeader::FLAG_INITIATOR | ProtocolMessageHeader::FLAG_ACK;
1214
1215    #[tokio::test]
1216    async fn test_read_request2_single_chunk() {
1217        let (conn, mut device) = mock_pair();
1218        let task = tokio::spawn(async move {
1219            let req = device.recv().await;
1220            assert_eq!(
1221                req.protocol_header.opcode,
1222                ProtocolMessageHeader::INTERACTION_OPCODE_READ_REQ
1223            );
1224            let exchange = req.protocol_header.exchange_id;
1225            device
1226                .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(1, true)], false, true))
1227                .await;
1228            device.expect_silence().await;
1229        });
1230        let val = conn.read_request2(1, 6, 0).await.unwrap();
1231        assert_eq!(val, TlvItemValue::Bool(true));
1232        task.await.unwrap();
1233    }
1234
1235    #[tokio::test]
1236    async fn test_read_request2_chunked() {
1237        let (conn, mut device) = mock_pair();
1238        let task = tokio::spawn(async move {
1239            let req = device.recv().await;
1240            let exchange = req.protocol_header.exchange_id;
1241            let counter = device
1242                .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(1, true)], true, false))
1243                .await;
1244            device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1245            device
1246                .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(2, false)], false, true))
1247                .await;
1248            device.expect_silence().await;
1249        });
1250        let val = conn.read_request2(1, 6, 0).await.unwrap();
1251        assert_eq!(val, TlvItemValue::Bool(true));
1252        task.await.unwrap();
1253    }
1254
1255    #[tokio::test]
1256    async fn test_subscribe_and_updates() {
1257        let (conn, mut device) = mock_pair();
1258        let task = tokio::spawn(async move {
1259            let req = device.recv().await;
1260            assert_eq!(
1261                req.protocol_header.opcode,
1262                ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_REQ
1263            );
1264            let exchange = req.protocol_header.exchange_id;
1265            let counter = device
1266                .send(&report_data(exchange, FLAGS_RESPONDER, Some(7), &[(1, true)], true, false))
1267                .await;
1268            device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1269            let counter = device
1270                .send(&report_data(exchange, FLAGS_RESPONDER, Some(7), &[(2, false)], false, false))
1271                .await;
1272            device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1273            device.send(&subscribe_response(exchange, 7, 60)).await;
1274
1275            // device-initiated update on a fresh exchange
1276            let counter = device
1277                .send(&report_data(0x4001, FLAGS_DEVICE_INITIATED, Some(7), &[(1, false)], false, false))
1278                .await;
1279            device
1280                .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1281                .await;
1282            device
1283        });
1284
1285        let mut sub = conn.subscribe_attrs(Some(1), Some(6), Some(0), false).await.unwrap();
1286        assert_eq!(sub.subscription_id, 7);
1287        assert_eq!(sub.max_interval, 60);
1288        assert_eq!(sub.priming_attribute_reports.len(), 2);
1289        assert_eq!(sub.priming_attribute_reports[0].path.endpoint, Some(1));
1290        assert_eq!(sub.priming_attribute_reports[1].path.endpoint, Some(2));
1291
1292        let update = sub.next().await.unwrap();
1293        assert_eq!(update.subscription_id, 7);
1294        assert_eq!(update.attribute_reports.len(), 1);
1295        assert_eq!(
1296            update.attribute_reports[0].data,
1297            im::AttributeData::Value(TlvItemValue::Bool(false))
1298        );
1299        task.await.unwrap();
1300    }
1301
1302    #[tokio::test]
1303    async fn test_chunked_unsolicited_report() {
1304        let (conn, mut device) = mock_pair();
1305        let task = tokio::spawn(async move {
1306            let req = device.recv().await;
1307            let exchange = req.protocol_header.exchange_id;
1308            let counter = device
1309                .send(&report_data(exchange, FLAGS_RESPONDER, Some(9), &[(1, true)], false, false))
1310                .await;
1311            device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1312            device.send(&subscribe_response(exchange, 9, 60)).await;
1313
1314            // chunked device-initiated update
1315            let counter = device
1316                .send(&report_data(0x4002, FLAGS_DEVICE_INITIATED, Some(9), &[(1, false)], true, false))
1317                .await;
1318            device
1319                .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1320                .await;
1321            let counter = device
1322                .send(&report_data(0x4002, FLAGS_DEVICE_INITIATED, Some(9), &[(2, true)], false, false))
1323                .await;
1324            device
1325                .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1326                .await;
1327        });
1328
1329        let mut sub = conn.subscribe_attrs(Some(1), Some(6), Some(0), false).await.unwrap();
1330        let update = sub.next().await.unwrap();
1331        assert_eq!(update.attribute_reports.len(), 2);
1332        assert_eq!(update.attribute_reports[0].path.endpoint, Some(1));
1333        assert_eq!(update.attribute_reports[1].path.endpoint, Some(2));
1334        task.await.unwrap();
1335    }
1336
1337    #[tokio::test]
1338    async fn test_unregistered_subscription_id() {
1339        let (conn, mut device) = mock_pair();
1340
1341        let counter = device
1342            .send(&report_data(0x4003, FLAGS_DEVICE_INITIATED, Some(99), &[(1, true)], false, false))
1343            .await;
1344        device
1345            .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1346            .await;
1347        let raw = conn.recv_event().await.unwrap();
1348        assert_eq!(
1349            raw.protocol_header.opcode,
1350            ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA
1351        );
1352
1353        conn.set_auto_status_response(false);
1354        device
1355            .send(&report_data(0x4004, FLAGS_DEVICE_INITIATED, Some(99), &[(1, true)], false, false))
1356            .await;
1357        device.expect_silence().await;
1358        let raw = conn.recv_event().await.unwrap();
1359        assert_eq!(raw.protocol_header.exchange_id, 0x4004);
1360    }
1361
1362    #[tokio::test]
1363    async fn test_duplicate_message_dropped() {
1364        let (conn, mut device) = mock_pair();
1365
1366        let payload =
1367            report_data(0x4005, FLAGS_DEVICE_INITIATED, Some(99), &[(1, true)], false, false);
1368        let encoded = device.session.encode_message(&payload).unwrap();
1369        let (header, _) = messages::MessageHeader::decode(&encoded).unwrap();
1370        device.tx.send(encoded.clone()).await.unwrap();
1371        device
1372            .expect_status_response(ProtocolMessageHeader::FLAG_ACK, header.message_counter)
1373            .await;
1374        let raw = conn.recv_event().await.unwrap();
1375        assert_eq!(raw.protocol_header.exchange_id, 0x4005);
1376
1377        // replayed frame must be dropped: no status response, no event
1378        device.tx.send(encoded).await.unwrap();
1379        device.expect_silence().await;
1380        assert!(conn.try_recv_event().is_none());
1381    }
1382
1383    #[tokio::test]
1384    async fn test_initiator_flag_not_misrouted() {
1385        let (conn, mut device) = mock_pair();
1386        let task = tokio::spawn(async move {
1387            let req = device.recv().await;
1388            let exchange = req.protocol_header.exchange_id;
1389            // device-initiated report colliding with the pending exchange id
1390            // must not resolve the pending read request
1391            let counter = device
1392                .send(&report_data(exchange, FLAGS_DEVICE_INITIATED, None, &[(5, false)], false, false))
1393                .await;
1394            device
1395                .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1396                .await;
1397            device
1398                .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(1, true)], false, true))
1399                .await;
1400        });
1401        let val = conn.read_request2(1, 6, 0).await.unwrap();
1402        assert_eq!(val, TlvItemValue::Bool(true));
1403        task.await.unwrap();
1404    }
1405
1406    #[tokio::test(start_paused = true)]
1407    async fn test_retransmit_schedule_and_give_up() {
1408        let mrp = crate::mrp::MrpParameters::from_txt_ms(Some(5000), None, None);
1409        let (conn, mut device) = mock_pair_unreliable(mrp);
1410        let req = tokio::spawn(async move { conn.read_request2(1, 6, 0).await });
1411
1412        let mut times = Vec::new();
1413        let mut counters = Vec::new();
1414        for i in 0..crate::mrp::MRP_MAX_TRANSMISSIONS {
1415            let msg = device
1416                .recv_within(Duration::from_secs(30))
1417                .await
1418                .unwrap_or_else(|| panic!("missing transmission {}", i));
1419            times.push(tokio::time::Instant::now());
1420            counters.push(msg.message_header.message_counter);
1421        }
1422        assert!(counters.iter().all(|c| *c == counters[0]));
1423
1424        // gap n follows backoff: 5s * 1.1 * 1.6^max(0, n-1) plus up to 25% jitter
1425        for (n, w) in times.windows(2).enumerate() {
1426            let gap = (w[1] - w[0]).as_secs_f64();
1427            let lower = 5.0 * 1.1 * 1.6f64.powi(n.saturating_sub(1) as i32);
1428            let upper = lower * 1.25;
1429            assert!(
1430                gap >= lower - 0.01 && gap <= upper + 0.1,
1431                "gap {} = {} not in [{}, {}]",
1432                n, gap, lower, upper
1433            );
1434        }
1435
1436        // after the final backoff period the exchange is dropped and the request fails
1437        let res = req.await.unwrap();
1438        assert!(res.is_err(), "request should fail after give-up");
1439        assert!(
1440            device.recv_within(Duration::from_secs(120)).await.is_none(),
1441            "no transmissions expected after give-up"
1442        );
1443    }
1444
1445    #[tokio::test(start_paused = true)]
1446    async fn test_retransmit_stops_after_ack() {
1447        let (conn, mut device) = mock_pair_unreliable(Default::default());
1448        let _req = tokio::spawn(async move { conn.read_request2(1, 6, 0).await });
1449
1450        let msg = device.recv_within(Duration::from_secs(5)).await.expect("request");
1451        let retr = device.recv_within(Duration::from_secs(5)).await.expect("retransmit");
1452        assert_eq!(
1453            msg.message_header.message_counter,
1454            retr.message_header.message_counter
1455        );
1456
1457        device
1458            .send(&messages::ack(
1459                msg.protocol_header.exchange_id,
1460                msg.message_header.message_counter as i64,
1461            ).unwrap())
1462            .await;
1463        assert!(
1464            device.recv_within(Duration::from_secs(60)).await.is_none(),
1465            "no retransmissions expected after ack"
1466        );
1467    }
1468
1469    #[tokio::test(start_paused = true)]
1470    async fn test_retransmit_not_starved_by_inbound_traffic() {
1471        let (conn, mut device) = mock_pair_unreliable(Default::default());
1472        let _req = tokio::spawn(async move { conn.read_request2(1, 6, 0).await });
1473
1474        let first = device.recv_within(Duration::from_secs(2)).await.expect("request");
1475        let counter = first.message_header.message_counter;
1476
1477        // keep the read loop busy with inbound messages so it never hits a
1478        // receive timeout; the retransmit (due at ~550-690ms) must still fire
1479        let mut seen_retransmit = false;
1480        for _ in 0..10 {
1481            device.send(&messages::ack(0x7777, 999_999).unwrap()).await;
1482            tokio::time::sleep(Duration::from_millis(100)).await;
1483            while let Ok(data) = device.rx.try_recv() {
1484                let m = Message::decode(&data).unwrap();
1485                if m.message_header.message_counter == counter {
1486                    seen_retransmit = true;
1487                }
1488            }
1489        }
1490        assert!(seen_retransmit, "retransmit starved by continuous inbound traffic");
1491    }
1492}