1mod attributes;
4mod case_handler;
5mod commissioning;
6mod crypto;
7mod interaction;
8mod pase;
9mod persist;
10mod send;
11mod types;
12
13pub use types::DeviceConfig;
14pub use attributes::AttrContext;
15use types::{ActiveSubscription, CaseState, FabricInfo, PaseState, PendingChunkState, SubscribeState};
16
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU32, Ordering};
21
22use anyhow::{Ok, Result};
23use tokio::net::UdpSocket;
24
25use crate::{messages, session, tlv};
26
27pub enum CommandResult {
29 Success,
31 Error(u16),
33 Unhandled,
35}
36
37pub trait AppHandler: Send {
38 fn handle_command(
39 &mut self,
40 endpoint: u16,
41 cluster: u32,
42 command: u32,
43 payload: &tlv::TlvItem,
44 attrs: &mut AttrContext,
45 ) -> CommandResult;
46}
47
48pub struct Device {
49 pub(crate) config: DeviceConfig,
50 pub(crate) socket: UdpSocket,
51 pub(crate) salt: Vec<u8>,
52 pub(crate) pbkdf_iterations: u32,
53 pub(crate) operational_key: p256::SecretKey,
54 pub(crate) message_counter: AtomicU32,
55 pub(crate) pase_state: Option<PaseState>,
57 pub(crate) pase_session: Option<session::Session>,
58 pub(crate) case_states: HashMap<u16, CaseState>,
59 pub(crate) case_sessions: Vec<session::Session>,
60 pub(crate) subscribe_states: Vec<SubscribeState>,
61 pub(crate) active_subscriptions: Vec<ActiveSubscription>,
62 pub(crate) pending_chunks: Vec<PendingChunkState>,
63 pub(crate) fabrics: Vec<FabricInfo>,
65 pub(crate) next_fabric_index: u8,
67 pub(crate) pending_root_cert: Option<Vec<u8>>,
69 pub(crate) received_counters: HashSet<u32>,
71 pub(crate) endpoints: Vec<u16>,
73 pub(crate) attributes: HashMap<(u16, u32, u32), Vec<u8>>,
75 pub(crate) dirty_attributes: HashSet<(u16, u32, u32)>,
77 pub(crate) mdns: Arc<crate::mdns2::MdnsService>,
78 pub(crate) extra_persisted: Vec<(u16, u32, u32)>,
80}
81
82impl Device {
83 pub async fn new(config: DeviceConfig, mdns: Arc<crate::mdns2::MdnsService>) -> Result<Self> {
84 let socket = UdpSocket::bind(&config.listen_address).await?;
85 let mut salt = vec![0u8; 32];
86 rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut salt);
87 let operational_key = p256::SecretKey::random(&mut rand::thread_rng());
88 let mut device = Self {
89 config,
90 socket,
91 salt,
92 pbkdf_iterations: 1000,
93 operational_key,
94 message_counter: AtomicU32::new(rand::random()),
95 pase_state: None,
96 pase_session: None,
97 case_states: HashMap::new(),
98 case_sessions: Vec::new(),
99 subscribe_states: Vec::new(),
100 active_subscriptions: Vec::new(),
101 pending_chunks: Vec::new(),
102 fabrics: Vec::new(),
103 next_fabric_index: 1,
104 pending_root_cert: None,
105 received_counters: HashSet::new(),
106 endpoints: vec![0],
107 attributes: HashMap::new(),
108 dirty_attributes: HashSet::new(),
109 mdns,
110 extra_persisted: Vec::new(),
111 };
112 device.setup_default_attributes()?;
113 device.dirty_attributes.clear();
114
115 let port: u16 = device
117 .config
118 .listen_address
119 .rsplit(':')
120 .next()
121 .and_then(|p| p.parse().ok())
122 .unwrap_or(5540);
123 let short_disc = device.config.discriminator >> 8;
124 let instance_name = format!("{:016X}", rand::random::<u64>());
125 let (adv_v4, adv_v6) = device.config.split_advertise_ips();
126 let svc = crate::mdns2::ServiceRegistration {
127 instance_name,
128 service_type: "_matterc._udp.local".to_string(),
129 port,
130 txt_records: vec![
131 ("DN".to_string(), device.config.product_name.clone()),
132 ("D".to_string(), device.config.discriminator.to_string()),
133 (
134 "VP".to_string(),
135 format!("{}+{}", device.config.vendor_id, device.config.product_id),
136 ),
137 ("CM".to_string(), "1".to_string()),
138 ("PH".to_string(), "33".to_string()),
139 ("DT".to_string(), "256".to_string()),
140 ],
141 hostname: device.config.hostname.clone(),
142 ttl: 120,
143 subtypes: vec![format!("_S{}", short_disc)],
144 ips_v4: adv_v4,
145 ips_v6: adv_v6,
146 };
147 device.mdns.register_service(svc).await;
148
149 Ok(device)
150 }
151
152 pub(crate) fn next_counter(&self) -> u32 {
153 self.message_counter.fetch_add(1, Ordering::Relaxed)
154 }
155
156 pub async fn run(&mut self, handler: &mut dyn AppHandler) -> Result<()> {
157 let mut buf = [0u8; 4096];
158 log::info!(
159 "Device listening on {} (PIN: {})",
160 self.config.listen_address,
161 self.config.pin
162 );
163 loop {
164 let max_interval = self
165 .active_subscriptions
166 .iter()
167 .map(|sub| sub.max_interval_secs as u64)
168 .min()
169 .map(std::time::Duration::from_secs)
170 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
171 let has_dirty = !self.dirty_attributes.is_empty();
172 tokio::select! {
173 result = self.socket.recv_from(&mut buf) => {
174 let (len, addr) = result?;
175 let data = buf[..len].to_vec();
176 if let Err(e) = self.handle_packet(&data, &addr, handler).await {
177 log::warn!("Error handling packet from {}: {:?}", addr, e);
178 }
179 }
180 _ = tokio::time::sleep(max_interval) => {
181 if let Err(e) = self.send_subscription_report().await {
182 log::warn!("Error sending subscription keepalive: {:?}", e);
183 }
184 }
185 _ = tokio::time::sleep(std::time::Duration::from_secs(1)), if has_dirty => {
186 if let Err(e) = self.send_subscription_report().await {
187 log::warn!("Error sending dirty subscription report: {:?}", e);
188 }
189 }
190 }
191 }
192 }
193
194 async fn handle_packet(&mut self, data: &[u8], addr: &std::net::SocketAddr, handler: &mut dyn AppHandler) -> Result<()> {
195 let (msg_header, rest) = messages::MessageHeader::decode(data)?;
196 log::debug!(
197 "Received message: session={} counter={} from {}",
198 msg_header.session_id,
199 msg_header.message_counter,
200 addr
201 );
202
203 if self.received_counters.contains(&msg_header.message_counter) {
205 log::debug!(
206 "Dropping duplicate message counter={}",
207 msg_header.message_counter
208 );
209 return Ok(());
210 }
211 self.received_counters.insert(msg_header.message_counter);
212
213 let payload = if msg_header.session_id != 0 {
215 let session = self
217 .case_sessions
218 .iter()
219 .find(|s| s.my_session_id == msg_header.session_id)
220 .or_else(|| {
221 self.pase_session
222 .as_ref()
223 .filter(|s| s.my_session_id == msg_header.session_id)
224 });
225 match session {
226 Some(ses) => {
227 let decrypted = ses.decode_message(data)?;
228 let (_, proto_rest) = messages::MessageHeader::decode(&decrypted)?;
229 proto_rest
230 }
231 None => {
232 log::debug!(
233 "No session for session_id={}, dropping",
234 msg_header.session_id
235 );
236 return Ok(());
237 }
238 }
239 } else {
240 rest
241 };
242
243 let (proto_header, proto_payload) = messages::ProtocolMessageHeader::decode(&payload)?;
244 log::debug!(
245 "Protocol: opcode=0x{:02x} protocol={} exchange={} flags=0x{:02x}",
246 proto_header.opcode,
247 proto_header.protocol_id,
248 proto_header.exchange_id,
249 proto_header.exchange_flags
250 );
251
252 if proto_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
254 && proto_header.opcode == messages::ProtocolMessageHeader::OPCODE_ACK
255 {
256 return Ok(());
257 }
258
259 match (proto_header.protocol_id, proto_header.opcode) {
260 (
261 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
262 messages::ProtocolMessageHeader::OPCODE_PBKDF_REQ,
263 ) => {
264 self.handle_pbkdf_req(addr, &msg_header, &proto_header, &proto_payload, &payload)
265 .await
266 }
267 (
268 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
269 messages::ProtocolMessageHeader::OPCODE_PASE_PAKE1,
270 ) => {
271 self.handle_pake1(addr, &msg_header, &proto_header, &proto_payload)
272 .await
273 }
274 (
275 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
276 messages::ProtocolMessageHeader::OPCODE_PASE_PAKE3,
277 ) => {
278 self.handle_pake3(addr, &msg_header, &proto_header, &proto_payload)
279 .await
280 }
281 (
282 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
283 messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA1,
284 ) => {
285 self.handle_sigma1(addr, &msg_header, &proto_header, &proto_payload)
286 .await
287 }
288 (
289 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
290 messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA3,
291 ) => {
292 self.handle_sigma3(addr, &msg_header, &proto_header, &proto_payload)
293 .await
294 }
295 (
296 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
297 messages::ProtocolMessageHeader::OPCODE_STATUS,
298 ) => {
299 self.handle_status_report(&proto_payload).await
300 }
301 (
302 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
303 messages::ProtocolMessageHeader::INTERACTION_OPCODE_INVOKE_REQ,
304 ) => {
305 self.handle_invoke_request(addr, &msg_header, &proto_header, &proto_payload, handler)
306 .await
307 }
308 (
309 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
310 messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP,
311 ) => {
312 self.handle_status_response(addr, &msg_header, &proto_header)
313 .await
314 }
315 (
316 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
317 messages::ProtocolMessageHeader::INTERACTION_OPCODE_READ_REQ,
318 ) => {
319 log::debug!("Received IM read request");
320 self.handle_read_request(addr, &msg_header, &proto_header, &proto_payload)
321 .await
322 }
323 (
324 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
325 messages::ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_REQ,
326 ) => {
327 log::debug!("Received IM subscribe request");
328 self.handle_subscribe_request(addr, &msg_header, &proto_header, &proto_payload)
329 .await
330 }
331 (
332 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
333 messages::ProtocolMessageHeader::INTERACTION_OPCODE_WRITE_REQ,
334 ) => {
335 log::debug!("Received IM write request");
336 self.handle_write_request(addr, &msg_header, &proto_header, &proto_payload)
337 .await
338 }
339
340 _ => {
341 log::warn!(
342 "Unhandled opcode: protocol={} opcode=0x{:02x}",
343 proto_header.protocol_id,
344 proto_header.opcode
345 );
346 Ok(())
347 }
348 }
349 }
350}