matc/
controller.rs

1use std::sync::Arc;
2
3use crate::{
4    active_connection::ActiveConnection,
5    cert_matter, certmanager, commission, fabric,
6    messages::{self, Message},
7    retransmit, session, sigma, spake2p,
8    tlv::TlvItemValue,
9    transport,
10    util::cryptoutil,
11};
12use anyhow::{Context, Result};
13use byteorder::{LittleEndian, WriteBytesExt};
14
15pub struct Controller {
16    certmanager: Arc<dyn certmanager::CertManager>,
17    #[allow(dead_code)]
18    transport: Arc<transport::Transport>,
19    fabric: fabric::Fabric,
20}
21
22pub struct Connection {
23    active: ActiveConnection,
24}
25//trait IsSync: Sync {}
26//impl IsSync for Controller {}
27
28const CA_ID: u64 = 1;
29
30impl Controller {
31    pub fn new(
32        certmanager: &Arc<dyn certmanager::CertManager>,
33        transport: &Arc<transport::Transport>,
34        fabric_id: u64,
35    ) -> Result<Arc<Self>> {
36        let fabric = fabric::Fabric::new(fabric_id, CA_ID, &certmanager.get_ca_public_key()?);
37        Ok(Arc::new(Self {
38            certmanager: certmanager.clone(),
39            transport: transport.clone(),
40            fabric,
41        }))
42    }
43
44    /// commission device
45    /// - authenticate using pin
46    /// - push CA certificate to device
47    /// - sign device's certificate
48    /// - set controller id - user which can control device
49    /// - return authenticated connection which can be used to send additional commands
50    pub async fn commission(
51        &self,
52        connection: &Arc<transport::Connection>,
53        pin: u32,
54        node_id: u64,
55        controller_id: u64,
56    ) -> Result<Connection> {
57        let mut session = auth_spake(connection, pin).await?;
58        let session = commission::commission(
59            connection,
60            &mut session,
61            &self.fabric,
62            self.certmanager.as_ref(),
63            node_id,
64            controller_id,
65        )
66        .await?;
67        Ok(Connection {
68            active: ActiveConnection::new(connection.clone(), session),
69        })
70    }
71
72    /// create authenticated connection to control device
73    pub async fn auth_sigma(
74        &self,
75        connection: &Arc<transport::Connection>,
76        node_id: u64,
77        controller_id: u64,
78    ) -> Result<Connection> {
79        let session = auth_sigma(
80            connection,
81            &self.fabric,
82            self.certmanager.as_ref(),
83            node_id,
84            controller_id,
85        )
86        .await?;
87        Ok(Connection {
88            active: ActiveConnection::new(connection.clone(), session),
89        })
90    }
91}
92
93/// Authenticated virtual connection can be used to send commands to device.
94impl Connection {
95    /// Read attribute from device and return parsed matter protocol response.
96    pub async fn read_request(
97        &self,
98        endpoint: u16,
99        cluster: u32,
100        attr: u32,
101    ) -> Result<Message> {
102        let exchange: u16 = rand::random();
103        let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
104        self.active.request(exchange, &msg).await
105    }
106
107    /// Read attribute from device and return tlv with attribute value.
108    pub async fn read_request2(
109        &self,
110        endpoint: u16,
111        cluster: u32,
112        attr: u32,
113    ) -> Result<TlvItemValue> {
114        let res = self.read_request(endpoint, cluster, attr).await?;
115        if (res.protocol_header.protocol_id
116            != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION)
117            || (res.protocol_header.opcode
118                != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA)
119        {
120            Err(anyhow::anyhow!(
121                "response is not expected report_data {:?}",
122                res.protocol_header
123            ))
124        } else {
125            match res.tlv.get(&[1, 0, 1, 2]) {
126                Some(a) => Ok(a.clone()),
127                None => {
128                    let s = res
129                        .tlv
130                        .get(&[1, 0, 0, 1, 0])
131                        .context("report data format not recognized1")?;
132                    if let TlvItemValue::Int(status) = s {
133                        Err(anyhow::anyhow!("report data with status {}", status))
134                    } else {
135                        Err(anyhow::anyhow!("report data format not recognized2"))
136                    }
137                }
138            }
139        }
140    }
141
142    /// Invoke command
143    pub async fn invoke_request(
144        &self,
145        endpoint: u16,
146        cluster: u32,
147        command: u32,
148        payload: &[u8],
149    ) -> Result<Message> {
150        let exchange: u16 = rand::random();
151        log::debug!(
152            "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
153            exchange,
154            endpoint,
155            cluster,
156            command
157        );
158        let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
159        self.active.request(exchange, &msg).await
160    }
161
162    /// Invoke command and return result TLV
163    pub async fn invoke_request2(
164        &self,
165        endpoint: u16,
166        cluster: u32,
167        command: u32,
168        payload: &[u8],
169    ) -> Result<TlvItemValue> {
170        let res = self.invoke_request(endpoint, cluster, command, payload).await?;
171        let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
172        Ok(o.clone())
173    }
174
175    pub async fn im_subscribe_request(
176        &self,
177        endpoint: u16,
178        cluster: u32,
179        event: u32,
180    ) -> Result<Message> {
181        let exchange: u16 = rand::random();
182        log::debug!(
183            "im_subscribe_request exch:{} endpoint:{} cluster:{} event:{}",
184            exchange,
185            endpoint,
186            cluster,
187            event
188        );
189        let msg = messages::im_subscribe_request(endpoint, cluster, exchange, event)?;
190        self.active.request(exchange, &msg).await
191    }
192
193    pub async fn im_status_response(
194        &self,
195        exchange: u16,
196        flags: u8,
197        ack: u32
198    ) -> Result<()> {
199        let msg = messages::im_status_response(exchange, flags, ack)?;
200        self.active.send(&msg).await
201    }
202
203    /// Invoke command with timed interaction
204    pub async fn invoke_request_timed(
205        &self,
206        endpoint: u16,
207        cluster: u32,
208        command: u32,
209        payload: &[u8],
210        timeout: u16,
211    ) -> Result<Message> {
212        let exchange: u16 = rand::random();
213
214        // Send timed request first
215        let tr = messages::im_timed_request(exchange, timeout)?;
216        let result = self.active.request(exchange, &tr).await?;
217
218        if result.protocol_header.protocol_id
219            != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
220            || result.protocol_header.opcode
221                != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
222        {
223            return Err(anyhow::anyhow!(
224                "invoke_request_timed: unexpected response {:?}",
225                result
226            ));
227        }
228        let status = result
229            .tlv
230            .get_int(&[0])
231            .context("invoke_request_timed: status not found")?;
232        if status != 0 {
233            return Err(anyhow::anyhow!(
234                "invoke_request_timed: unexpected status {}",
235                status
236            ));
237        }
238
239        log::debug!(
240            "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
241            exchange,
242            endpoint,
243            cluster,
244            command
245        );
246        let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
247        self.active.request(exchange, &msg).await
248    }
249
250    /// Receive next event (for subscriptions). Returns None when connection is closed.
251    pub async fn recv_event(&self) -> Option<Message> {
252        self.active.recv_event().await
253    }
254
255    /// Try receive event without blocking.
256    pub fn try_recv_event(&self) -> Option<Message> {
257        self.active.try_recv_event()
258    }
259}
260
261pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
262    let mut out = Vec::new();
263    out.write_u32::<LittleEndian>(pin)?;
264    Ok(out)
265}
266
267async fn auth_spake(connection: &transport::Connection, pin: u32) -> Result<session::Session> {
268    let exchange = rand::random();
269    log::debug!("start auth_spake");
270    let mut session = session::Session::new();
271    session.my_session_id = 1;
272    let mut retrctx = retransmit::RetrContext::new(connection, &session);
273    // send pbkdf
274    log::debug!("send pbkdf request");
275    let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
276    retrctx.send(&pbkdf_req_protocol_message).await?;
277
278    // get pbkdf response
279    let pbkdf_response = retrctx.get_next_message().await?;
280    if pbkdf_response.protocol_header.protocol_id
281        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
282        || pbkdf_response.protocol_header.opcode
283            != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
284    {
285        return Err(anyhow::anyhow!("pbkdf response not received"));
286    }
287
288    let iterations = pbkdf_response
289        .tlv
290        .get_int(&[4, 1])
291        .context("pbkdf_response - iterations missing")?;
292    let salt = pbkdf_response
293        .tlv
294        .get_octet_string(&[4, 2])
295        .context("pbkdf_response - salt missing")?;
296    let p_session = pbkdf_response
297        .tlv
298        .get_int(&[3])
299        .context("pbkdf_response - session missing")?;
300
301    // send pake1
302    let engine = spake2p::Engine::new()?;
303    let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
304    log::debug!("send pake1 request");
305    let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
306    retrctx.send(&pake1_protocol_message).await?;
307
308    // receive pake2
309    let pake2 = retrctx.get_next_message().await?;
310    if pake2.protocol_header.protocol_id
311        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
312        || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
313    {
314        return Err(anyhow::anyhow!("pake2 not received"));
315    }
316    let pake2_pb = pake2
317        .tlv
318        .get_octet_string(&[1])
319        .context("pake2 pb tlv missing")?;
320    ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
321
322    // send pake3
323    let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
324    hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
325    hash_seed.extend_from_slice(&pbkdf_response.payload);
326    engine.finish(&mut ctx, &hash_seed)?;
327    let pake3_protocol_message = messages::pake3(
328        exchange,
329        &ctx.ca.context("ca value not present in context")?,
330        -1,
331    )?;
332    log::debug!("send pake3 request");
333    retrctx.send(&pake3_protocol_message).await?;
334
335    let pake3_resp = retrctx.get_next_message().await?;
336    match &pake3_resp.status_report_info {
337        Some(s) => {
338            if !s.is_ok() {
339                return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
340            }
341        }
342        None => {
343            return Err(anyhow::anyhow!(
344                "expecting status report (pake3 resp), got {:?}",
345                pake3_resp
346            ))
347        }
348    }
349
350    session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
351    session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
352    session.session_id = p_session as u16;
353    log::debug!("auth_spake ok; session: {}", session.session_id);
354    Ok(session)
355}
356
357pub(crate) async fn auth_sigma(
358    connection: &transport::Connection,
359    fabric: &fabric::Fabric,
360    cm: &dyn certmanager::CertManager,
361    node_id: u64,
362    controller_id: u64,
363) -> Result<session::Session> {
364    log::debug!("auth_sigma");
365    let exchange = rand::random();
366    let mut session = session::Session::new();
367    let mut retrctx = retransmit::RetrContext::new(connection, &mut session);
368    retrctx.subscribe_exchange(exchange);
369    let mut ctx = sigma::SigmaContext::new(node_id);
370    let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
371    sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
372    let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
373
374    log::debug!("send sigma1 {}", exchange);
375    retrctx.send(&s1).await?;
376
377    // receive sigma2
378    log::debug!("receive sigma2 {}", exchange);
379    let sigma2 = retrctx.get_next_message().await?;
380    log::debug!("sigma2 received {:?}", sigma2);
381    if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
382        && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
383    {
384        return Err(anyhow::anyhow!("sigma2 not received, status: {}", sigma2.status_report_info.context("status report info missing")?.to_string()));
385    }
386    ctx.sigma2_payload = sigma2.payload;
387    ctx.responder_session = sigma2
388        .tlv
389        .get_int(&[2])
390        .context("responder session tlv missing in sigma2")? as u16;
391    ctx.responder_public = sigma2
392        .tlv
393        .get_octet_string(&[3])
394        .context("responder public tlv missing in sigma2")?
395        .to_vec();
396
397    let controller_private = cm.get_user_key(controller_id)?;
398    let controller_x509 = cm.get_user_cert(controller_id)?;
399    let controller_matter_cert =
400        cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
401
402    // send sigma3
403    log::debug!("send sigma3 {} with piggyback ack for {}", exchange, sigma2.message_header.message_counter);
404    sigma::sigma3(
405        fabric,
406        &mut ctx,
407        &controller_private.to_sec1_der()?,
408        &controller_matter_cert,
409    )?;
410    let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload, sigma2.message_header.message_counter)?;
411    retrctx.send(&sigma3).await?;
412
413    log::debug!("receive result {}", exchange);
414    let status = retrctx.get_next_message().await?;
415    if !status
416        .status_report_info
417        .context("sigma3 status resp not received")?
418        .is_ok()
419    {
420        return Err(anyhow::anyhow!(format!(
421            "response to sigma3 does not contain status ok {:?}",
422            status
423        )));
424    }
425
426    //session keys
427    let mut th = ctx.sigma1_payload.clone();
428    th.extend_from_slice(&ctx.sigma2_payload);
429
430    let mut transcript = th;
431    transcript.extend_from_slice(&ctx.sigma3_payload);
432    let transcript_hash = cryptoutil::sha256(&transcript);
433    let mut salt = fabric.signed_ipk()?;
434    salt.extend_from_slice(&transcript_hash);
435    let shared = ctx.shared.context("shared secret not in context")?;
436    let keypack = cryptoutil::hkdf_sha256(
437        &salt,
438        shared.raw_secret_bytes().as_slice(),
439        "SessionKeys".as_bytes(),
440        16 * 3,
441    )?;
442    let mut ses = session::Session::new();
443    ses.session_id = ctx.responder_session;
444    ses.my_session_id = ctx.session_id;
445    ses.set_decrypt_key(&keypack[16..32]);
446    ses.set_encrypt_key(&keypack[..16]);
447
448    let mut local_node = Vec::new();
449    local_node.write_u64::<LittleEndian>(controller_id)?;
450    ses.local_node = Some(local_node);
451
452    let mut remote_node = Vec::new();
453    remote_node.write_u64::<LittleEndian>(node_id)?;
454    ses.remote_node = Some(remote_node);
455
456    Ok(ses)
457}
458