matc/
controller.rs

1use std::sync::Arc;
2
3use crate::{
4    cert_matter, certmanager, commission, fabric,
5    messages::{self, Message},
6    retransmit, session, sigma, spake2p,
7    tlv::TlvItemValue,
8    transport,
9    util::cryptoutil,
10};
11use anyhow::{Context, Result};
12use byteorder::{LittleEndian, WriteBytesExt};
13
14pub struct Controller {
15    certmanager: Arc<dyn certmanager::CertManager>,
16    transport: Arc<transport::Transport>,
17    fabric: fabric::Fabric,
18}
19
20pub struct Connection {
21    connection: Arc<transport::Connection>,
22    session: session::Session,
23}
24//trait IsSync: Sync {}
25//impl IsSync for Controller {}
26
27const CA_ID: u64 = 1;
28
29impl Controller {
30    pub fn new(
31        certmanager: &Arc<dyn certmanager::CertManager>,
32        transport: &Arc<transport::Transport>,
33        fabric_id: u64,
34    ) -> Result<Arc<Self>> {
35        let fabric = fabric::Fabric::new(fabric_id, CA_ID, &certmanager.get_ca_public_key()?);
36        Ok(Arc::new(Self {
37            certmanager: certmanager.clone(),
38            transport: transport.clone(),
39            fabric,
40        }))
41    }
42
43    /// commission device
44    /// - authenticate using pin
45    /// - push CA certificate to device
46    /// - sign device's certificate
47    /// - set controller id - user which can control device
48    /// - return authenticated connection which can be used to send additional commands
49    pub async fn commission(
50        &self,
51        connection: &Arc<transport::Connection>,
52        pin: u32,
53        node_id: u64,
54        controller_id: u64,
55    ) -> Result<Connection> {
56        let mut session = auth_spake(connection, pin).await?;
57        let session = commission::commission(
58            connection,
59            &mut session,
60            &self.fabric,
61            self.certmanager.as_ref(),
62            node_id,
63            controller_id,
64        )
65        .await?;
66        Ok(Connection {
67            connection: connection.clone(),
68            session,
69        })
70    }
71
72    /// create authenticated connection to control device
73    pub async fn auth_sigma(
74        &self,
75        connection: &Arc<transport::Connection>,
76        node_id: u64,
77        controller_id: u64,
78    ) -> Result<Connection> {
79        let session = auth_sigma(
80            connection,
81            &self.fabric,
82            self.certmanager.as_ref(),
83            node_id,
84            controller_id,
85        )
86        .await?;
87        Ok(Connection {
88            connection: connection.clone(),
89            session,
90        })
91    }
92}
93
94/// Authenticated virtual connection can bse used to send commands to device.
95impl Connection {
96    /// Read attribute from device and return parsed matter protocol response.
97    pub async fn read_request(
98        &mut self,
99        endpoint: u16,
100        cluster: u32,
101        attr: u32,
102    ) -> Result<Message> {
103        read_request(&self.connection, &mut self.session, endpoint, cluster, attr).await
104    }
105
106    /// Read attribute from device and return tlv with attribute value.
107    pub async fn read_request2(
108        &mut self,
109        endpoint: u16,
110        cluster: u32,
111        attr: u32,
112    ) -> Result<TlvItemValue> {
113        let res =
114            read_request(&self.connection, &mut self.session, endpoint, cluster, attr).await?;
115        if (res.protocol_header.protocol_id
116            != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION)
117            || (res.protocol_header.opcode
118                != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA)
119        {
120            Err(anyhow::anyhow!(
121                "response is not expected report_data {:?}",
122                res.protocol_header
123            ))
124        } else {
125            match res.tlv.get(&[1, 0, 1, 2]) {
126                Some(a) => Ok(a.clone()),
127                None => {
128                    let s = res
129                        .tlv
130                        .get(&[1, 0, 0, 1, 0])
131                        .context("report data format not recognized1")?;
132                    if let TlvItemValue::Int(status) = s {
133                        Err(anyhow::anyhow!("report data with status {}", status))
134                    } else {
135                        Err(anyhow::anyhow!("report data format not recognized2"))
136                    }
137                }
138            }
139        }
140    }
141
142    /// Invoke command
143    pub async fn invoke_request(
144        &mut self,
145        endpoint: u16,
146        cluster: u32,
147        command: u32,
148        payload: &[u8],
149    ) -> Result<Message> {
150        invoke_request(
151            &self.connection,
152            &mut self.session,
153            endpoint,
154            cluster,
155            command,
156            payload,
157        )
158        .await
159    }
160
161    /// Invoke command
162    pub async fn invoke_request2(
163        &mut self,
164        endpoint: u16,
165        cluster: u32,
166        command: u32,
167        payload: &[u8],
168    ) -> Result<TlvItemValue> {
169        let res = invoke_request(
170            &self.connection,
171            &mut self.session,
172            endpoint,
173            cluster,
174            command,
175            payload,
176        )
177        .await?;
178        let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
179        Ok(o.clone())
180    }
181
182    pub async fn invoke_request_timed(
183        &mut self,
184        endpoint: u16,
185        cluster: u32,
186        command: u32,
187        payload: &[u8],
188        timeout: u16,
189    ) -> Result<Message> {
190        invoke_request_timed(
191            &self.connection,
192            &mut self.session,
193            endpoint,
194            cluster,
195            command,
196            payload,
197            timeout,
198        )
199        .await
200    }
201}
202
203/*async fn get_next_message(
204    connection: &transport::Connection,
205    session: &mut session::Session,
206) -> Result<messages::Message> {
207    loop {
208        let resp = connection.receive(Duration::from_secs(3)).await?;
209        let resp = session.decode_message(&resp)?;
210        let decoded = messages::Message::decode(&resp)?;
211        if decoded.protocol_header.protocol_id
212            == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
213            && decoded.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_ACK
214        {
215            continue;
216        }
217        let ack = messages::ack(
218            decoded.protocol_header.exchange_id,
219            decoded.message_header.message_counter as i64,
220        )?;
221        let out = session.encode_message(&ack)?;
222        connection.send(&out).await?;
223        return Ok(decoded);
224    }
225}*/
226
227pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
228    let mut out = Vec::new();
229    out.write_u32::<LittleEndian>(pin)?;
230    Ok(out)
231}
232
233async fn auth_spake(connection: &transport::Connection, pin: u32) -> Result<session::Session> {
234    let exchange = rand::random();
235    log::debug!("start auth_spake");
236    let mut session = session::Session::new();
237    let mut retrctx = retransmit::RetrContext::new(connection, &mut session);
238    // send pbkdf
239    log::debug!("send pbkdf request");
240    let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
241    retrctx.send(&pbkdf_req_protocol_message).await?;
242
243    // get pbkdf response
244    let pbkdf_response = retrctx.get_next_message().await?;
245    if pbkdf_response.protocol_header.protocol_id
246        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
247        || pbkdf_response.protocol_header.opcode
248            != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
249    {
250        return Err(anyhow::anyhow!("pbkdf response not received"));
251    }
252
253    let iterations = pbkdf_response
254        .tlv
255        .get_int(&[4, 1])
256        .context("pbkdf_response - iterations missing")?;
257    let salt = pbkdf_response
258        .tlv
259        .get_octet_string(&[4, 2])
260        .context("pbkdf_response - salt missing")?;
261    let p_session = pbkdf_response
262        .tlv
263        .get_int(&[3])
264        .context("pbkdf_response - session missing")?;
265
266    // send pake1
267    let engine = spake2p::Engine::new()?;
268    let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
269    log::debug!("send pake1 request");
270    let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
271    retrctx.send(&pake1_protocol_message).await?;
272
273    // receive pake2
274    let pake2 = retrctx.get_next_message().await?;
275    if pake2.protocol_header.protocol_id
276        != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
277        || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
278    {
279        return Err(anyhow::anyhow!("pake2 not received"));
280    }
281    let pake2_pb = pake2
282        .tlv
283        .get_octet_string(&[1])
284        .context("pake2 pb tlv missing")?;
285    ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
286
287    // send pake3
288    let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
289    hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
290    hash_seed.extend_from_slice(&pbkdf_response.payload);
291    engine.finish(&mut ctx, &hash_seed)?;
292    let pake3_protocol_message = messages::pake3(
293        exchange,
294        &ctx.ca.context("ca value not present in context")?,
295        -1,
296    )?;
297    log::debug!("send pake3 request");
298    retrctx.send(&pake3_protocol_message).await?;
299
300    let pake3_resp = retrctx.get_next_message().await?;
301    match &pake3_resp.status_report_info {
302        Some(s) => {
303            if !s.is_ok() {
304                return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
305            }
306        }
307        None => {
308            return Err(anyhow::anyhow!(
309                "expecting status report (pake3 resp), got {:?}",
310                pake3_resp
311            ))
312        }
313    }
314
315    session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
316    session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
317    session.session_id = p_session as u16;
318    log::debug!("auth_spake ok; session: {}", session.session_id);
319    Ok(session)
320}
321
322pub(crate) async fn auth_sigma(
323    connection: &transport::Connection,
324    fabric: &fabric::Fabric,
325    cm: &dyn certmanager::CertManager,
326    node_id: u64,
327    controller_id: u64,
328) -> Result<session::Session> {
329    log::debug!("auth_sigma");
330    let exchange = rand::random();
331    let mut session = session::Session::new();
332    let mut retrctx = retransmit::RetrContext::new(connection, &mut session);
333    retrctx.subscribe_exchange(exchange);
334    let mut ctx = sigma::SigmaContext::new(node_id);
335    let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
336    sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
337    let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
338
339    log::debug!("send sigma1 {}", exchange);
340    retrctx.send(&s1).await?;
341
342    // receive sigma2
343    log::debug!("receive sigma2 {}", exchange);
344    let sigma2 = retrctx.get_next_message().await?;
345    log::debug!("sigma2 received {:?}", sigma2);
346    if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
347        && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
348    {
349        return Err(anyhow::anyhow!("sigma2 not received, status: {}", sigma2.status_report_info.context("status report info missing")?.to_string()));
350    }
351    ctx.sigma2_payload = sigma2.payload;
352    ctx.responder_session = sigma2
353        .tlv
354        .get_int(&[2])
355        .context("responder session tlv missing in sigma2")? as u16;
356    ctx.responder_public = sigma2
357        .tlv
358        .get_octet_string(&[3])
359        .context("responder public tlv missing in sigma2")?
360        .to_vec();
361
362    let controller_private = cm.get_user_key(controller_id)?;
363    let controller_x509 = cm.get_user_cert(controller_id)?;
364    let controller_matter_cert =
365        cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
366
367    // send sigma3
368    log::debug!("send sigma3 {}", exchange);
369    sigma::sigma3(
370        fabric,
371        &mut ctx,
372        &controller_private.to_sec1_der()?,
373        &controller_matter_cert,
374    )?;
375    let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload)?;
376    retrctx.send(&sigma3).await?;
377
378    log::debug!("receive result {}", exchange);
379    let status = retrctx.get_next_message().await?;
380    if !status
381        .status_report_info
382        .context("sigma3 status resp not received")?
383        .is_ok()
384    {
385        return Err(anyhow::anyhow!(format!(
386            "response to sigma3 does not contain status ok {:?}",
387            status
388        )));
389    }
390
391    //session keys
392    let mut th = ctx.sigma1_payload.clone();
393    th.extend_from_slice(&ctx.sigma2_payload);
394
395    let mut transcript = th;
396    transcript.extend_from_slice(&ctx.sigma3_payload);
397    let transcript_hash = cryptoutil::sha256(&transcript);
398    let mut salt = fabric.signed_ipk()?;
399    salt.extend_from_slice(&transcript_hash);
400    let shared = ctx.shared.context("shared secret not in context")?;
401    let keypack = cryptoutil::hkdf_sha256(
402        &salt,
403        shared.raw_secret_bytes().as_slice(),
404        "SessionKeys".as_bytes(),
405        16 * 3,
406    )?;
407    let mut ses = session::Session::new();
408    ses.session_id = ctx.responder_session;
409    ses.set_decrypt_key(&keypack[16..32]);
410    ses.set_encrypt_key(&keypack[..16]);
411
412    let mut local_node = Vec::new();
413    local_node.write_u64::<LittleEndian>(controller_id)?;
414    ses.local_node = Some(local_node);
415
416    let mut remote_node = Vec::new();
417    remote_node.write_u64::<LittleEndian>(node_id)?;
418    ses.remote_node = Some(remote_node);
419
420    Ok(ses)
421}
422
423async fn read_request(
424    connection: &transport::Connection,
425    session: &mut session::Session,
426    endpoint: u16,
427    cluster: u32,
428    attr: u32,
429) -> Result<Message> {
430    let exchange = rand::random();
431    let mut retrctx = retransmit::RetrContext::new(connection, session);
432    let testm = messages::im_read_request(endpoint, cluster, attr, exchange)?;
433    retrctx.send(&testm).await?;
434    let result = retrctx.get_next_message().await?;
435    Ok(result)
436}
437
438async fn invoke_request(
439    connection: &transport::Connection,
440    session: &mut session::Session,
441    endpoint: u16,
442    cluster: u32,
443    command: u32,
444    payload: &[u8],
445) -> Result<Message> {
446    let exchange = rand::random();
447    let mut retrctx = retransmit::RetrContext::new(connection, session);
448    retrctx.subscribe_exchange(exchange);
449    log::debug!(
450        "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
451        exchange,
452        endpoint,
453        cluster,
454        command
455    );
456    let testm = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
457    retrctx.send(&testm).await?;
458    let result = retrctx.get_next_message().await?;
459    Ok(result)
460}
461
462async fn invoke_request_timed(
463    connection: &transport::Connection,
464    session: &mut session::Session,
465    endpoint: u16,
466    cluster: u32,
467    command: u32,
468    payload: &[u8],
469    timeout: u16,
470) -> Result<Message> {
471    let exchange = rand::random();
472    let mut retrctx = retransmit::RetrContext::new(connection, session);
473    retrctx.subscribe_exchange(exchange);
474    let tr = messages::im_timed_request(exchange, timeout)?;
475    retrctx.send(&tr).await?;
476    let result = retrctx.get_next_message().await?;
477    if result.protocol_header.protocol_id
478        != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
479        || result.protocol_header.opcode
480            != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
481    {
482        return Err(anyhow::anyhow!(
483            "invoke_request_timed: unexpected response {:?}",
484            result
485        ));
486    }
487    let status = result
488        .tlv
489        .get_int(&[0])
490        .context("invoke_request_timed: status not found")?;
491    if status != 0 {
492        return Err(anyhow::anyhow!(
493            "invoke_request_timed: unexpected status {}",
494            status
495        ));
496    }
497    log::debug!(
498        "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
499        exchange,
500        endpoint,
501        cluster,
502        command
503    );
504    let testm = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
505    retrctx.send(&testm).await?;
506    let result = retrctx.get_next_message().await?;
507    Ok(result)
508}