1use std::sync::Arc;
2
3use crate::{
4 cert_matter, certmanager, commission, fabric,
5 messages::{self, Message},
6 retransmit, session, sigma, spake2p,
7 tlv::TlvItemValue,
8 transport,
9 util::cryptoutil,
10};
11use anyhow::{Context, Result};
12use byteorder::{LittleEndian, WriteBytesExt};
13
14pub struct Controller {
15 certmanager: Arc<dyn certmanager::CertManager>,
16 transport: Arc<transport::Transport>,
17 fabric: fabric::Fabric,
18}
19
20pub struct Connection {
21 connection: Arc<transport::Connection>,
22 session: session::Session,
23}
24const CA_ID: u64 = 1;
28
29impl Controller {
30 pub fn new(
31 certmanager: &Arc<dyn certmanager::CertManager>,
32 transport: &Arc<transport::Transport>,
33 fabric_id: u64,
34 ) -> Result<Arc<Self>> {
35 let fabric = fabric::Fabric::new(fabric_id, CA_ID, &certmanager.get_ca_public_key()?);
36 Ok(Arc::new(Self {
37 certmanager: certmanager.clone(),
38 transport: transport.clone(),
39 fabric,
40 }))
41 }
42
43 pub async fn commission(
50 &self,
51 connection: &Arc<transport::Connection>,
52 pin: u32,
53 node_id: u64,
54 controller_id: u64,
55 ) -> Result<Connection> {
56 let mut session = auth_spake(connection, pin).await?;
57 let session = commission::commission(
58 connection,
59 &mut session,
60 &self.fabric,
61 self.certmanager.as_ref(),
62 node_id,
63 controller_id,
64 )
65 .await?;
66 Ok(Connection {
67 connection: connection.clone(),
68 session,
69 })
70 }
71
72 pub async fn auth_sigma(
74 &self,
75 connection: &Arc<transport::Connection>,
76 node_id: u64,
77 controller_id: u64,
78 ) -> Result<Connection> {
79 let session = auth_sigma(
80 connection,
81 &self.fabric,
82 self.certmanager.as_ref(),
83 node_id,
84 controller_id,
85 )
86 .await?;
87 Ok(Connection {
88 connection: connection.clone(),
89 session,
90 })
91 }
92}
93
94impl Connection {
96 pub async fn read_request(
98 &mut self,
99 endpoint: u16,
100 cluster: u32,
101 attr: u32,
102 ) -> Result<Message> {
103 read_request(&self.connection, &mut self.session, endpoint, cluster, attr).await
104 }
105
106 pub async fn read_request2(
108 &mut self,
109 endpoint: u16,
110 cluster: u32,
111 attr: u32,
112 ) -> Result<TlvItemValue> {
113 let res =
114 read_request(&self.connection, &mut self.session, endpoint, cluster, attr).await?;
115 if (res.protocol_header.protocol_id
116 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION)
117 || (res.protocol_header.opcode
118 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA)
119 {
120 Err(anyhow::anyhow!(
121 "response is not expected report_data {:?}",
122 res.protocol_header
123 ))
124 } else {
125 match res.tlv.get(&[1, 0, 1, 2]) {
126 Some(a) => Ok(a.clone()),
127 None => {
128 let s = res
129 .tlv
130 .get(&[1, 0, 0, 1, 0])
131 .context("report data format not recognized1")?;
132 if let TlvItemValue::Int(status) = s {
133 Err(anyhow::anyhow!("report data with status {}", status))
134 } else {
135 Err(anyhow::anyhow!("report data format not recognized2"))
136 }
137 }
138 }
139 }
140 }
141
142 pub async fn invoke_request(
144 &mut self,
145 endpoint: u16,
146 cluster: u32,
147 command: u32,
148 payload: &[u8],
149 ) -> Result<Message> {
150 invoke_request(
151 &self.connection,
152 &mut self.session,
153 endpoint,
154 cluster,
155 command,
156 payload,
157 )
158 .await
159 }
160
161 pub async fn invoke_request2(
163 &mut self,
164 endpoint: u16,
165 cluster: u32,
166 command: u32,
167 payload: &[u8],
168 ) -> Result<TlvItemValue> {
169 let res = invoke_request(
170 &self.connection,
171 &mut self.session,
172 endpoint,
173 cluster,
174 command,
175 payload,
176 )
177 .await?;
178 let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
179 Ok(o.clone())
180 }
181
182 pub async fn invoke_request_timed(
183 &mut self,
184 endpoint: u16,
185 cluster: u32,
186 command: u32,
187 payload: &[u8],
188 timeout: u16,
189 ) -> Result<Message> {
190 invoke_request_timed(
191 &self.connection,
192 &mut self.session,
193 endpoint,
194 cluster,
195 command,
196 payload,
197 timeout,
198 )
199 .await
200 }
201}
202
203pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
228 let mut out = Vec::new();
229 out.write_u32::<LittleEndian>(pin)?;
230 Ok(out)
231}
232
233async fn auth_spake(connection: &transport::Connection, pin: u32) -> Result<session::Session> {
234 let exchange = rand::random();
235 log::debug!("start auth_spake");
236 let mut session = session::Session::new();
237 let mut retrctx = retransmit::RetrContext::new(connection, &mut session);
238 log::debug!("send pbkdf request");
240 let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
241 retrctx.send(&pbkdf_req_protocol_message).await?;
242
243 let pbkdf_response = retrctx.get_next_message().await?;
245 if pbkdf_response.protocol_header.protocol_id
246 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
247 || pbkdf_response.protocol_header.opcode
248 != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
249 {
250 return Err(anyhow::anyhow!("pbkdf response not received"));
251 }
252
253 let iterations = pbkdf_response
254 .tlv
255 .get_int(&[4, 1])
256 .context("pbkdf_response - iterations missing")?;
257 let salt = pbkdf_response
258 .tlv
259 .get_octet_string(&[4, 2])
260 .context("pbkdf_response - salt missing")?;
261 let p_session = pbkdf_response
262 .tlv
263 .get_int(&[3])
264 .context("pbkdf_response - session missing")?;
265
266 let engine = spake2p::Engine::new()?;
268 let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
269 log::debug!("send pake1 request");
270 let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
271 retrctx.send(&pake1_protocol_message).await?;
272
273 let pake2 = retrctx.get_next_message().await?;
275 if pake2.protocol_header.protocol_id
276 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
277 || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
278 {
279 return Err(anyhow::anyhow!("pake2 not received"));
280 }
281 let pake2_pb = pake2
282 .tlv
283 .get_octet_string(&[1])
284 .context("pake2 pb tlv missing")?;
285 ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
286
287 let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
289 hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
290 hash_seed.extend_from_slice(&pbkdf_response.payload);
291 engine.finish(&mut ctx, &hash_seed)?;
292 let pake3_protocol_message = messages::pake3(
293 exchange,
294 &ctx.ca.context("ca value not present in context")?,
295 -1,
296 )?;
297 log::debug!("send pake3 request");
298 retrctx.send(&pake3_protocol_message).await?;
299
300 let pake3_resp = retrctx.get_next_message().await?;
301 match &pake3_resp.status_report_info {
302 Some(s) => {
303 if !s.is_ok() {
304 return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
305 }
306 }
307 None => {
308 return Err(anyhow::anyhow!(
309 "expecting status report (pake3 resp), got {:?}",
310 pake3_resp
311 ))
312 }
313 }
314
315 session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
316 session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
317 session.session_id = p_session as u16;
318 log::debug!("auth_spake ok; session: {}", session.session_id);
319 Ok(session)
320}
321
322pub(crate) async fn auth_sigma(
323 connection: &transport::Connection,
324 fabric: &fabric::Fabric,
325 cm: &dyn certmanager::CertManager,
326 node_id: u64,
327 controller_id: u64,
328) -> Result<session::Session> {
329 log::debug!("auth_sigma");
330 let exchange = rand::random();
331 let mut session = session::Session::new();
332 let mut retrctx = retransmit::RetrContext::new(connection, &mut session);
333 retrctx.subscribe_exchange(exchange);
334 let mut ctx = sigma::SigmaContext::new(node_id);
335 let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
336 sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
337 let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
338
339 log::debug!("send sigma1 {}", exchange);
340 retrctx.send(&s1).await?;
341
342 log::debug!("receive sigma2 {}", exchange);
344 let sigma2 = retrctx.get_next_message().await?;
345 log::debug!("sigma2 received {:?}", sigma2);
346 if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
347 && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
348 {
349 return Err(anyhow::anyhow!("sigma2 not received, status: {}", sigma2.status_report_info.context("status report info missing")?.to_string()));
350 }
351 ctx.sigma2_payload = sigma2.payload;
352 ctx.responder_session = sigma2
353 .tlv
354 .get_int(&[2])
355 .context("responder session tlv missing in sigma2")? as u16;
356 ctx.responder_public = sigma2
357 .tlv
358 .get_octet_string(&[3])
359 .context("responder public tlv missing in sigma2")?
360 .to_vec();
361
362 let controller_private = cm.get_user_key(controller_id)?;
363 let controller_x509 = cm.get_user_cert(controller_id)?;
364 let controller_matter_cert =
365 cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
366
367 log::debug!("send sigma3 {}", exchange);
369 sigma::sigma3(
370 fabric,
371 &mut ctx,
372 &controller_private.to_sec1_der()?,
373 &controller_matter_cert,
374 )?;
375 let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload)?;
376 retrctx.send(&sigma3).await?;
377
378 log::debug!("receive result {}", exchange);
379 let status = retrctx.get_next_message().await?;
380 if !status
381 .status_report_info
382 .context("sigma3 status resp not received")?
383 .is_ok()
384 {
385 return Err(anyhow::anyhow!(format!(
386 "response to sigma3 does not contain status ok {:?}",
387 status
388 )));
389 }
390
391 let mut th = ctx.sigma1_payload.clone();
393 th.extend_from_slice(&ctx.sigma2_payload);
394
395 let mut transcript = th;
396 transcript.extend_from_slice(&ctx.sigma3_payload);
397 let transcript_hash = cryptoutil::sha256(&transcript);
398 let mut salt = fabric.signed_ipk()?;
399 salt.extend_from_slice(&transcript_hash);
400 let shared = ctx.shared.context("shared secret not in context")?;
401 let keypack = cryptoutil::hkdf_sha256(
402 &salt,
403 shared.raw_secret_bytes().as_slice(),
404 "SessionKeys".as_bytes(),
405 16 * 3,
406 )?;
407 let mut ses = session::Session::new();
408 ses.session_id = ctx.responder_session;
409 ses.set_decrypt_key(&keypack[16..32]);
410 ses.set_encrypt_key(&keypack[..16]);
411
412 let mut local_node = Vec::new();
413 local_node.write_u64::<LittleEndian>(controller_id)?;
414 ses.local_node = Some(local_node);
415
416 let mut remote_node = Vec::new();
417 remote_node.write_u64::<LittleEndian>(node_id)?;
418 ses.remote_node = Some(remote_node);
419
420 Ok(ses)
421}
422
423async fn read_request(
424 connection: &transport::Connection,
425 session: &mut session::Session,
426 endpoint: u16,
427 cluster: u32,
428 attr: u32,
429) -> Result<Message> {
430 let exchange = rand::random();
431 let mut retrctx = retransmit::RetrContext::new(connection, session);
432 let testm = messages::im_read_request(endpoint, cluster, attr, exchange)?;
433 retrctx.send(&testm).await?;
434 let result = retrctx.get_next_message().await?;
435 Ok(result)
436}
437
438async fn invoke_request(
439 connection: &transport::Connection,
440 session: &mut session::Session,
441 endpoint: u16,
442 cluster: u32,
443 command: u32,
444 payload: &[u8],
445) -> Result<Message> {
446 let exchange = rand::random();
447 let mut retrctx = retransmit::RetrContext::new(connection, session);
448 retrctx.subscribe_exchange(exchange);
449 log::debug!(
450 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
451 exchange,
452 endpoint,
453 cluster,
454 command
455 );
456 let testm = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
457 retrctx.send(&testm).await?;
458 let result = retrctx.get_next_message().await?;
459 Ok(result)
460}
461
462async fn invoke_request_timed(
463 connection: &transport::Connection,
464 session: &mut session::Session,
465 endpoint: u16,
466 cluster: u32,
467 command: u32,
468 payload: &[u8],
469 timeout: u16,
470) -> Result<Message> {
471 let exchange = rand::random();
472 let mut retrctx = retransmit::RetrContext::new(connection, session);
473 retrctx.subscribe_exchange(exchange);
474 let tr = messages::im_timed_request(exchange, timeout)?;
475 retrctx.send(&tr).await?;
476 let result = retrctx.get_next_message().await?;
477 if result.protocol_header.protocol_id
478 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
479 || result.protocol_header.opcode
480 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
481 {
482 return Err(anyhow::anyhow!(
483 "invoke_request_timed: unexpected response {:?}",
484 result
485 ));
486 }
487 let status = result
488 .tlv
489 .get_int(&[0])
490 .context("invoke_request_timed: status not found")?;
491 if status != 0 {
492 return Err(anyhow::anyhow!(
493 "invoke_request_timed: unexpected status {}",
494 status
495 ));
496 }
497 log::debug!(
498 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
499 exchange,
500 endpoint,
501 cluster,
502 command
503 );
504 let testm = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
505 retrctx.send(&testm).await?;
506 let result = retrctx.get_next_message().await?;
507 Ok(result)
508}