1use std::sync::Arc;
2
3use crate::{
4 active_connection::ActiveConnection,
5 cert_matter, certmanager, commission, fabric,
6 messages::{self, Message},
7 retransmit, session, sigma, spake2p,
8 tlv::TlvItemValue,
9 transport::{self, ConnectionTrait},
10 util::cryptoutil,
11};
12use anyhow::{Context, Result};
13use byteorder::{LittleEndian, WriteBytesExt};
14
15pub struct Controller {
16 certmanager: Arc<dyn certmanager::CertManager>,
17 #[allow(dead_code)]
18 transport: Arc<transport::Transport>,
19 fabric: fabric::Fabric,
20}
21
22pub struct Connection {
23 active: ActiveConnection,
24}
25const CA_ID: u64 = 1;
29
30impl Controller {
31 pub fn new(
32 certmanager: &Arc<dyn certmanager::CertManager>,
33 transport: &Arc<transport::Transport>,
34 fabric_id: u64,
35 ) -> Result<Arc<Self>> {
36 let fabric = fabric::Fabric::new(fabric_id, CA_ID, &certmanager.get_ca_public_key()?);
37 Ok(Arc::new(Self {
38 certmanager: certmanager.clone(),
39 transport: transport.clone(),
40 fabric,
41 }))
42 }
43
44 pub async fn commission(
51 &self,
52 connection: &Arc<dyn ConnectionTrait>,
53 pin: u32,
54 node_id: u64,
55 controller_id: u64,
56 ) -> Result<Connection> {
57 let mut session = auth_spake(connection.as_ref(), pin).await?;
58 let session = commission::commission(
59 connection.as_ref(),
60 &mut session,
61 &self.fabric,
62 self.certmanager.as_ref(),
63 node_id,
64 controller_id,
65 )
66 .await?;
67 Ok(Connection {
68 active: ActiveConnection::new(connection.clone(), session),
69 })
70 }
71
72 pub async fn auth_sigma(
74 &self,
75 connection: &Arc<dyn ConnectionTrait>,
76 node_id: u64,
77 controller_id: u64,
78 ) -> Result<Connection> {
79 let session = auth_sigma(
80 connection.as_ref(),
81 &self.fabric,
82 self.certmanager.as_ref(),
83 node_id,
84 controller_id,
85 )
86 .await?;
87 Ok(Connection {
88 active: ActiveConnection::new(connection.clone(), session),
89 })
90 }
91
92 #[cfg(feature = "ble")]
106 pub async fn commission_ble(
107 &self,
108 discriminator: u16,
109 short_discriminator: bool,
110 pin: u32,
111 node_id: u64,
112 controller_id: u64,
113 network_creds: commission::NetworkCreds,
114 mdns: &std::sync::Arc<crate::mdns2::MdnsService>,
115 mdns_receiver: &tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<crate::mdns2::MdnsEvent>>,
116 ) -> Result<Connection> {
117 use crate::{btp::BtpConnection, discover};
118
119 let peripheral = crate::ble::find_by_discriminator(discriminator, short_discriminator, std::time::Duration::from_secs(30))
121 .await
122 .context("BLE scan")?;
123 log::debug!("BLE device found: z2");
124 let btp_conn = BtpConnection::connect(peripheral).await.context("BTP connect")?;
125
126 let mut pase_session = auth_spake(btp_conn.as_ref(), pin).await.context("PASE over BLE")?;
128
129 commission::commission_ble_phase(
131 btp_conn.as_ref(),
132 &mut pase_session,
133 &self.fabric,
134 self.certmanager.as_ref(),
135 node_id,
136 controller_id,
137 &network_creds,
138 )
139 .await
140 .context("BLE commissioning phase")?;
141
142 drop(btp_conn);
144
145 let ca_pubkey = self.certmanager.get_ca_public_key()?;
147 let fabric_tmp = fabric::Fabric::new(self.fabric.id, 0, &ca_pubkey);
148 let compressed = fabric_tmp.compressed().context("compressed fabric ID")?;
149 let instance = format!("{}-{:016X}", hex::encode_upper(&compressed), node_id);
150 let expected_target = format!("{}._matter._tcp.local.", instance);
151
152 let mut addresses = Vec::new();
153 {
154 let mut rx = mdns_receiver.lock().await;
155 mdns.active_lookup("_matter._tcp.local", 0xff).await;
156 loop {
157 match tokio::time::timeout(std::time::Duration::from_secs(30), rx.recv()).await {
158 Ok(Some(crate::mdns2::MdnsEvent::ServiceDiscovered { name, records: _, target })) => {
159 if name != "_matter._tcp.local." || target != expected_target {
160 continue;
161 }
162 let info = discover::extract_matter_info(&target, mdns).await?;
163 log::debug!("Operational mDNS discovered device: {:?}", info);
164
165 let port = info.port.unwrap_or(5540);
166 for ip in &info.ips {
167 if ip.is_ipv6() {
168 addresses.push(format!("[{}]:{}", ip, port));
169 } else {
170 addresses.push(format!("{}:{}", ip, port));
171 }
172 }
173 break;
174 }
175 Ok(_) => continue,
176 Err(_) => anyhow::bail!("operational mDNS timeout for {}", instance),
177 }
178 }
179 };
180
181 log::info!("Device discovered at {}", addresses.join(", "));
182
183 for address in addresses {
185 log::debug!("Trying to commission over UDP at {}...", address);
186 let udp_conn = self.transport.create_connection(&address).await;
187 let ses = commission::commissioning_complete_udp(
188 udp_conn.as_ref(),
189 self.certmanager.as_ref(),
190 node_id,
191 controller_id,
192 &self.fabric,
193 )
194 .await;
195 if let Ok(ses) = ses {
196 return Ok(Connection {
197 active: ActiveConnection::new(udp_conn, ses),
198 });
199 } else {
200 log::debug!("Failed to commission over UDP at {}: {:?}", address, ses.err());
201 }
202 }
203 Err(anyhow::anyhow!("failed to commission device over UDP at any discovered address"))
204 }
205}
206
207impl Connection {
209 pub async fn read_request(
211 &self,
212 endpoint: u16,
213 cluster: u32,
214 attr: u32,
215 ) -> Result<Message> {
216 let exchange: u16 = rand::random();
217 let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
218 self.active.request(exchange, &msg).await
219 }
220
221 pub async fn read_request2(
223 &self,
224 endpoint: u16,
225 cluster: u32,
226 attr: u32,
227 ) -> Result<TlvItemValue> {
228 let res = self.read_request(endpoint, cluster, attr).await?;
229 if (res.protocol_header.protocol_id
230 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION)
231 || (res.protocol_header.opcode
232 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA)
233 {
234 Err(anyhow::anyhow!(
235 "response is not expected report_data {:?}",
236 res.protocol_header
237 ))
238 } else {
239 match res.tlv.get(&[1, 0, 1, 2]) {
240 Some(a) => Ok(a.clone()),
241 None => {
242 let s = res
243 .tlv
244 .get(&[1, 0, 0, 1, 0])
245 .context("report data format not recognized1")?;
246 if let TlvItemValue::Int(status) = s {
247 Err(anyhow::anyhow!("report data with status {}", status))
248 } else {
249 Err(anyhow::anyhow!("report data format not recognized2"))
250 }
251 }
252 }
253 }
254 }
255
256 pub async fn invoke_request(
258 &self,
259 endpoint: u16,
260 cluster: u32,
261 command: u32,
262 payload: &[u8],
263 ) -> Result<Message> {
264 let exchange: u16 = rand::random();
265 log::debug!(
266 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
267 exchange,
268 endpoint,
269 cluster,
270 command
271 );
272 let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
273 self.active.request(exchange, &msg).await
274 }
275
276 pub async fn invoke_request2(
278 &self,
279 endpoint: u16,
280 cluster: u32,
281 command: u32,
282 payload: &[u8],
283 ) -> Result<TlvItemValue> {
284 let res = self.invoke_request(endpoint, cluster, command, payload).await?;
285 let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
286 Ok(o.clone())
287 }
288
289 pub async fn im_subscribe_request(
290 &self,
291 endpoint: u16,
292 cluster: u32,
293 event: u32,
294 ) -> Result<Message> {
295 let exchange: u16 = rand::random();
296 log::debug!(
297 "im_subscribe_request exch:{} endpoint:{} cluster:{} event:{}",
298 exchange,
299 endpoint,
300 cluster,
301 event
302 );
303 let msg = messages::im_subscribe_request(endpoint, cluster, exchange, event)?;
304 self.active.request(exchange, &msg).await
305 }
306
307 pub async fn im_subscribe_request_attr(
311 &self,
312 endpoint: u16,
313 cluster: u32,
314 attr: u32,
315 keep_subscriptions: bool,
316 ) -> Result<Message> {
317 let exchange: u16 = rand::random();
318 log::debug!(
319 "im_subscribe_request_attr exch:{} endpoint:{} cluster:{} attr:{} keep:{}",
320 exchange, endpoint, cluster, attr, keep_subscriptions
321 );
322 let msg = messages::im_subscribe_request_attr(endpoint, cluster, attr, exchange, keep_subscriptions)?;
323 self.active.request(exchange, &msg).await
324 }
325
326 pub async fn im_unsubscribe_all(&self) -> Result<Message> {
329 let exchange: u16 = rand::random();
330 log::debug!("im_unsubscribe_all exch:{}", exchange);
331 let msg = messages::im_unsubscribe_all(exchange)?;
332 self.active.request(exchange, &msg).await
333 }
334
335 pub async fn im_status_response(
336 &self,
337 exchange: u16,
338 flags: u8,
339 ack: u32
340 ) -> Result<()> {
341 let msg = messages::im_status_response(exchange, flags, ack)?;
342 self.active.send(&msg).await
343 }
344
345 pub async fn invoke_request_timed(
347 &self,
348 endpoint: u16,
349 cluster: u32,
350 command: u32,
351 payload: &[u8],
352 timeout: u16,
353 ) -> Result<Message> {
354 let exchange: u16 = rand::random();
355
356 let tr = messages::im_timed_request(exchange, timeout)?;
358 let result = self.active.request(exchange, &tr).await?;
359
360 if result.protocol_header.protocol_id
361 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
362 || result.protocol_header.opcode
363 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
364 {
365 return Err(anyhow::anyhow!(
366 "invoke_request_timed: unexpected response {:?}",
367 result
368 ));
369 }
370 let status = result
371 .tlv
372 .get_int(&[0])
373 .context("invoke_request_timed: status not found")?;
374 if status != 0 {
375 return Err(anyhow::anyhow!(
376 "invoke_request_timed: unexpected status {}",
377 status
378 ));
379 }
380
381 log::debug!(
382 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
383 exchange,
384 endpoint,
385 cluster,
386 command
387 );
388 let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
389 self.active.request(exchange, &msg).await
390 }
391
392 pub async fn recv_event(&self) -> Option<Message> {
394 self.active.recv_event().await
395 }
396
397 pub fn try_recv_event(&self) -> Option<Message> {
399 self.active.try_recv_event()
400 }
401}
402
403pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
404 let mut out = Vec::new();
405 out.write_u32::<LittleEndian>(pin)?;
406 Ok(out)
407}
408
409pub(crate) async fn auth_spake(connection: &dyn ConnectionTrait, pin: u32) -> Result<session::Session> {
410 let exchange = rand::random();
411 log::debug!("start auth_spake");
412 let mut session = session::Session::new();
413 session.my_session_id = 1;
414 let mut retrctx = retransmit::RetrContext::new(connection, &session);
415 log::debug!("send pbkdf request");
417 let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
418 retrctx.send(&pbkdf_req_protocol_message).await?;
419
420 let pbkdf_response = retrctx.get_next_message().await?;
422 if pbkdf_response.protocol_header.protocol_id
423 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
424 || pbkdf_response.protocol_header.opcode
425 != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
426 {
427 return Err(anyhow::anyhow!("pbkdf response not received"));
428 }
429
430 let iterations = pbkdf_response
431 .tlv
432 .get_int(&[4, 1])
433 .context("pbkdf_response - iterations missing")?;
434 let salt = pbkdf_response
435 .tlv
436 .get_octet_string(&[4, 2])
437 .context("pbkdf_response - salt missing")?;
438 let p_session = pbkdf_response
439 .tlv
440 .get_int(&[3])
441 .context("pbkdf_response - session missing")?;
442
443 let engine = spake2p::Engine::new()?;
445 let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
446 log::debug!("send pake1 request");
447 let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
448 retrctx.send(&pake1_protocol_message).await?;
449
450 let pake2 = retrctx.get_next_message().await?;
452 if pake2.protocol_header.protocol_id
453 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
454 || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
455 {
456 return Err(anyhow::anyhow!("pake2 not received"));
457 }
458 let pake2_pb = pake2
459 .tlv
460 .get_octet_string(&[1])
461 .context("pake2 pb tlv missing")?;
462 ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
463
464 let pake2_cb = pake2
465 .tlv
466 .get_octet_string(&[2])
467 .context("pake2 cb tlv missing")?;
468
469 let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
471 hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
472 hash_seed.extend_from_slice(&pbkdf_response.payload);
473 engine.finish(&mut ctx, &hash_seed, pake2_cb)?;
474 let pake3_protocol_message = messages::pake3(
475 exchange,
476 &ctx.ca.context("ca value not present in context")?,
477 -1,
478 )?;
479 log::debug!("send pake3 request");
480 retrctx.send(&pake3_protocol_message).await?;
481
482 let pake3_resp = retrctx.get_next_message().await?;
483 match &pake3_resp.status_report_info {
484 Some(s) => {
485 if !s.is_ok() {
486 return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
487 }
488 }
489 None => {
490 return Err(anyhow::anyhow!(
491 "expecting status report (pake3 resp), got {:?}",
492 pake3_resp
493 ))
494 }
495 }
496
497 session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
498 session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
499 session.session_id = p_session as u16;
500 log::debug!("auth_spake ok; session: {}", session.session_id);
501 Ok(session)
502}
503
504pub(crate) async fn auth_sigma(
505 connection: &dyn ConnectionTrait,
506 fabric: &fabric::Fabric,
507 cm: &dyn certmanager::CertManager,
508 node_id: u64,
509 controller_id: u64,
510) -> Result<session::Session> {
511 log::debug!("auth_sigma");
512 let exchange = rand::random();
513 let session = session::Session::new();
514 let mut retrctx = retransmit::RetrContext::new(connection, &session);
515 retrctx.subscribe_exchange(exchange);
516 let mut ctx = sigma::SigmaContext::new(node_id);
517 let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
518 sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
519 let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
520
521 log::debug!("send sigma1 {}", exchange);
522 retrctx.send(&s1).await?;
523
524 log::debug!("receive sigma2 {}", exchange);
526 let sigma2 = retrctx.get_next_message().await?;
527 log::debug!("sigma2 received {:?}", sigma2);
528 if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
529 && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
530 {
531 return Err(anyhow::anyhow!("sigma2 not received, status: {}", sigma2.status_report_info.context("status report info missing")?.to_string()));
532 }
533 ctx.sigma2_payload = sigma2.payload;
534 ctx.responder_session = sigma2
535 .tlv
536 .get_int(&[2])
537 .context("responder session tlv missing in sigma2")? as u16;
538 ctx.responder_public = sigma2
539 .tlv
540 .get_octet_string(&[3])
541 .context("responder public tlv missing in sigma2")?
542 .to_vec();
543
544 let controller_private = cm.get_user_key(controller_id)?;
545 let controller_x509 = cm.get_user_cert(controller_id)?;
546 let controller_matter_cert =
547 cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
548
549 log::debug!("send sigma3 {} with piggyback ack for {}", exchange, sigma2.message_header.message_counter);
551 sigma::sigma3(
552 fabric,
553 &mut ctx,
554 &controller_private.to_sec1_der()?,
555 &controller_matter_cert,
556 )?;
557 let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload, sigma2.message_header.message_counter)?;
558 retrctx.send(&sigma3).await?;
559
560 log::debug!("receive result {}", exchange);
561 let status = retrctx.get_next_message().await?;
562 if !status
563 .status_report_info
564 .context("sigma3 status resp not received")?
565 .is_ok()
566 {
567 return Err(anyhow::anyhow!(format!(
568 "response to sigma3 does not contain status ok {:?}",
569 status
570 )));
571 }
572
573 let mut th = ctx.sigma1_payload.clone();
575 th.extend_from_slice(&ctx.sigma2_payload);
576
577 let mut transcript = th;
578 transcript.extend_from_slice(&ctx.sigma3_payload);
579 let transcript_hash = cryptoutil::sha256(&transcript);
580 let mut salt = fabric.signed_ipk()?;
581 salt.extend_from_slice(&transcript_hash);
582 let shared = ctx.shared.context("shared secret not in context")?;
583 let keypack = cryptoutil::hkdf_sha256(
584 &salt,
585 shared.raw_secret_bytes().as_slice(),
586 "SessionKeys".as_bytes(),
587 16 * 3,
588 )?;
589 let mut ses = session::Session::new();
590 ses.session_id = ctx.responder_session;
591 ses.my_session_id = ctx.session_id;
592 ses.set_decrypt_key(&keypack[16..32]);
593 ses.set_encrypt_key(&keypack[..16]);
594
595 let mut local_node = Vec::new();
596 local_node.write_u64::<LittleEndian>(controller_id)?;
597 ses.local_node = Some(local_node);
598
599 let mut remote_node = Vec::new();
600 remote_node.write_u64::<LittleEndian>(node_id)?;
601 ses.remote_node = Some(remote_node);
602
603 Ok(ses)
604}
605