1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use crate::{
4 active_connection::{ActiveConnection, Exchange},
5 cert_matter, certmanager, commission, fabric, im,
6 messages::{self, Message},
7 retransmit, session, sigma, spake2p,
8 tlv::TlvItemValue,
9 transport::{self, ConnectionTrait},
10 util::cryptoutil,
11};
12use anyhow::{Context, Result};
13use tokio::sync::mpsc;
14use byteorder::{LittleEndian, WriteBytesExt};
15
16pub struct Controller {
17 certmanager: Arc<dyn certmanager::CertManager>,
18 #[allow(dead_code)]
19 transport: Arc<transport::Transport>,
20 fabric: fabric::Fabric,
21 resumption: Arc<tokio::sync::Mutex<HashMap<u64, sigma::ResumptionRecord>>>,
23}
24
25pub struct Connection {
26 active: ActiveConnection,
27}
28const CA_ID: u64 = 1;
32
33#[derive(Debug, Clone, Copy)]
34pub struct SigmaBusy {
35 pub wait_ms: Option<u32>,
36}
37impl std::fmt::Display for SigmaBusy {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self.wait_ms {
40 Some(ms) => write!(f, "responder BUSY (min wait {} ms)", ms),
41 None => write!(f, "responder BUSY"),
42 }
43 }
44}
45impl std::error::Error for SigmaBusy {}
46
47impl Controller {
48 pub fn new(
49 certmanager: &Arc<dyn certmanager::CertManager>,
50 transport: &Arc<transport::Transport>,
51 fabric_id: u64,
52 ) -> Result<Arc<Self>> {
53 let fabric = fabric::Fabric::new(
54 fabric_id,
55 CA_ID,
56 &certmanager.get_ca_public_key()?,
57 &certmanager.get_ipk_epoch_key(),
58 );
59 Ok(Arc::new(Self {
60 certmanager: certmanager.clone(),
61 transport: transport.clone(),
62 fabric,
63 resumption: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
64 }))
65 }
66
67 pub async fn commission(
74 &self,
75 connection: &Arc<dyn ConnectionTrait>,
76 pin: u32,
77 node_id: u64,
78 controller_id: u64,
79 ) -> Result<Connection> {
80 let mut session = auth_spake(connection.as_ref(), pin).await?;
81 let session = commission::commission(
82 connection.as_ref(),
83 &mut session,
84 &self.fabric,
85 self.certmanager.as_ref(),
86 node_id,
87 controller_id,
88 )
89 .await?;
90 Ok(Connection {
91 active: ActiveConnection::new(connection.clone(), session),
92 })
93 }
94
95 pub async fn auth_sigma(
97 &self,
98 connection: &Arc<dyn ConnectionTrait>,
99 node_id: u64,
100 controller_id: u64,
101 ) -> Result<Connection> {
102 let (session, resumption) = auth_sigma(
103 connection.as_ref(),
104 &self.fabric,
105 self.certmanager.as_ref(),
106 node_id,
107 controller_id,
108 )
109 .await?;
110 if let Some(record) = resumption {
111 self.resumption.lock().await.insert(node_id, record);
112 }
113 Ok(Connection {
114 active: ActiveConnection::new(connection.clone(), session),
115 })
116 }
117
118 pub async fn auth_sigma_with_busy_retry(
122 &self,
123 connection: &Arc<dyn ConnectionTrait>,
124 node_id: u64,
125 controller_id: u64,
126 ) -> Result<session::Session> {
127 if let Some(ses) = self.try_auth_sigma_resume(connection, node_id, controller_id).await? {
128 return Ok(ses);
129 }
130
131 const MAX_BUSY_RETRIES: u32 = 5;
132 const DEFAULT_BUSY_WAIT: Duration = Duration::from_millis(3000);
133 const MAX_BUSY_WAIT: Duration = Duration::from_secs(60);
134
135 let mut busy_retries = 0u32;
136 loop {
137 match auth_sigma(connection.as_ref(), &self.fabric, self.certmanager.as_ref(), node_id, controller_id).await {
138 Ok((ses, resumption)) => {
139 if let Some(record) = resumption {
140 self.resumption.lock().await.insert(node_id, record);
141 }
142 return Ok(ses);
143 }
144 Err(e) => {
145 if let Some(busy) = e.downcast_ref::<SigmaBusy>() {
146 if busy_retries < MAX_BUSY_RETRIES {
147 let wait = busy.wait_ms
148 .map(|ms| Duration::from_millis(ms.into()))
149 .unwrap_or(DEFAULT_BUSY_WAIT)
150 .min(MAX_BUSY_WAIT);
151 log::info!(
152 "CASE responder BUSY, waiting {:?} before retry ({}/{})",
153 wait, busy_retries + 1, MAX_BUSY_RETRIES
154 );
155 tokio::time::sleep(wait).await;
156 busy_retries += 1;
157 continue;
158 }
159 return Err(e).context(format!(
160 "still BUSY after {} retries", MAX_BUSY_RETRIES
161 ));
162 }
163 return Err(e);
164 }
165 }
166 }
167 }
168
169 async fn try_auth_sigma_resume(
170 &self,
171 connection: &Arc<dyn ConnectionTrait>,
172 node_id: u64,
173 controller_id: u64,
174 ) -> Result<Option<session::Session>> {
175 let record = {
176 let map = self.resumption.lock().await;
177 map.get(&node_id).cloned()
178 };
179 let record = match record {
180 Some(r) => r,
181 None => return Ok(None),
182 };
183
184 let exchange: u16 = rand::random();
185 let session = session::Session::new();
186 let mut retrctx = retransmit::RetrContext::new(connection.as_ref(), &session);
187 retrctx.subscribe_exchange(exchange);
188
189 let mut ctx = sigma::SigmaContext::new(node_id);
190 let ca_pubkey = self.certmanager.get_ca_key()?.public_key().to_sec1_bytes();
191 sigma::sigma1_resume(&self.fabric, &mut ctx, &ca_pubkey, &record)?;
192 let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
193
194 log::debug!("CASE resume: send Sigma1Resume exchange:{}", exchange);
195 retrctx.send(&s1).await?;
196
197 let sigma2 = retrctx.get_next_message().await?;
198
199 if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
202 && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
203 {
204 log::debug!(
205 "CASE resume: responder rejected with status report, falling back to full SIGMA (exchange:{} {:?})",
206 exchange,
207 sigma2.status_report_info
208 );
209 return Ok(None);
210 }
211
212 if !sigma::is_sigma2_resume(&sigma2.payload) {
213 log::debug!("CASE resume: responder sent full Sigma2, falling back");
216 self.resumption.lock().await.remove(&node_id);
217 return Ok(None);
218 }
219
220 let parsed = match sigma::parse_sigma2_resume(&sigma2.payload) {
221 Ok(p) => p,
222 Err(e) => {
223 log::debug!("CASE resume: malformed Sigma2Resume ({:?}), falling back to full SIGMA", e);
224 self.resumption.lock().await.remove(&node_id);
225 return Ok(None);
226 }
227 };
228
229 if let Err(e) = sigma::verify_sigma2_resume_mic(
230 &record.shared_secret,
231 &ctx.initiator_random,
232 &parsed.new_resumption_id,
233 &parsed.sigma2_resume_mic,
234 ) {
235 log::debug!("CASE resume: MIC verification failed: {:?}, falling back to full SIGMA", e);
236 self.resumption.lock().await.remove(&node_id);
237 return Ok(None);
238 }
239
240 let sr = messages::status_report_success(exchange)?;
241 if let Err(e) = retrctx.send(&sr).await {
242 log::debug!("CASE resume: failed to send StatusReport ({:?}), falling back to full SIGMA", e);
243 self.resumption.lock().await.remove(&node_id);
244 return Ok(None);
245 }
246
247 let keypack = sigma::derive_resumed_session_keys(
248 &record.shared_secret,
249 &ctx.initiator_random,
250 &record.resumption_id,
251 )?;
252
253 let mut ses = session::Session::new();
254 ses.session_id = parsed.responder_session_id;
255 ses.my_session_id = ctx.session_id;
256 ses.set_decrypt_key(&keypack[16..32]);
257 ses.set_encrypt_key(&keypack[..16]);
258
259 let mut local_node = Vec::new();
260 local_node.write_u64::<LittleEndian>(controller_id)?;
261 ses.local_node = Some(local_node);
262
263 let mut remote_node = Vec::new();
264 remote_node.write_u64::<LittleEndian>(node_id)?;
265 ses.remote_node = Some(remote_node);
266
267 {
269 let mut map = self.resumption.lock().await;
270 if let Some(entry) = map.get_mut(&node_id) {
271 entry.resumption_id = parsed.new_resumption_id;
272 }
273 }
274
275 log::info!("CASE session resumed for node_id={}", node_id);
276 Ok(Some(ses))
277 }
278
279 #[cfg(feature = "ble")]
293 pub async fn commission_ble(
294 &self,
295 discriminator: u16,
296 short_discriminator: bool,
297 pin: u32,
298 node_id: u64,
299 controller_id: u64,
300 network_creds: commission::NetworkCreds,
301 mdns: &std::sync::Arc<crate::mdns2::MdnsService>,
302 ) -> Result<Connection> {
303 use crate::btp::BtpConnection;
304
305 let peripheral = crate::ble::find_by_discriminator(discriminator, short_discriminator, std::time::Duration::from_secs(30))
307 .await
308 .context("BLE scan")?;
309 log::debug!("BLE device found: z2");
310 let btp_conn = BtpConnection::connect(peripheral).await.context("BTP connect")?;
311
312 let mut pase_session = auth_spake(btp_conn.as_ref(), pin).await.context("PASE over BLE")?;
314
315 commission::commission_ble_phase(
317 btp_conn.as_ref(),
318 &mut pase_session,
319 &self.fabric,
320 self.certmanager.as_ref(),
321 node_id,
322 controller_id,
323 &network_creds,
324 )
325 .await
326 .context("BLE commissioning phase")?;
327 tokio::time::sleep(std::time::Duration::from_secs(5)).await; drop(btp_conn);
331
332 for attempt in 0..5 {
334 let addresses = match self.discover_operational_addresses(node_id, mdns).await {
335 Ok(a) => a,
336 Err(e) => {
337 log::debug!("mDNS discovery failed (attempt {}/{}): {:?}", attempt + 1, 5, e);
338 continue;
339 }
340 };
341 for address in &addresses {
342 log::debug!("Trying to commission over UDP at {}... (attempt {}/{})", address, attempt + 1, 5);
343 let udp_conn = self.transport.create_connection(&address).await;
344 let ses = commission::commissioning_complete_udp(
345 udp_conn.as_ref(),
346 self.certmanager.as_ref(),
347 node_id,
348 controller_id,
349 &self.fabric,
350 )
351 .await;
352 if let Ok(ses) = ses {
353 return Ok(Connection {
354 active: ActiveConnection::new(udp_conn, ses),
355 });
356 } else {
357 log::debug!("Failed to commission over UDP at {}: {:?}", address, ses.err());
358 }
359 }
360 }
361 Err(anyhow::anyhow!("failed to commission device over UDP at any discovered address"))
362 }
363
364 #[cfg(feature = "ble")]
365 async fn discover_operational_addresses(
366 &self,
367 node_id: u64,
368 mdns: &std::sync::Arc<crate::mdns2::MdnsService>,
369 ) -> Result<Vec<String>> {
370 use crate::discover;
371
372 let ca_pubkey = self.certmanager.get_ca_public_key()?;
373 let fabric_tmp = fabric::Fabric::new(self.fabric.id, 0, &ca_pubkey, &self.certmanager.get_ipk_epoch_key());
374 let compressed = fabric_tmp.compressed().context("compressed fabric ID")?;
375 let instance = format!("{}-{:016X}", hex::encode_upper(&compressed), node_id);
376 let expected_target = format!("{}._matter._tcp.local.", instance);
377
378 log::debug!("Discovering operational device via mDNS with target {}", expected_target);
379 let (_, info) = discover::discover_one(
380 mdns,
381 "_matter._tcp.local",
382 "_matter._tcp.local.",
383 std::time::Duration::from_secs(120),
384 move |target, _| target == expected_target,
385 ).await.context(format!("operational mDNS timeout for {}", instance))?;
386 log::debug!("Operational mDNS discovered device: {:?}", info);
387
388 let port = info.port.unwrap_or(5540);
389 let addresses: Vec<String> = info
390 .ips
391 .iter()
392 .map(|ip| crate::discover::addr_string(ip, port, info.scope_id))
393 .collect();
394 log::info!("Device discovered at {}", addresses.join(", "));
395 Ok(addresses)
396 }
397}
398
399impl Connection {
401 pub(crate) fn from_parts(conn: Arc<dyn ConnectionTrait>, session: session::Session) -> Self {
403 Self { active: ActiveConnection::new(conn, session) }
404 }
405
406 pub async fn read_request(
408 &self,
409 endpoint: u16,
410 cluster: u32,
411 attr: u32,
412 ) -> Result<Message> {
413 let exchange: u16 = rand::random();
414 let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
415 self.active.request(exchange, &msg).await
416 }
417
418 pub async fn read_request2(
421 &self,
422 endpoint: u16,
423 cluster: u32,
424 attr: u32,
425 ) -> Result<TlvItemValue> {
426 let exchange: u16 = rand::random();
427 let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
428 let mut ex = self.active.open_exchange(exchange);
429 ex.send(&msg).await?;
430 let report = self.collect_reports(&mut ex).await?;
431 let first = report
432 .attribute_reports
433 .into_iter()
434 .next()
435 .context("report data contains no attribute reports")?;
436 match first.data {
437 im::AttributeData::Value(v) => Ok(v),
438 im::AttributeData::Status { status, .. } => {
439 Err(anyhow::anyhow!("report data with status {}", status))
440 }
441 }
442 }
443
444 async fn collect_reports(&self, exchange: &mut Exchange<'_>) -> Result<im::ReportData> {
449 let mut merged: Option<im::ReportData> = None;
450 loop {
451 let msg = exchange.recv().await?;
452 if let Some(status) = &msg.status_report_info {
453 return Err(anyhow::anyhow!(
454 "status report while waiting for report data: {:?}",
455 status
456 ));
457 }
458 if msg.protocol_header.protocol_id
459 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
460 || msg.protocol_header.opcode
461 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA
462 {
463 return Err(anyhow::anyhow!(
464 "response is not expected report_data {:?}",
465 msg.protocol_header
466 ));
467 }
468 let report = im::ReportData::parse(&msg.tlv)?;
469 let more = report.more_chunks;
470 let respond = more || !report.suppress_response;
471 match merged.as_mut() {
472 Some(m) => m.merge(report),
473 None => merged = Some(report),
474 }
475 if respond {
476 let flags = messages::im_status_flags_for(msg.protocol_header.exchange_flags);
477 let resp = messages::im_status_response(
478 exchange.id,
479 flags,
480 msg.message_header.message_counter,
481 )?;
482 exchange.send(&resp).await?;
483 }
484 if !more {
485 return merged.context("no report data received");
486 }
487 }
488 }
489
490 pub async fn invoke_request(
492 &self,
493 endpoint: u16,
494 cluster: u32,
495 command: u32,
496 payload: &[u8],
497 ) -> Result<Message> {
498 let exchange: u16 = rand::random();
499 log::debug!(
500 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
501 exchange,
502 endpoint,
503 cluster,
504 command
505 );
506 let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
507 self.active.request(exchange, &msg).await
508 }
509
510 pub async fn invoke_request2(
512 &self,
513 endpoint: u16,
514 cluster: u32,
515 command: u32,
516 payload: &[u8],
517 ) -> Result<TlvItemValue> {
518 let res = self.invoke_request(endpoint, cluster, command, payload).await?;
519 let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
520 Ok(o.clone())
521 }
522
523 pub async fn write_request(
524 &self,
525 endpoint: u16,
526 cluster: u32,
527 attr: u32,
528 payload: &[u8],
529 ) -> Result<()> {
530 let exchange: u16 = rand::random();
531 log::debug!(
532 "write_request exch:{} endpoint:{} cluster:{} attr:{}",
533 exchange,
534 endpoint,
535 cluster,
536 attr,
537 );
538
539 let msg = messages::im_write_request(endpoint, cluster, attr, exchange, payload)?;
540 let res = self.active.request(exchange, &msg).await?;
541 if res.status_report_info.is_some() {
542 return Err(anyhow::anyhow!(
543 "write_request failed with status {:?}",
544 res.status_report_info
545 ))
546 };
547 if res.protocol_header.protocol_id
548 == messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
549 && res.protocol_header.opcode
550 == messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
551 {
552 let stat = res
553 .tlv
554 .get_int(&[0])
555 .context("status not found in status response")?;
556 res.tlv.dump(1);
557 return Err(anyhow::anyhow!(
558 "response is not expected status_resp 0x{:x}",
559 stat
560 ))
561 };
562 if res.protocol_header.protocol_id
563 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
564 || res.protocol_header.opcode
565 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_WRITE_RESP
566 {
567 return Err(anyhow::anyhow!(
568 "response is not expected write_resp {:?}",
569 res.protocol_header
570 ))
571 };
572 let stat = res.tlv.get_int(&[0, 0, 1, 0]).context("status not found in write response")?;
573 if stat != 0 {
574 return Err(anyhow::anyhow!("write failed with status 0x{:x}", stat));
575 }
576 Ok(())
577 }
578
579 pub async fn subscribe_attrs(
588 &self,
589 endpoint: Option<u16>,
590 cluster: Option<u32>,
591 attr: Option<u32>,
592 keep_subscriptions: bool,
593 ) -> Result<Subscription> {
594 let exchange: u16 = rand::random();
595 log::debug!(
596 "subscribe_attrs exch:{} endpoint:{:?} cluster:{:?} attr:{:?} keep:{}",
597 exchange, endpoint, cluster, attr, keep_subscriptions
598 );
599 let msg = messages::im_subscribe_request_attr(endpoint, cluster, attr, exchange, keep_subscriptions)?;
600 self.subscribe_internal(exchange, &msg).await
601 }
602
603 pub async fn subscribe_events(
606 &self,
607 endpoint: Option<u16>,
608 cluster: Option<u32>,
609 event: Option<u32>,
610 keep_subscriptions: bool,
611 ) -> Result<Subscription> {
612 let exchange: u16 = rand::random();
613 log::debug!(
614 "subscribe_events exch:{} endpoint:{:?} cluster:{:?} event:{:?} keep:{}",
615 exchange, endpoint, cluster, event, keep_subscriptions
616 );
617 let msg = messages::im_subscribe_request_event(endpoint, cluster, event, exchange, keep_subscriptions)?;
618 self.subscribe_internal(exchange, &msg).await
619 }
620
621 async fn subscribe_internal(&self, exchange_id: u16, msg: &[u8]) -> Result<Subscription> {
622 let mut exchange = self.active.open_exchange(exchange_id);
623 exchange.send(msg).await?;
624 let priming = self.collect_reports(&mut exchange).await?;
625 let subscription_id = priming
626 .subscription_id
627 .context("priming report missing subscription id")?;
628
629 let rx = self.active.register_subscription(subscription_id);
632 let registry = self.active.subscriptions_handle();
633
634 let response = async {
635 let resp = exchange.recv().await?;
636 if resp.protocol_header.protocol_id
637 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
638 || resp.protocol_header.opcode
639 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_RESP
640 {
641 anyhow::bail!(
642 "response is not expected subscribe_resp {:?}",
643 resp.protocol_header
644 );
645 }
646 let sr = im::SubscribeResponse::parse(&resp.tlv)?;
647 if sr.subscription_id != subscription_id {
648 anyhow::bail!(
649 "subscribe response id {} does not match priming report id {}",
650 sr.subscription_id,
651 subscription_id
652 );
653 }
654 Ok(sr)
655 }
656 .await;
657
658 match response {
659 Ok(sr) => Ok(Subscription {
660 subscription_id,
661 max_interval: sr.max_interval,
662 priming_attribute_reports: priming.attribute_reports,
663 priming_event_reports: priming.event_reports,
664 rx,
665 registry,
666 }),
667 Err(e) => {
668 registry.lock().unwrap().remove(&subscription_id);
669 Err(e)
670 }
671 }
672 }
673
674 pub async fn im_unsubscribe_all(&self) -> Result<Message> {
677 let exchange: u16 = rand::random();
678 log::debug!("im_unsubscribe_all exch:{}", exchange);
679 let msg = messages::im_unsubscribe_all(exchange)?;
680 self.active.request(exchange, &msg).await
681 }
682
683 pub fn set_auto_status_response(&self, enabled: bool) {
687 self.active.set_auto_status_response(enabled);
688 }
689
690 pub async fn invoke_request_timed(
692 &self,
693 endpoint: u16,
694 cluster: u32,
695 command: u32,
696 payload: &[u8],
697 timeout: u16,
698 ) -> Result<Message> {
699 let exchange: u16 = rand::random();
700
701 let tr = messages::im_timed_request(exchange, timeout)?;
703 let result = self.active.request(exchange, &tr).await?;
704
705 if result.protocol_header.protocol_id
706 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
707 || result.protocol_header.opcode
708 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
709 {
710 return Err(anyhow::anyhow!(
711 "invoke_request_timed: unexpected response {:?}",
712 result
713 ));
714 }
715 let status = result
716 .tlv
717 .get_int(&[0])
718 .context("invoke_request_timed: status not found")?;
719 if status != 0 {
720 return Err(anyhow::anyhow!(
721 "invoke_request_timed: unexpected status {}",
722 status
723 ));
724 }
725
726 log::debug!(
727 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
728 exchange,
729 endpoint,
730 cluster,
731 command
732 );
733 let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
734 self.active.request(exchange, &msg).await
735 }
736
737 pub async fn recv_event(&self) -> Option<Message> {
743 self.active.recv_event().await
744 }
745
746 pub fn try_recv_event(&self) -> Option<Message> {
748 self.active.try_recv_event()
749 }
750
751 pub async fn reauth(
755 &self,
756 controller: &Controller,
757 node_id: u64,
758 controller_id: u64,
759 ) -> Result<()> {
760 self.active.pause_read_loop().await;
761 let new_session = controller
762 .auth_sigma_with_busy_retry(&self.active.transport_conn, node_id, controller_id)
763 .await?;
764 self.active.reauth_with_session(new_session).await
765 }
766}
767
768pub struct Subscription {
774 pub subscription_id: u32,
775 pub max_interval: u16,
777 pub priming_attribute_reports: Vec<im::AttributeReport>,
779 pub priming_event_reports: Vec<im::EventReport>,
781 rx: mpsc::Receiver<im::ReportUpdate>,
782 registry: Arc<std::sync::Mutex<HashMap<u32, mpsc::Sender<im::ReportUpdate>>>>,
783}
784
785impl Subscription {
786 pub async fn next(&mut self) -> Option<im::ReportUpdate> {
789 self.rx.recv().await
790 }
791}
792
793impl Drop for Subscription {
794 fn drop(&mut self) {
795 self.registry.lock().unwrap().remove(&self.subscription_id);
796 }
797}
798
799pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
800 let mut out = Vec::new();
801 out.write_u32::<LittleEndian>(pin)?;
802 Ok(out)
803}
804
805pub(crate) async fn auth_spake(connection: &dyn ConnectionTrait, pin: u32) -> Result<session::Session> {
806 let exchange = rand::random();
807 log::debug!("start auth_spake");
808 let mut session = session::Session::new();
809 session.my_session_id = 1;
810 let mut retrctx = retransmit::RetrContext::new(connection, &session);
811 log::debug!("send pbkdf request");
813 let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
814 retrctx.send(&pbkdf_req_protocol_message).await?;
815
816 let pbkdf_response = retrctx.get_next_message().await?;
818 if pbkdf_response.protocol_header.protocol_id
819 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
820 || pbkdf_response.protocol_header.opcode
821 != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
822 {
823 return Err(anyhow::anyhow!("pbkdf response not received"));
824 }
825
826 let iterations = pbkdf_response
827 .tlv
828 .get_int(&[4, 1])
829 .context("pbkdf_response - iterations missing")?;
830 let salt = pbkdf_response
831 .tlv
832 .get_octet_string(&[4, 2])
833 .context("pbkdf_response - salt missing")?;
834 let p_session = pbkdf_response
835 .tlv
836 .get_int(&[3])
837 .context("pbkdf_response - session missing")?;
838
839 let engine = spake2p::Engine::new()?;
841 let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
842 log::debug!("send pake1 request");
843 let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
844 retrctx.send(&pake1_protocol_message).await?;
845
846 let pake2 = retrctx.get_next_message().await?;
848 if pake2.protocol_header.protocol_id
849 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
850 || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
851 {
852 return Err(anyhow::anyhow!("pake2 not received"));
853 }
854 let pake2_pb = pake2
855 .tlv
856 .get_octet_string(&[1])
857 .context("pake2 pb tlv missing")?;
858 ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
859
860 let pake2_cb = pake2
861 .tlv
862 .get_octet_string(&[2])
863 .context("pake2 cb tlv missing")?;
864
865 let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
867 hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
868 hash_seed.extend_from_slice(&pbkdf_response.payload);
869 engine.finish(&mut ctx, &hash_seed, pake2_cb)?;
870 let pake3_protocol_message = messages::pake3(
871 exchange,
872 &ctx.ca.context("ca value not present in context")?,
873 -1,
874 )?;
875 log::debug!("send pake3 request");
876 retrctx.send(&pake3_protocol_message).await?;
877
878 let pake3_resp = retrctx.get_next_message().await?;
879 match &pake3_resp.status_report_info {
880 Some(s) => {
881 if !s.is_ok() {
882 return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
883 }
884 }
885 None => {
886 return Err(anyhow::anyhow!(
887 "expecting status report (pake3 resp), got {:?}",
888 pake3_resp
889 ))
890 }
891 }
892
893 session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
894 session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
895 session.session_id = p_session as u16;
896 log::debug!("auth_spake ok; session: {}", session.session_id);
897 Ok(session)
898}
899
900pub(crate) async fn auth_sigma(
901 connection: &dyn ConnectionTrait,
902 fabric: &fabric::Fabric,
903 cm: &dyn certmanager::CertManager,
904 node_id: u64,
905 controller_id: u64,
906) -> Result<(session::Session, Option<sigma::ResumptionRecord>)> {
907 log::debug!("auth_sigma");
908 let exchange = rand::random();
909 let session = session::Session::new();
910 let mut retrctx = retransmit::RetrContext::new(connection, &session);
911 retrctx.subscribe_exchange(exchange);
912 let mut ctx = sigma::SigmaContext::new(node_id);
913 let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
914 sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
915 let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
916
917 log::debug!("send sigma1 {}", exchange);
918 retrctx.send(&s1).await?;
919
920 log::debug!("receive sigma2 {}", exchange);
922 let sigma2 = retrctx.get_next_message().await?;
923 log::debug!("sigma2 received {:?}", sigma2);
924 if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
925 && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
926 {
927 let sri = sigma2.status_report_info.context("status report info missing")?;
928 if sri.is_busy() {
929 return Err(anyhow::Error::new(SigmaBusy { wait_ms: sri.minimum_wait_time_ms() }));
930 }
931 return Err(anyhow::anyhow!("sigma2 not received, status: {}", sri));
932 }
933 ctx.sigma2_payload = sigma2.payload;
934 ctx.responder_session = sigma2
935 .tlv
936 .get_int(&[2])
937 .context("responder session tlv missing in sigma2")? as u16;
938 ctx.responder_public = sigma2
939 .tlv
940 .get_octet_string(&[3])
941 .context("responder public tlv missing in sigma2")?
942 .to_vec();
943
944 log::debug!("verify sigma2 {}", exchange);
945 let resumption_id =
946 sigma::verify_sigma2(fabric, &ctx, &ca_pubkey).context("sigma2 verification failed")?;
947
948 let controller_private = cm.get_user_key(controller_id)?;
949 let controller_x509 = cm.get_user_cert(controller_id)?;
950 let controller_matter_cert =
951 cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
952
953 log::debug!("send sigma3 {} with piggyback ack for {}", exchange, sigma2.message_header.message_counter);
955 sigma::sigma3(
956 fabric,
957 &mut ctx,
958 &controller_private.to_sec1_der()?,
959 &controller_matter_cert,
960 )?;
961 let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload, sigma2.message_header.message_counter)?;
962 retrctx.send(&sigma3).await?;
963
964 log::debug!("receive result {}", exchange);
965 let status = retrctx.get_next_message().await?;
966 if !status
967 .status_report_info
968 .as_ref()
969 .context("sigma3 status resp not received")?
970 .is_ok()
971 {
972 return Err(anyhow::anyhow!(format!(
973 "response to sigma3 does not contain status ok {:?}",
974 status
975 )));
976 }
977
978 let mut th = ctx.sigma1_payload.clone();
980 th.extend_from_slice(&ctx.sigma2_payload);
981
982 let mut transcript = th;
983 transcript.extend_from_slice(&ctx.sigma3_payload);
984 let transcript_hash = cryptoutil::sha256(&transcript);
985 let mut salt = fabric.signed_ipk()?;
986 salt.extend_from_slice(&transcript_hash);
987 let shared = ctx.shared.context("shared secret not in context")?;
988 let shared_bytes: [u8; 32] = shared.raw_secret_bytes().as_slice()
989 .try_into()
990 .map_err(|_| anyhow::anyhow!("shared secret wrong length"))?;
991 let keypack = cryptoutil::hkdf_sha256(
992 &salt,
993 &shared_bytes,
994 "SessionKeys".as_bytes(),
995 16 * 3,
996 )?;
997 let mut ses = session::Session::new();
998 ses.session_id = ctx.responder_session;
999 ses.my_session_id = ctx.session_id;
1000 ses.set_decrypt_key(&keypack[16..32]);
1001 ses.set_encrypt_key(&keypack[..16]);
1002
1003 let mut local_node = Vec::new();
1004 local_node.write_u64::<LittleEndian>(controller_id)?;
1005 ses.local_node = Some(local_node);
1006
1007 let mut remote_node = Vec::new();
1008 remote_node.write_u64::<LittleEndian>(node_id)?;
1009 ses.remote_node = Some(remote_node);
1010
1011 let resumption = resumption_id
1012 .map(|id| sigma::ResumptionRecord { resumption_id: id, shared_secret: shared_bytes });
1013
1014 if resumption.is_none() {
1015 log::debug!("auth_sigma: responder did not include a NewResumptionID - resumption unavailable for node {}", node_id);
1016 }
1017
1018 Ok((ses, resumption))
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024 use crate::messages::ProtocolMessageHeader;
1025 use crate::tlv;
1026 use std::time::Duration;
1027
1028 struct MockConn {
1030 inbound: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
1031 outbound: mpsc::UnboundedSender<Vec<u8>>,
1032 reliable: bool,
1033 mrp: std::sync::Mutex<crate::mrp::MrpParameters>,
1034 }
1035
1036 #[async_trait::async_trait]
1037 impl ConnectionTrait for MockConn {
1038 async fn send(&self, data: &[u8]) -> Result<()> {
1039 self.outbound
1040 .send(data.to_vec())
1041 .map_err(|_| anyhow::anyhow!("mock closed"))
1042 }
1043 async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
1044 let mut rx = self.inbound.lock().await;
1045 match tokio::time::timeout(timeout, rx.recv()).await {
1046 Ok(Some(d)) => Ok(d),
1047 Ok(None) => Err(anyhow::anyhow!("mock closed")),
1048 Err(_) => Err(anyhow::anyhow!("timeout")),
1049 }
1050 }
1051 fn is_reliable(&self) -> bool {
1052 self.reliable
1053 }
1054 fn mrp_params(&self) -> crate::mrp::MrpParameters {
1055 *self.mrp.lock().unwrap()
1056 }
1057 fn set_mrp_params(&self, params: crate::mrp::MrpParameters) {
1058 *self.mrp.lock().unwrap() = params;
1059 }
1060 }
1061
1062 struct MockDevice {
1063 rx: mpsc::UnboundedReceiver<Vec<u8>>,
1064 tx: mpsc::Sender<Vec<u8>>,
1065 session: session::Session,
1066 }
1067
1068 impl MockDevice {
1069 async fn recv(&mut self) -> Message {
1070 let data = tokio::time::timeout(Duration::from_secs(2), self.rx.recv())
1071 .await
1072 .expect("timeout waiting for controller message")
1073 .expect("mock closed");
1074 Message::decode(&data).unwrap()
1075 }
1076
1077 async fn expect_status_response(&mut self, want_flags: u8, want_ack: u32) {
1078 let msg = self.recv().await;
1079 assert_eq!(
1080 msg.protocol_header.protocol_id,
1081 ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
1082 );
1083 assert_eq!(
1084 msg.protocol_header.opcode,
1085 ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
1086 );
1087 assert_eq!(
1088 msg.protocol_header.exchange_flags,
1089 ProtocolMessageHeader::FLAG_RELIABILITY | want_flags
1090 );
1091 assert_eq!(msg.protocol_header.ack_counter, want_ack);
1092 assert_eq!(msg.tlv.get_int(&[0]), Some(0));
1093 }
1094
1095 async fn expect_silence(&mut self) {
1096 assert!(
1097 tokio::time::timeout(Duration::from_millis(200), self.rx.recv())
1098 .await
1099 .is_err(),
1100 "unexpected message from controller"
1101 );
1102 }
1103
1104 async fn send(&self, payload: &[u8]) -> u32 {
1105 let encoded = self.session.encode_message(payload).unwrap();
1106 let (header, _) = messages::MessageHeader::decode(&encoded).unwrap();
1107 self.tx.send(encoded).await.unwrap();
1108 header.message_counter
1109 }
1110
1111 async fn recv_within(&mut self, d: Duration) -> Option<Message> {
1112 match tokio::time::timeout(d, self.rx.recv()).await {
1113 Ok(Some(data)) => Some(Message::decode(&data).unwrap()),
1114 _ => None,
1115 }
1116 }
1117 }
1118
1119 fn mock_pair() -> (Connection, MockDevice) {
1120 mock_pair_with(true, Default::default())
1121 }
1122
1123 fn mock_pair_unreliable(mrp: crate::mrp::MrpParameters) -> (Connection, MockDevice) {
1124 mock_pair_with(false, mrp)
1125 }
1126
1127 fn mock_pair_with(reliable: bool, mrp: crate::mrp::MrpParameters) -> (Connection, MockDevice) {
1128 let (to_ctrl_tx, to_ctrl_rx) = mpsc::channel(32);
1129 let (to_dev_tx, to_dev_rx) = mpsc::unbounded_channel();
1130 let mock = Arc::new(MockConn {
1131 inbound: tokio::sync::Mutex::new(to_ctrl_rx),
1132 outbound: to_dev_tx,
1133 reliable,
1134 mrp: std::sync::Mutex::new(mrp),
1135 });
1136 let conn = Connection::from_parts(mock, session::Session::new());
1137 let device = MockDevice {
1138 rx: to_dev_rx,
1139 tx: to_ctrl_tx,
1140 session: session::Session::new(),
1141 };
1142 (conn, device)
1143 }
1144
1145 fn report_data(
1146 exchange: u16,
1147 flags: u8,
1148 sub_id: Option<u32>,
1149 values: &[(u16, bool)],
1150 more: bool,
1151 suppress: bool,
1152 ) -> Vec<u8> {
1153 let b = ProtocolMessageHeader {
1154 exchange_flags: flags,
1155 opcode: ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA,
1156 exchange_id: exchange,
1157 protocol_id: ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
1158 ack_counter: 0,
1159 }
1160 .encode()
1161 .unwrap();
1162 let mut t = tlv::TlvBuffer::from_vec(b);
1163 t.write_anon_struct().unwrap();
1164 if let Some(id) = sub_id {
1165 t.write_uint32(0, id).unwrap();
1166 }
1167 t.write_array(1).unwrap();
1168 for (endpoint, value) in values {
1169 t.write_anon_struct().unwrap();
1170 t.write_struct(1).unwrap();
1171 t.write_uint32(0, 0).unwrap();
1172 t.write_list(1).unwrap();
1173 t.write_uint16(2, *endpoint).unwrap();
1174 t.write_uint32(3, 6).unwrap();
1175 t.write_uint32(4, 0).unwrap();
1176 t.write_struct_end().unwrap();
1177 t.write_bool(2, *value).unwrap();
1178 t.write_struct_end().unwrap();
1179 t.write_struct_end().unwrap();
1180 }
1181 t.write_struct_end().unwrap();
1182 if more {
1183 t.write_bool(3, true).unwrap();
1184 }
1185 if suppress {
1186 t.write_bool(4, true).unwrap();
1187 }
1188 t.write_struct_end().unwrap();
1189 t.data
1190 }
1191
1192 fn subscribe_response(exchange: u16, sub_id: u32, max_interval: u16) -> Vec<u8> {
1193 let b = ProtocolMessageHeader {
1194 exchange_flags: 0,
1195 opcode: ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_RESP,
1196 exchange_id: exchange,
1197 protocol_id: ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
1198 ack_counter: 0,
1199 }
1200 .encode()
1201 .unwrap();
1202 let mut t = tlv::TlvBuffer::from_vec(b);
1203 t.write_anon_struct().unwrap();
1204 t.write_uint32(0, sub_id).unwrap();
1205 t.write_uint16(2, max_interval).unwrap();
1206 t.write_struct_end().unwrap();
1207 t.data
1208 }
1209
1210 const FLAGS_RESPONDER: u8 = 0;
1211 const FLAGS_DEVICE_INITIATED: u8 = ProtocolMessageHeader::FLAG_INITIATOR;
1212 const ACK_AND_INITIATOR: u8 =
1213 ProtocolMessageHeader::FLAG_INITIATOR | ProtocolMessageHeader::FLAG_ACK;
1214
1215 #[tokio::test]
1216 async fn test_read_request2_single_chunk() {
1217 let (conn, mut device) = mock_pair();
1218 let task = tokio::spawn(async move {
1219 let req = device.recv().await;
1220 assert_eq!(
1221 req.protocol_header.opcode,
1222 ProtocolMessageHeader::INTERACTION_OPCODE_READ_REQ
1223 );
1224 let exchange = req.protocol_header.exchange_id;
1225 device
1226 .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(1, true)], false, true))
1227 .await;
1228 device.expect_silence().await;
1229 });
1230 let val = conn.read_request2(1, 6, 0).await.unwrap();
1231 assert_eq!(val, TlvItemValue::Bool(true));
1232 task.await.unwrap();
1233 }
1234
1235 #[tokio::test]
1236 async fn test_read_request2_chunked() {
1237 let (conn, mut device) = mock_pair();
1238 let task = tokio::spawn(async move {
1239 let req = device.recv().await;
1240 let exchange = req.protocol_header.exchange_id;
1241 let counter = device
1242 .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(1, true)], true, false))
1243 .await;
1244 device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1245 device
1246 .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(2, false)], false, true))
1247 .await;
1248 device.expect_silence().await;
1249 });
1250 let val = conn.read_request2(1, 6, 0).await.unwrap();
1251 assert_eq!(val, TlvItemValue::Bool(true));
1252 task.await.unwrap();
1253 }
1254
1255 #[tokio::test]
1256 async fn test_subscribe_and_updates() {
1257 let (conn, mut device) = mock_pair();
1258 let task = tokio::spawn(async move {
1259 let req = device.recv().await;
1260 assert_eq!(
1261 req.protocol_header.opcode,
1262 ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_REQ
1263 );
1264 let exchange = req.protocol_header.exchange_id;
1265 let counter = device
1266 .send(&report_data(exchange, FLAGS_RESPONDER, Some(7), &[(1, true)], true, false))
1267 .await;
1268 device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1269 let counter = device
1270 .send(&report_data(exchange, FLAGS_RESPONDER, Some(7), &[(2, false)], false, false))
1271 .await;
1272 device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1273 device.send(&subscribe_response(exchange, 7, 60)).await;
1274
1275 let counter = device
1277 .send(&report_data(0x4001, FLAGS_DEVICE_INITIATED, Some(7), &[(1, false)], false, false))
1278 .await;
1279 device
1280 .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1281 .await;
1282 device
1283 });
1284
1285 let mut sub = conn.subscribe_attrs(Some(1), Some(6), Some(0), false).await.unwrap();
1286 assert_eq!(sub.subscription_id, 7);
1287 assert_eq!(sub.max_interval, 60);
1288 assert_eq!(sub.priming_attribute_reports.len(), 2);
1289 assert_eq!(sub.priming_attribute_reports[0].path.endpoint, Some(1));
1290 assert_eq!(sub.priming_attribute_reports[1].path.endpoint, Some(2));
1291
1292 let update = sub.next().await.unwrap();
1293 assert_eq!(update.subscription_id, 7);
1294 assert_eq!(update.attribute_reports.len(), 1);
1295 assert_eq!(
1296 update.attribute_reports[0].data,
1297 im::AttributeData::Value(TlvItemValue::Bool(false))
1298 );
1299 task.await.unwrap();
1300 }
1301
1302 #[tokio::test]
1303 async fn test_chunked_unsolicited_report() {
1304 let (conn, mut device) = mock_pair();
1305 let task = tokio::spawn(async move {
1306 let req = device.recv().await;
1307 let exchange = req.protocol_header.exchange_id;
1308 let counter = device
1309 .send(&report_data(exchange, FLAGS_RESPONDER, Some(9), &[(1, true)], false, false))
1310 .await;
1311 device.expect_status_response(ACK_AND_INITIATOR, counter).await;
1312 device.send(&subscribe_response(exchange, 9, 60)).await;
1313
1314 let counter = device
1316 .send(&report_data(0x4002, FLAGS_DEVICE_INITIATED, Some(9), &[(1, false)], true, false))
1317 .await;
1318 device
1319 .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1320 .await;
1321 let counter = device
1322 .send(&report_data(0x4002, FLAGS_DEVICE_INITIATED, Some(9), &[(2, true)], false, false))
1323 .await;
1324 device
1325 .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1326 .await;
1327 });
1328
1329 let mut sub = conn.subscribe_attrs(Some(1), Some(6), Some(0), false).await.unwrap();
1330 let update = sub.next().await.unwrap();
1331 assert_eq!(update.attribute_reports.len(), 2);
1332 assert_eq!(update.attribute_reports[0].path.endpoint, Some(1));
1333 assert_eq!(update.attribute_reports[1].path.endpoint, Some(2));
1334 task.await.unwrap();
1335 }
1336
1337 #[tokio::test]
1338 async fn test_unregistered_subscription_id() {
1339 let (conn, mut device) = mock_pair();
1340
1341 let counter = device
1342 .send(&report_data(0x4003, FLAGS_DEVICE_INITIATED, Some(99), &[(1, true)], false, false))
1343 .await;
1344 device
1345 .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1346 .await;
1347 let raw = conn.recv_event().await.unwrap();
1348 assert_eq!(
1349 raw.protocol_header.opcode,
1350 ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA
1351 );
1352
1353 conn.set_auto_status_response(false);
1354 device
1355 .send(&report_data(0x4004, FLAGS_DEVICE_INITIATED, Some(99), &[(1, true)], false, false))
1356 .await;
1357 device.expect_silence().await;
1358 let raw = conn.recv_event().await.unwrap();
1359 assert_eq!(raw.protocol_header.exchange_id, 0x4004);
1360 }
1361
1362 #[tokio::test]
1363 async fn test_duplicate_message_dropped() {
1364 let (conn, mut device) = mock_pair();
1365
1366 let payload =
1367 report_data(0x4005, FLAGS_DEVICE_INITIATED, Some(99), &[(1, true)], false, false);
1368 let encoded = device.session.encode_message(&payload).unwrap();
1369 let (header, _) = messages::MessageHeader::decode(&encoded).unwrap();
1370 device.tx.send(encoded.clone()).await.unwrap();
1371 device
1372 .expect_status_response(ProtocolMessageHeader::FLAG_ACK, header.message_counter)
1373 .await;
1374 let raw = conn.recv_event().await.unwrap();
1375 assert_eq!(raw.protocol_header.exchange_id, 0x4005);
1376
1377 device.tx.send(encoded).await.unwrap();
1379 device.expect_silence().await;
1380 assert!(conn.try_recv_event().is_none());
1381 }
1382
1383 #[tokio::test]
1384 async fn test_initiator_flag_not_misrouted() {
1385 let (conn, mut device) = mock_pair();
1386 let task = tokio::spawn(async move {
1387 let req = device.recv().await;
1388 let exchange = req.protocol_header.exchange_id;
1389 let counter = device
1392 .send(&report_data(exchange, FLAGS_DEVICE_INITIATED, None, &[(5, false)], false, false))
1393 .await;
1394 device
1395 .expect_status_response(ProtocolMessageHeader::FLAG_ACK, counter)
1396 .await;
1397 device
1398 .send(&report_data(exchange, FLAGS_RESPONDER, None, &[(1, true)], false, true))
1399 .await;
1400 });
1401 let val = conn.read_request2(1, 6, 0).await.unwrap();
1402 assert_eq!(val, TlvItemValue::Bool(true));
1403 task.await.unwrap();
1404 }
1405
1406 #[tokio::test(start_paused = true)]
1407 async fn test_retransmit_schedule_and_give_up() {
1408 let mrp = crate::mrp::MrpParameters::from_txt_ms(Some(5000), None, None);
1409 let (conn, mut device) = mock_pair_unreliable(mrp);
1410 let req = tokio::spawn(async move { conn.read_request2(1, 6, 0).await });
1411
1412 let mut times = Vec::new();
1413 let mut counters = Vec::new();
1414 for i in 0..crate::mrp::MRP_MAX_TRANSMISSIONS {
1415 let msg = device
1416 .recv_within(Duration::from_secs(30))
1417 .await
1418 .unwrap_or_else(|| panic!("missing transmission {}", i));
1419 times.push(tokio::time::Instant::now());
1420 counters.push(msg.message_header.message_counter);
1421 }
1422 assert!(counters.iter().all(|c| *c == counters[0]));
1423
1424 for (n, w) in times.windows(2).enumerate() {
1426 let gap = (w[1] - w[0]).as_secs_f64();
1427 let lower = 5.0 * 1.1 * 1.6f64.powi(n.saturating_sub(1) as i32);
1428 let upper = lower * 1.25;
1429 assert!(
1430 gap >= lower - 0.01 && gap <= upper + 0.1,
1431 "gap {} = {} not in [{}, {}]",
1432 n, gap, lower, upper
1433 );
1434 }
1435
1436 let res = req.await.unwrap();
1438 assert!(res.is_err(), "request should fail after give-up");
1439 assert!(
1440 device.recv_within(Duration::from_secs(120)).await.is_none(),
1441 "no transmissions expected after give-up"
1442 );
1443 }
1444
1445 #[tokio::test(start_paused = true)]
1446 async fn test_retransmit_stops_after_ack() {
1447 let (conn, mut device) = mock_pair_unreliable(Default::default());
1448 let _req = tokio::spawn(async move { conn.read_request2(1, 6, 0).await });
1449
1450 let msg = device.recv_within(Duration::from_secs(5)).await.expect("request");
1451 let retr = device.recv_within(Duration::from_secs(5)).await.expect("retransmit");
1452 assert_eq!(
1453 msg.message_header.message_counter,
1454 retr.message_header.message_counter
1455 );
1456
1457 device
1458 .send(&messages::ack(
1459 msg.protocol_header.exchange_id,
1460 msg.message_header.message_counter as i64,
1461 ).unwrap())
1462 .await;
1463 assert!(
1464 device.recv_within(Duration::from_secs(60)).await.is_none(),
1465 "no retransmissions expected after ack"
1466 );
1467 }
1468
1469 #[tokio::test(start_paused = true)]
1470 async fn test_retransmit_not_starved_by_inbound_traffic() {
1471 let (conn, mut device) = mock_pair_unreliable(Default::default());
1472 let _req = tokio::spawn(async move { conn.read_request2(1, 6, 0).await });
1473
1474 let first = device.recv_within(Duration::from_secs(2)).await.expect("request");
1475 let counter = first.message_header.message_counter;
1476
1477 let mut seen_retransmit = false;
1480 for _ in 0..10 {
1481 device.send(&messages::ack(0x7777, 999_999).unwrap()).await;
1482 tokio::time::sleep(Duration::from_millis(100)).await;
1483 while let Ok(data) = device.rx.try_recv() {
1484 let m = Message::decode(&data).unwrap();
1485 if m.message_header.message_counter == counter {
1486 seen_retransmit = true;
1487 }
1488 }
1489 }
1490 assert!(seen_retransmit, "retransmit starved by continuous inbound traffic");
1491 }
1492}