matc/
controller.rs

1use std::sync::Arc;
2
3use crate::{
4    active_connection::ActiveConnection,
5    cert_matter, certmanager, commission, fabric,
6    messages::{self, Message},
7    retransmit, session, sigma, spake2p,
8    tlv::TlvItemValue,
9    transport::{self, ConnectionTrait},
10    util::cryptoutil,
11};
12use anyhow::{Context, Result};
13use byteorder::{LittleEndian, WriteBytesExt};
14
15pub struct Controller {
16    certmanager: Arc<dyn certmanager::CertManager>,
17    #[allow(dead_code)]
18    transport: Arc<transport::Transport>,
19    fabric: fabric::Fabric,
20}
21
22pub struct Connection {
23    active: ActiveConnection,
24}
25//trait IsSync: Sync {}
26//impl IsSync for Controller {}
27
28const CA_ID: u64 = 1;
29
30impl Controller {
31    pub fn new(
32        certmanager: &Arc<dyn certmanager::CertManager>,
33        transport: &Arc<transport::Transport>,
34        fabric_id: u64,
35    ) -> Result<Arc<Self>> {
36        let fabric = fabric::Fabric::new(fabric_id, CA_ID, &certmanager.get_ca_public_key()?);
37        Ok(Arc::new(Self {
38            certmanager: certmanager.clone(),
39            transport: transport.clone(),
40            fabric,
41        }))
42    }
43
44    /// commission device
45    /// - authenticate using pin
46    /// - push CA certificate to device
47    /// - sign device's certificate
48    /// - set controller id - user which can control device
49    /// - return authenticated connection which can be used to send additional commands
50    pub async fn commission(
51        &self,
52        connection: &Arc<dyn ConnectionTrait>,
53        pin: u32,
54        node_id: u64,
55        controller_id: u64,
56    ) -> Result<Connection> {
57        let mut session = auth_spake(connection.as_ref(), pin).await?;
58        let session = commission::commission(
59            connection.as_ref(),
60            &mut session,
61            &self.fabric,
62            self.certmanager.as_ref(),
63            node_id,
64            controller_id,
65        )
66        .await?;
67        Ok(Connection {
68            active: ActiveConnection::new(connection.clone(), session),
69        })
70    }
71
72    /// create authenticated connection to control device
73    pub async fn auth_sigma(
74        &self,
75        connection: &Arc<dyn ConnectionTrait>,
76        node_id: u64,
77        controller_id: u64,
78    ) -> Result<Connection> {
79        let session = auth_sigma(
80            connection.as_ref(),
81            &self.fabric,
82            self.certmanager.as_ref(),
83            node_id,
84            controller_id,
85        )
86        .await?;
87        Ok(Connection {
88            active: ActiveConnection::new(connection.clone(), session),
89        })
90    }
91
92    /// Commission a device that is advertising over BLE.
93    ///
94    /// 1. Scans for a commissionable BLE device with the given `discriminator`.
95    /// 2. Runs PASE over BTP (BLE transport protocol).
96    /// 3. Pushes the CA cert, signs the device cert (AddNOC).
97    /// 4. Sends ArmFailSafe + SetRegulatoryConfig.
98    /// 5. Optionally provisions network credentials (Wi-Fi / Thread).
99    /// 6. Drops the BLE connection.
100    /// 7. Discovers the device on the IP network via mDNS.
101    /// 8. Establishes CASE + sends CommissioningComplete over UDP.
102    /// 9. Returns an authenticated [`Connection`] ready for commands.
103    ///
104    /// Requires the `ble` Cargo feature.
105    #[cfg(feature = "ble")]
106    pub async fn commission_ble(
107        &self,
108        discriminator: u16,
109        short_discriminator: bool,
110        pin: u32,
111        node_id: u64,
112        controller_id: u64,
113        network_creds: commission::NetworkCreds,
114        mdns: &std::sync::Arc<crate::mdns2::MdnsService>,
115        mdns_receiver: &tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<crate::mdns2::MdnsEvent>>,
116    ) -> Result<Connection> {
117        use crate::{btp::BtpConnection, discover};
118
119        // 1. BLE scan + GATT connect + BTP handshake
120        let peripheral = crate::ble::find_by_discriminator(discriminator, short_discriminator, std::time::Duration::from_secs(30))
121            .await
122            .context("BLE scan")?;
123        log::debug!("BLE device found: z2");
124        let btp_conn = BtpConnection::connect(peripheral).await.context("BTP connect")?;
125
126        // 2. PASE
127        let mut pase_session = auth_spake(btp_conn.as_ref(), pin).await.context("PASE over BLE")?;
128
129        // 3. BLE-side commissioning phase
130        commission::commission_ble_phase(
131            btp_conn.as_ref(),
132            &mut pase_session,
133            &self.fabric,
134            self.certmanager.as_ref(),
135            node_id,
136            controller_id,
137            &network_creds,
138        )
139        .await
140        .context("BLE commissioning phase")?;
141
142        // 4. Drop BTP (BLE connection closes when btp_conn is dropped)
143        drop(btp_conn);
144
145        // 5. Rediscover device via operational mDNS
146        let ca_pubkey = self.certmanager.get_ca_public_key()?;
147        let fabric_tmp = fabric::Fabric::new(self.fabric.id, 0, &ca_pubkey);
148        let compressed = fabric_tmp.compressed().context("compressed fabric ID")?;
149        let instance = format!("{}-{:016X}", hex::encode_upper(&compressed), node_id);
150        let expected_target = format!("{}._matter._tcp.local.", instance);
151
152        let mut addresses = Vec::new();
153        {
154            let mut rx = mdns_receiver.lock().await;
155            mdns.active_lookup("_matter._tcp.local", 0xff).await;
156            loop {
157                match tokio::time::timeout(std::time::Duration::from_secs(30), rx.recv()).await {
158                    Ok(Some(crate::mdns2::MdnsEvent::ServiceDiscovered { name, records: _, target })) => {
159                        if name != "_matter._tcp.local." || target != expected_target {
160                            continue;
161                        }
162                        let info = discover::extract_matter_info(&target, mdns).await?;
163                        log::debug!("Operational mDNS discovered device: {:?}", info);
164
165                        let port = info.port.unwrap_or(5540);
166                        for ip in &info.ips {
167                            if ip.is_ipv6() {
168                                addresses.push(format!("[{}]:{}", ip, port));
169                            } else {
170                                addresses.push(format!("{}:{}", ip, port));
171                            }
172                        }
173                        break;
174                    }
175                    Ok(_) => continue,
176                    Err(_) => anyhow::bail!("operational mDNS timeout for {}", instance),
177                }
178            }
179        };
180
181        log::info!("Device discovered at {}", addresses.join(", "));
182
183        // 6. UDP connection + CASE + CommissioningComplete
184        for address in addresses {
185            log::debug!("Trying to commission over UDP at {}...", address);
186            let udp_conn = self.transport.create_connection(&address).await;
187            let ses = commission::commissioning_complete_udp(
188                udp_conn.as_ref(),
189                self.certmanager.as_ref(),
190                node_id,
191                controller_id,
192                &self.fabric,
193            )
194            .await;
195            if let Ok(ses) = ses {
196                return Ok(Connection {
197                    active: ActiveConnection::new(udp_conn, ses),
198                });
199            } else {
200                log::debug!("Failed to commission over UDP at {}: {:?}", address, ses.err());
201            }
202        }
203        Err(anyhow::anyhow!("failed to commission device over UDP at any discovered address"))
204    }
205}
206
207/// Authenticated virtual connection can be used to send commands to device.
208impl Connection {
209    /// Read attribute from device and return parsed matter protocol response.
210    pub async fn read_request(
211        &self,
212        endpoint: u16,
213        cluster: u32,
214        attr: u32,
215    ) -> Result<Message> {
216        let exchange: u16 = rand::random();
217        let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
218        self.active.request(exchange, &msg).await
219    }
220
221    /// Read attribute from device and return tlv with attribute value.
222    pub async fn read_request2(
223        &self,
224        endpoint: u16,
225        cluster: u32,
226        attr: u32,
227    ) -> Result<TlvItemValue> {
228        let res = self.read_request(endpoint, cluster, attr).await?;
229        if (res.protocol_header.protocol_id
230            != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION)
231            || (res.protocol_header.opcode
232                != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA)
233        {
234            Err(anyhow::anyhow!(
235                "response is not expected report_data {:?}",
236                res.protocol_header
237            ))
238        } else {
239            match res.tlv.get(&[1, 0, 1, 2]) {
240                Some(a) => Ok(a.clone()),
241                None => {
242                    let s = res
243                        .tlv
244                        .get(&[1, 0, 0, 1, 0])
245                        .context("report data format not recognized1")?;
246                    if let TlvItemValue::Int(status) = s {
247                        Err(anyhow::anyhow!("report data with status {}", status))
248                    } else {
249                        Err(anyhow::anyhow!("report data format not recognized2"))
250                    }
251                }
252            }
253        }
254    }
255
256    /// Invoke command
257    pub async fn invoke_request(
258        &self,
259        endpoint: u16,
260        cluster: u32,
261        command: u32,
262        payload: &[u8],
263    ) -> Result<Message> {
264        let exchange: u16 = rand::random();
265        log::debug!(
266            "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
267            exchange,
268            endpoint,
269            cluster,
270            command
271        );
272        let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
273        self.active.request(exchange, &msg).await
274    }
275
276    /// Invoke command and return result TLV
277    pub async fn invoke_request2(
278        &self,
279        endpoint: u16,
280        cluster: u32,
281        command: u32,
282        payload: &[u8],
283    ) -> Result<TlvItemValue> {
284        let res = self.invoke_request(endpoint, cluster, command, payload).await?;
285        let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
286        Ok(o.clone())
287    }
288
289    pub async fn im_subscribe_request(
290        &self,
291        endpoint: u16,
292        cluster: u32,
293        event: u32,
294    ) -> Result<Message> {
295        let exchange: u16 = rand::random();
296        log::debug!(
297            "im_subscribe_request exch:{} endpoint:{} cluster:{} event:{}",
298            exchange,
299            endpoint,
300            cluster,
301            event
302        );
303        let msg = messages::im_subscribe_request(endpoint, cluster, exchange, event)?;
304        self.active.request(exchange, &msg).await
305    }
306
307    /// Subscribe to attribute changes. Returns the initial ReportData message.
308    /// Set `keep_subscriptions = true` when adding a second subscription on the same
309    /// connection so the device does not cancel the first one.
310    pub async fn im_subscribe_request_attr(
311        &self,
312        endpoint: u16,
313        cluster: u32,
314        attr: u32,
315        keep_subscriptions: bool,
316    ) -> Result<Message> {
317        let exchange: u16 = rand::random();
318        log::debug!(
319            "im_subscribe_request_attr exch:{} endpoint:{} cluster:{} attr:{} keep:{}",
320            exchange, endpoint, cluster, attr, keep_subscriptions
321        );
322        let msg = messages::im_subscribe_request_attr(endpoint, cluster, attr, exchange, keep_subscriptions)?;
323        self.active.request(exchange, &msg).await
324    }
325
326    /// Cancel all subscriptions on this session by sending a SubscribeRequest with
327    /// `KeepSubscriptions = false` and no paths. The device drops all prior subscriptions.
328    pub async fn im_unsubscribe_all(&self) -> Result<Message> {
329        let exchange: u16 = rand::random();
330        log::debug!("im_unsubscribe_all exch:{}", exchange);
331        let msg = messages::im_unsubscribe_all(exchange)?;
332        self.active.request(exchange, &msg).await
333    }
334
335    pub async fn im_status_response(
336        &self,
337        exchange: u16,
338        flags: u8,
339        ack: u32
340    ) -> Result<()> {
341        let msg = messages::im_status_response(exchange, flags, ack)?;
342        self.active.send(&msg).await
343    }
344
345    /// Invoke command with timed interaction
346    pub async fn invoke_request_timed(
347        &self,
348        endpoint: u16,
349        cluster: u32,
350        command: u32,
351        payload: &[u8],
352        timeout: u16,
353    ) -> Result<Message> {
354        let exchange: u16 = rand::random();
355
356        // Send timed request first
357        let tr = messages::im_timed_request(exchange, timeout)?;
358        let result = self.active.request(exchange, &tr).await?;
359
360        if result.protocol_header.protocol_id
361            != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
362            || result.protocol_header.opcode
363                != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
364        {
365            return Err(anyhow::anyhow!(
366                "invoke_request_timed: unexpected response {:?}",
367                result
368            ));
369        }
370        let status = result
371            .tlv
372            .get_int(&[0])
373            .context("invoke_request_timed: status not found")?;
374        if status != 0 {
375            return Err(anyhow::anyhow!(
376                "invoke_request_timed: unexpected status {}",
377                status
378            ));
379        }
380
381        log::debug!(
382            "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
383            exchange,
384            endpoint,
385            cluster,
386            command
387        );
388        let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
389        self.active.request(exchange, &msg).await
390    }
391
392    /// Receive next event (for subscriptions). Returns None when connection is closed.
393    pub async fn recv_event(&self) -> Option<Message> {
394        self.active.recv_event().await
395    }
396
397    /// Try receive event without blocking.
398    pub fn try_recv_event(&self) -> Option<Message> {
399        self.active.try_recv_event()
400    }
401}
402
403pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
404    let mut out = Vec::new();
405    out.write_u32::<LittleEndian>(pin)?;
406    Ok(out)
407}
408
409pub(crate) async fn auth_spake(connection: &dyn ConnectionTrait, pin: u32) -> Result<session::Session> {
410    let exchange = rand::random();
411    log::debug!("start auth_spake");
412    let mut session = session::Session::new();
413    session.my_session_id = 1;
414    let mut retrctx = retransmit::RetrContext::new(connection, &session);
415    // send pbkdf
416    log::debug!("send pbkdf request");
417    let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
418    retrctx.send(&pbkdf_req_protocol_message).await?;
419
420    // get pbkdf response
421    let pbkdf_response = retrctx.get_next_message().await?;
422    if pbkdf_response.protocol_header.protocol_id
423        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
424        || pbkdf_response.protocol_header.opcode
425            != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
426    {
427        return Err(anyhow::anyhow!("pbkdf response not received"));
428    }
429
430    let iterations = pbkdf_response
431        .tlv
432        .get_int(&[4, 1])
433        .context("pbkdf_response - iterations missing")?;
434    let salt = pbkdf_response
435        .tlv
436        .get_octet_string(&[4, 2])
437        .context("pbkdf_response - salt missing")?;
438    let p_session = pbkdf_response
439        .tlv
440        .get_int(&[3])
441        .context("pbkdf_response - session missing")?;
442
443    // send pake1
444    let engine = spake2p::Engine::new()?;
445    let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
446    log::debug!("send pake1 request");
447    let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
448    retrctx.send(&pake1_protocol_message).await?;
449
450    // receive pake2
451    let pake2 = retrctx.get_next_message().await?;
452    if pake2.protocol_header.protocol_id
453        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
454        || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
455    {
456        return Err(anyhow::anyhow!("pake2 not received"));
457    }
458    let pake2_pb = pake2
459        .tlv
460        .get_octet_string(&[1])
461        .context("pake2 pb tlv missing")?;
462    ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
463
464    let pake2_cb = pake2
465        .tlv
466        .get_octet_string(&[2])
467        .context("pake2 cb tlv missing")?;
468
469    // send pake3
470    let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
471    hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
472    hash_seed.extend_from_slice(&pbkdf_response.payload);
473    engine.finish(&mut ctx, &hash_seed, pake2_cb)?;
474    let pake3_protocol_message = messages::pake3(
475        exchange,
476        &ctx.ca.context("ca value not present in context")?,
477        -1,
478    )?;
479    log::debug!("send pake3 request");
480    retrctx.send(&pake3_protocol_message).await?;
481
482    let pake3_resp = retrctx.get_next_message().await?;
483    match &pake3_resp.status_report_info {
484        Some(s) => {
485            if !s.is_ok() {
486                return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
487            }
488        }
489        None => {
490            return Err(anyhow::anyhow!(
491                "expecting status report (pake3 resp), got {:?}",
492                pake3_resp
493            ))
494        }
495    }
496
497    session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
498    session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
499    session.session_id = p_session as u16;
500    log::debug!("auth_spake ok; session: {}", session.session_id);
501    Ok(session)
502}
503
504pub(crate) async fn auth_sigma(
505    connection: &dyn ConnectionTrait,
506    fabric: &fabric::Fabric,
507    cm: &dyn certmanager::CertManager,
508    node_id: u64,
509    controller_id: u64,
510) -> Result<session::Session> {
511    log::debug!("auth_sigma");
512    let exchange = rand::random();
513    let session = session::Session::new();
514    let mut retrctx = retransmit::RetrContext::new(connection, &session);
515    retrctx.subscribe_exchange(exchange);
516    let mut ctx = sigma::SigmaContext::new(node_id);
517    let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
518    sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
519    let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
520
521    log::debug!("send sigma1 {}", exchange);
522    retrctx.send(&s1).await?;
523
524    // receive sigma2
525    log::debug!("receive sigma2 {}", exchange);
526    let sigma2 = retrctx.get_next_message().await?;
527    log::debug!("sigma2 received {:?}", sigma2);
528    if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
529        && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
530    {
531        return Err(anyhow::anyhow!("sigma2 not received, status: {}", sigma2.status_report_info.context("status report info missing")?.to_string()));
532    }
533    ctx.sigma2_payload = sigma2.payload;
534    ctx.responder_session = sigma2
535        .tlv
536        .get_int(&[2])
537        .context("responder session tlv missing in sigma2")? as u16;
538    ctx.responder_public = sigma2
539        .tlv
540        .get_octet_string(&[3])
541        .context("responder public tlv missing in sigma2")?
542        .to_vec();
543
544    let controller_private = cm.get_user_key(controller_id)?;
545    let controller_x509 = cm.get_user_cert(controller_id)?;
546    let controller_matter_cert =
547        cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
548
549    // send sigma3
550    log::debug!("send sigma3 {} with piggyback ack for {}", exchange, sigma2.message_header.message_counter);
551    sigma::sigma3(
552        fabric,
553        &mut ctx,
554        &controller_private.to_sec1_der()?,
555        &controller_matter_cert,
556    )?;
557    let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload, sigma2.message_header.message_counter)?;
558    retrctx.send(&sigma3).await?;
559
560    log::debug!("receive result {}", exchange);
561    let status = retrctx.get_next_message().await?;
562    if !status
563        .status_report_info
564        .context("sigma3 status resp not received")?
565        .is_ok()
566    {
567        return Err(anyhow::anyhow!(format!(
568            "response to sigma3 does not contain status ok {:?}",
569            status
570        )));
571    }
572
573    //session keys
574    let mut th = ctx.sigma1_payload.clone();
575    th.extend_from_slice(&ctx.sigma2_payload);
576
577    let mut transcript = th;
578    transcript.extend_from_slice(&ctx.sigma3_payload);
579    let transcript_hash = cryptoutil::sha256(&transcript);
580    let mut salt = fabric.signed_ipk()?;
581    salt.extend_from_slice(&transcript_hash);
582    let shared = ctx.shared.context("shared secret not in context")?;
583    let keypack = cryptoutil::hkdf_sha256(
584        &salt,
585        shared.raw_secret_bytes().as_slice(),
586        "SessionKeys".as_bytes(),
587        16 * 3,
588    )?;
589    let mut ses = session::Session::new();
590    ses.session_id = ctx.responder_session;
591    ses.my_session_id = ctx.session_id;
592    ses.set_decrypt_key(&keypack[16..32]);
593    ses.set_encrypt_key(&keypack[..16]);
594
595    let mut local_node = Vec::new();
596    local_node.write_u64::<LittleEndian>(controller_id)?;
597    ses.local_node = Some(local_node);
598
599    let mut remote_node = Vec::new();
600    remote_node.write_u64::<LittleEndian>(node_id)?;
601    ses.remote_node = Some(remote_node);
602
603    Ok(ses)
604}
605