1use std::sync::Arc;
2
3use crate::{
4 active_connection::ActiveConnection,
5 cert_matter, certmanager, commission, fabric,
6 messages::{self, Message},
7 retransmit, session, sigma, spake2p,
8 tlv::TlvItemValue,
9 transport,
10 util::cryptoutil,
11};
12use anyhow::{Context, Result};
13use byteorder::{LittleEndian, WriteBytesExt};
14
15pub struct Controller {
16 certmanager: Arc<dyn certmanager::CertManager>,
17 #[allow(dead_code)]
18 transport: Arc<transport::Transport>,
19 fabric: fabric::Fabric,
20}
21
22pub struct Connection {
23 active: ActiveConnection,
24}
25const CA_ID: u64 = 1;
29
30impl Controller {
31 pub fn new(
32 certmanager: &Arc<dyn certmanager::CertManager>,
33 transport: &Arc<transport::Transport>,
34 fabric_id: u64,
35 ) -> Result<Arc<Self>> {
36 let fabric = fabric::Fabric::new(fabric_id, CA_ID, &certmanager.get_ca_public_key()?);
37 Ok(Arc::new(Self {
38 certmanager: certmanager.clone(),
39 transport: transport.clone(),
40 fabric,
41 }))
42 }
43
44 pub async fn commission(
51 &self,
52 connection: &Arc<transport::Connection>,
53 pin: u32,
54 node_id: u64,
55 controller_id: u64,
56 ) -> Result<Connection> {
57 let mut session = auth_spake(connection, pin).await?;
58 let session = commission::commission(
59 connection,
60 &mut session,
61 &self.fabric,
62 self.certmanager.as_ref(),
63 node_id,
64 controller_id,
65 )
66 .await?;
67 Ok(Connection {
68 active: ActiveConnection::new(connection.clone(), session),
69 })
70 }
71
72 pub async fn auth_sigma(
74 &self,
75 connection: &Arc<transport::Connection>,
76 node_id: u64,
77 controller_id: u64,
78 ) -> Result<Connection> {
79 let session = auth_sigma(
80 connection,
81 &self.fabric,
82 self.certmanager.as_ref(),
83 node_id,
84 controller_id,
85 )
86 .await?;
87 Ok(Connection {
88 active: ActiveConnection::new(connection.clone(), session),
89 })
90 }
91}
92
93impl Connection {
95 pub async fn read_request(
97 &self,
98 endpoint: u16,
99 cluster: u32,
100 attr: u32,
101 ) -> Result<Message> {
102 let exchange: u16 = rand::random();
103 let msg = messages::im_read_request(endpoint, cluster, attr, exchange)?;
104 self.active.request(exchange, &msg).await
105 }
106
107 pub async fn read_request2(
109 &self,
110 endpoint: u16,
111 cluster: u32,
112 attr: u32,
113 ) -> Result<TlvItemValue> {
114 let res = self.read_request(endpoint, cluster, attr).await?;
115 if (res.protocol_header.protocol_id
116 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION)
117 || (res.protocol_header.opcode
118 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_REPORT_DATA)
119 {
120 Err(anyhow::anyhow!(
121 "response is not expected report_data {:?}",
122 res.protocol_header
123 ))
124 } else {
125 match res.tlv.get(&[1, 0, 1, 2]) {
126 Some(a) => Ok(a.clone()),
127 None => {
128 let s = res
129 .tlv
130 .get(&[1, 0, 0, 1, 0])
131 .context("report data format not recognized1")?;
132 if let TlvItemValue::Int(status) = s {
133 Err(anyhow::anyhow!("report data with status {}", status))
134 } else {
135 Err(anyhow::anyhow!("report data format not recognized2"))
136 }
137 }
138 }
139 }
140 }
141
142 pub async fn invoke_request(
144 &self,
145 endpoint: u16,
146 cluster: u32,
147 command: u32,
148 payload: &[u8],
149 ) -> Result<Message> {
150 let exchange: u16 = rand::random();
151 log::debug!(
152 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
153 exchange,
154 endpoint,
155 cluster,
156 command
157 );
158 let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, false)?;
159 self.active.request(exchange, &msg).await
160 }
161
162 pub async fn invoke_request2(
164 &self,
165 endpoint: u16,
166 cluster: u32,
167 command: u32,
168 payload: &[u8],
169 ) -> Result<TlvItemValue> {
170 let res = self.invoke_request(endpoint, cluster, command, payload).await?;
171 let o = res.tlv.get(&[1, 0, 1, 1]).context("result not found")?;
172 Ok(o.clone())
173 }
174
175 pub async fn im_subscribe_request(
176 &self,
177 endpoint: u16,
178 cluster: u32,
179 event: u32,
180 ) -> Result<Message> {
181 let exchange: u16 = rand::random();
182 log::debug!(
183 "im_subscribe_request exch:{} endpoint:{} cluster:{} event:{}",
184 exchange,
185 endpoint,
186 cluster,
187 event
188 );
189 let msg = messages::im_subscribe_request(endpoint, cluster, exchange, event)?;
190 self.active.request(exchange, &msg).await
191 }
192
193 pub async fn im_status_response(
194 &self,
195 exchange: u16,
196 flags: u8,
197 ack: u32
198 ) -> Result<()> {
199 let msg = messages::im_status_response(exchange, flags, ack)?;
200 self.active.send(&msg).await
201 }
202
203 pub async fn invoke_request_timed(
205 &self,
206 endpoint: u16,
207 cluster: u32,
208 command: u32,
209 payload: &[u8],
210 timeout: u16,
211 ) -> Result<Message> {
212 let exchange: u16 = rand::random();
213
214 let tr = messages::im_timed_request(exchange, timeout)?;
216 let result = self.active.request(exchange, &tr).await?;
217
218 if result.protocol_header.protocol_id
219 != messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION
220 || result.protocol_header.opcode
221 != messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP
222 {
223 return Err(anyhow::anyhow!(
224 "invoke_request_timed: unexpected response {:?}",
225 result
226 ));
227 }
228 let status = result
229 .tlv
230 .get_int(&[0])
231 .context("invoke_request_timed: status not found")?;
232 if status != 0 {
233 return Err(anyhow::anyhow!(
234 "invoke_request_timed: unexpected status {}",
235 status
236 ));
237 }
238
239 log::debug!(
240 "invoke_request exch:{} endpoint:{} cluster:{} command:{}",
241 exchange,
242 endpoint,
243 cluster,
244 command
245 );
246 let msg = messages::im_invoke_request(endpoint, cluster, command, exchange, payload, true)?;
247 self.active.request(exchange, &msg).await
248 }
249
250 pub async fn recv_event(&self) -> Option<Message> {
252 self.active.recv_event().await
253 }
254
255 pub fn try_recv_event(&self) -> Option<Message> {
257 self.active.try_recv_event()
258 }
259}
260
261pub fn pin_to_passcode(pin: u32) -> Result<Vec<u8>> {
262 let mut out = Vec::new();
263 out.write_u32::<LittleEndian>(pin)?;
264 Ok(out)
265}
266
267async fn auth_spake(connection: &transport::Connection, pin: u32) -> Result<session::Session> {
268 let exchange = rand::random();
269 log::debug!("start auth_spake");
270 let mut session = session::Session::new();
271 session.my_session_id = 1;
272 let mut retrctx = retransmit::RetrContext::new(connection, &session);
273 log::debug!("send pbkdf request");
275 let pbkdf_req_protocol_message = messages::pbkdf_req(exchange)?;
276 retrctx.send(&pbkdf_req_protocol_message).await?;
277
278 let pbkdf_response = retrctx.get_next_message().await?;
280 if pbkdf_response.protocol_header.protocol_id
281 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
282 || pbkdf_response.protocol_header.opcode
283 != messages::ProtocolMessageHeader::OPCODE_PBKDF_RESP
284 {
285 return Err(anyhow::anyhow!("pbkdf response not received"));
286 }
287
288 let iterations = pbkdf_response
289 .tlv
290 .get_int(&[4, 1])
291 .context("pbkdf_response - iterations missing")?;
292 let salt = pbkdf_response
293 .tlv
294 .get_octet_string(&[4, 2])
295 .context("pbkdf_response - salt missing")?;
296 let p_session = pbkdf_response
297 .tlv
298 .get_int(&[3])
299 .context("pbkdf_response - session missing")?;
300
301 let engine = spake2p::Engine::new()?;
303 let mut ctx = engine.start(&pin_to_passcode(pin)?, salt, iterations as u32)?;
304 log::debug!("send pake1 request");
305 let pake1_protocol_message = messages::pake1(exchange, ctx.x.as_bytes(), -1)?;
306 retrctx.send(&pake1_protocol_message).await?;
307
308 let pake2 = retrctx.get_next_message().await?;
310 if pake2.protocol_header.protocol_id
311 != messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
312 || pake2.protocol_header.opcode != messages::ProtocolMessageHeader::OPCODE_PASE_PAKE2
313 {
314 return Err(anyhow::anyhow!("pake2 not received"));
315 }
316 let pake2_pb = pake2
317 .tlv
318 .get_octet_string(&[1])
319 .context("pake2 pb tlv missing")?;
320 ctx.y = p256::EncodedPoint::from_bytes(pake2_pb)?;
321
322 let mut hash_seed = "CHIP PAKE V1 Commissioning".as_bytes().to_vec();
324 hash_seed.extend_from_slice(&pbkdf_req_protocol_message[6..]);
325 hash_seed.extend_from_slice(&pbkdf_response.payload);
326 engine.finish(&mut ctx, &hash_seed)?;
327 let pake3_protocol_message = messages::pake3(
328 exchange,
329 &ctx.ca.context("ca value not present in context")?,
330 -1,
331 )?;
332 log::debug!("send pake3 request");
333 retrctx.send(&pake3_protocol_message).await?;
334
335 let pake3_resp = retrctx.get_next_message().await?;
336 match &pake3_resp.status_report_info {
337 Some(s) => {
338 if !s.is_ok() {
339 return Err(anyhow::anyhow!("pake3 resp not ok), got {:?}", pake3_resp));
340 }
341 }
342 None => {
343 return Err(anyhow::anyhow!(
344 "expecting status report (pake3 resp), got {:?}",
345 pake3_resp
346 ))
347 }
348 }
349
350 session.set_encrypt_key(&ctx.encrypt_key.context("encrypt key missing")?);
351 session.set_decrypt_key(&ctx.decrypt_key.context("decrypt key missing")?);
352 session.session_id = p_session as u16;
353 log::debug!("auth_spake ok; session: {}", session.session_id);
354 Ok(session)
355}
356
357pub(crate) async fn auth_sigma(
358 connection: &transport::Connection,
359 fabric: &fabric::Fabric,
360 cm: &dyn certmanager::CertManager,
361 node_id: u64,
362 controller_id: u64,
363) -> Result<session::Session> {
364 log::debug!("auth_sigma");
365 let exchange = rand::random();
366 let mut session = session::Session::new();
367 let mut retrctx = retransmit::RetrContext::new(connection, &mut session);
368 retrctx.subscribe_exchange(exchange);
369 let mut ctx = sigma::SigmaContext::new(node_id);
370 let ca_pubkey = cm.get_ca_key()?.public_key().to_sec1_bytes();
371 sigma::sigma1(fabric, &mut ctx, &ca_pubkey)?;
372 let s1 = messages::sigma1(exchange, &ctx.sigma1_payload)?;
373
374 log::debug!("send sigma1 {}", exchange);
375 retrctx.send(&s1).await?;
376
377 log::debug!("receive sigma2 {}", exchange);
379 let sigma2 = retrctx.get_next_message().await?;
380 log::debug!("sigma2 received {:?}", sigma2);
381 if sigma2.protocol_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
382 && sigma2.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_STATUS
383 {
384 return Err(anyhow::anyhow!("sigma2 not received, status: {}", sigma2.status_report_info.context("status report info missing")?.to_string()));
385 }
386 ctx.sigma2_payload = sigma2.payload;
387 ctx.responder_session = sigma2
388 .tlv
389 .get_int(&[2])
390 .context("responder session tlv missing in sigma2")? as u16;
391 ctx.responder_public = sigma2
392 .tlv
393 .get_octet_string(&[3])
394 .context("responder public tlv missing in sigma2")?
395 .to_vec();
396
397 let controller_private = cm.get_user_key(controller_id)?;
398 let controller_x509 = cm.get_user_cert(controller_id)?;
399 let controller_matter_cert =
400 cert_matter::convert_x509_bytes_to_matter(&controller_x509, &ca_pubkey)?;
401
402 log::debug!("send sigma3 {} with piggyback ack for {}", exchange, sigma2.message_header.message_counter);
404 sigma::sigma3(
405 fabric,
406 &mut ctx,
407 &controller_private.to_sec1_der()?,
408 &controller_matter_cert,
409 )?;
410 let sigma3 = messages::sigma3(exchange, &ctx.sigma3_payload, sigma2.message_header.message_counter)?;
411 retrctx.send(&sigma3).await?;
412
413 log::debug!("receive result {}", exchange);
414 let status = retrctx.get_next_message().await?;
415 if !status
416 .status_report_info
417 .context("sigma3 status resp not received")?
418 .is_ok()
419 {
420 return Err(anyhow::anyhow!(format!(
421 "response to sigma3 does not contain status ok {:?}",
422 status
423 )));
424 }
425
426 let mut th = ctx.sigma1_payload.clone();
428 th.extend_from_slice(&ctx.sigma2_payload);
429
430 let mut transcript = th;
431 transcript.extend_from_slice(&ctx.sigma3_payload);
432 let transcript_hash = cryptoutil::sha256(&transcript);
433 let mut salt = fabric.signed_ipk()?;
434 salt.extend_from_slice(&transcript_hash);
435 let shared = ctx.shared.context("shared secret not in context")?;
436 let keypack = cryptoutil::hkdf_sha256(
437 &salt,
438 shared.raw_secret_bytes().as_slice(),
439 "SessionKeys".as_bytes(),
440 16 * 3,
441 )?;
442 let mut ses = session::Session::new();
443 ses.session_id = ctx.responder_session;
444 ses.my_session_id = ctx.session_id;
445 ses.set_decrypt_key(&keypack[16..32]);
446 ses.set_encrypt_key(&keypack[..16]);
447
448 let mut local_node = Vec::new();
449 local_node.write_u64::<LittleEndian>(controller_id)?;
450 ses.local_node = Some(local_node);
451
452 let mut remote_node = Vec::new();
453 remote_node.write_u64::<LittleEndian>(node_id)?;
454 ses.remote_node = Some(remote_node);
455
456 Ok(ses)
457}
458