1mod attributes;
4mod case_handler;
5mod commissioning;
6mod crypto;
7mod interaction;
8mod pase;
9mod persist;
10mod send;
11mod types;
12
13pub use types::DeviceConfig;
14pub use attributes::{attr_get_bool, attr_set_bool};
15use types::{ActiveSubscription, CaseState, FabricInfo, PaseState, PendingChunkState, SubscribeState};
16
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU32, Ordering};
21
22use anyhow::{Ok, Result};
23use tokio::net::UdpSocket;
24
25use crate::{messages, session, tlv};
26
27pub enum CommandResult {
29 Success,
31 Error(u16),
33 Unhandled,
35}
36
37pub trait AppHandler: Send {
38 fn handle_command(
39 &mut self,
40 endpoint: u16,
41 cluster: u32,
42 command: u32,
43 payload: &tlv::TlvItem,
44 attributes: &mut HashMap<(u16, u32, u32), Vec<u8>>,
45 dirty_attributes: &mut HashSet<(u16, u32, u32)>,
46 ) -> CommandResult;
47}
48
49pub struct Device {
50 pub(crate) config: DeviceConfig,
51 pub(crate) socket: UdpSocket,
52 pub(crate) salt: Vec<u8>,
53 pub(crate) pbkdf_iterations: u32,
54 pub(crate) operational_key: p256::SecretKey,
55 pub(crate) message_counter: AtomicU32,
56 pub(crate) pase_state: Option<PaseState>,
58 pub(crate) pase_session: Option<session::Session>,
59 pub(crate) case_states: HashMap<u16, CaseState>,
60 pub(crate) case_sessions: Vec<session::Session>,
61 pub(crate) subscribe_states: Vec<SubscribeState>,
62 pub(crate) active_subscriptions: Vec<ActiveSubscription>,
63 pub(crate) pending_chunks: Vec<PendingChunkState>,
64 pub(crate) fabrics: Vec<FabricInfo>,
66 pub(crate) next_fabric_index: u8,
68 pub(crate) pending_root_cert: Option<Vec<u8>>,
70 pub(crate) received_counters: HashSet<u32>,
72 pub(crate) attributes: HashMap<(u16, u32, u32), Vec<u8>>,
74 pub(crate) dirty_attributes: HashSet<(u16, u32, u32)>,
76 pub(crate) mdns: Arc<crate::mdns2::MdnsService>,
77 pub(crate) extra_persisted: Vec<(u16, u32, u32)>,
79}
80
81impl Device {
82 pub async fn new(config: DeviceConfig, mdns: Arc<crate::mdns2::MdnsService>) -> Result<Self> {
83 let socket = UdpSocket::bind(&config.listen_address).await?;
84 let mut salt = vec![0u8; 32];
85 rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut salt);
86 let operational_key = p256::SecretKey::random(&mut rand::thread_rng());
87 let mut device = Self {
88 config,
89 socket,
90 salt,
91 pbkdf_iterations: 1000,
92 operational_key,
93 message_counter: AtomicU32::new(rand::random()),
94 pase_state: None,
95 pase_session: None,
96 case_states: HashMap::new(),
97 case_sessions: Vec::new(),
98 subscribe_states: Vec::new(),
99 active_subscriptions: Vec::new(),
100 pending_chunks: Vec::new(),
101 fabrics: Vec::new(),
102 next_fabric_index: 1,
103 pending_root_cert: None,
104 received_counters: HashSet::new(),
105 attributes: HashMap::new(),
106 dirty_attributes: HashSet::new(),
107 mdns,
108 extra_persisted: Vec::new(),
109 };
110 device.setup_default_attributes()?;
111 device.dirty_attributes.clear();
112
113 let port: u16 = device
115 .config
116 .listen_address
117 .rsplit(':')
118 .next()
119 .and_then(|p| p.parse().ok())
120 .unwrap_or(5540);
121 let short_disc = device.config.discriminator >> 8;
122 let instance_name = format!("{:016X}", rand::random::<u64>());
123 let svc = crate::mdns2::ServiceRegistration {
124 instance_name,
125 service_type: "_matterc._udp.local".to_string(),
126 port,
127 txt_records: vec![
128 ("DN".to_string(), device.config.product_name.clone()),
129 ("D".to_string(), device.config.discriminator.to_string()),
130 (
131 "VP".to_string(),
132 format!("{}+{}", device.config.vendor_id, device.config.product_id),
133 ),
134 ("CM".to_string(), "1".to_string()),
135 ("PH".to_string(), "33".to_string()),
136 ("DT".to_string(), "256".to_string()),
137 ],
138 hostname: device.config.hostname.clone(),
139 ttl: 120,
140 subtypes: vec![format!("_S{}", short_disc)],
141 };
142 device.mdns.register_service(svc).await;
143
144 Ok(device)
145 }
146
147 pub(crate) fn next_counter(&self) -> u32 {
148 self.message_counter.fetch_add(1, Ordering::Relaxed)
149 }
150
151 pub async fn run(&mut self, handler: &mut dyn AppHandler) -> Result<()> {
152 let mut buf = [0u8; 4096];
153 log::info!(
154 "Device listening on {} (PIN: {})",
155 self.config.listen_address,
156 self.config.pin
157 );
158 loop {
159 let max_interval = self
160 .active_subscriptions
161 .iter()
162 .map(|sub| sub.max_interval_secs as u64)
163 .min()
164 .map(std::time::Duration::from_secs)
165 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
166 let has_dirty = !self.dirty_attributes.is_empty();
167 tokio::select! {
168 result = self.socket.recv_from(&mut buf) => {
169 let (len, addr) = result?;
170 let data = buf[..len].to_vec();
171 if let Err(e) = self.handle_packet(&data, &addr, handler).await {
172 log::warn!("Error handling packet from {}: {:?}", addr, e);
173 }
174 }
175 _ = tokio::time::sleep(max_interval) => {
176 if let Err(e) = self.send_subscription_report().await {
177 log::warn!("Error sending subscription keepalive: {:?}", e);
178 }
179 }
180 _ = tokio::time::sleep(std::time::Duration::from_secs(1)), if has_dirty => {
181 if let Err(e) = self.send_subscription_report().await {
182 log::warn!("Error sending dirty subscription report: {:?}", e);
183 }
184 }
185 }
186 }
187 }
188
189 async fn handle_packet(&mut self, data: &[u8], addr: &std::net::SocketAddr, handler: &mut dyn AppHandler) -> Result<()> {
190 let (msg_header, rest) = messages::MessageHeader::decode(data)?;
191 log::debug!(
192 "Received message: session={} counter={} from {}",
193 msg_header.session_id,
194 msg_header.message_counter,
195 addr
196 );
197
198 if self.received_counters.contains(&msg_header.message_counter) {
200 log::debug!(
201 "Dropping duplicate message counter={}",
202 msg_header.message_counter
203 );
204 return Ok(());
205 }
206 self.received_counters.insert(msg_header.message_counter);
207
208 let payload = if msg_header.session_id != 0 {
210 let session = self
212 .case_sessions
213 .iter()
214 .find(|s| s.my_session_id == msg_header.session_id)
215 .or_else(|| {
216 self.pase_session
217 .as_ref()
218 .filter(|s| s.my_session_id == msg_header.session_id)
219 });
220 match session {
221 Some(ses) => {
222 let decrypted = ses.decode_message(data)?;
223 let (_, proto_rest) = messages::MessageHeader::decode(&decrypted)?;
224 proto_rest
225 }
226 None => {
227 log::debug!(
228 "No session for session_id={}, dropping",
229 msg_header.session_id
230 );
231 return Ok(());
232 }
233 }
234 } else {
235 rest
236 };
237
238 let (proto_header, proto_payload) = messages::ProtocolMessageHeader::decode(&payload)?;
239 log::debug!(
240 "Protocol: opcode=0x{:02x} protocol={} exchange={} flags=0x{:02x}",
241 proto_header.opcode,
242 proto_header.protocol_id,
243 proto_header.exchange_id,
244 proto_header.exchange_flags
245 );
246
247 if proto_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
249 && proto_header.opcode == messages::ProtocolMessageHeader::OPCODE_ACK
250 {
251 return Ok(());
252 }
253
254 match (proto_header.protocol_id, proto_header.opcode) {
255 (
256 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
257 messages::ProtocolMessageHeader::OPCODE_PBKDF_REQ,
258 ) => {
259 self.handle_pbkdf_req(addr, &msg_header, &proto_header, &proto_payload, &payload)
260 .await
261 }
262 (
263 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
264 messages::ProtocolMessageHeader::OPCODE_PASE_PAKE1,
265 ) => {
266 self.handle_pake1(addr, &msg_header, &proto_header, &proto_payload)
267 .await
268 }
269 (
270 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
271 messages::ProtocolMessageHeader::OPCODE_PASE_PAKE3,
272 ) => {
273 self.handle_pake3(addr, &msg_header, &proto_header, &proto_payload)
274 .await
275 }
276 (
277 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
278 messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA1,
279 ) => {
280 self.handle_sigma1(addr, &msg_header, &proto_header, &proto_payload)
281 .await
282 }
283 (
284 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
285 messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA3,
286 ) => {
287 self.handle_sigma3(addr, &msg_header, &proto_header, &proto_payload)
288 .await
289 }
290 (
291 messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
292 messages::ProtocolMessageHeader::OPCODE_STATUS,
293 ) => {
294 self.handle_status_report(&proto_payload).await
295 }
296 (
297 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
298 messages::ProtocolMessageHeader::INTERACTION_OPCODE_INVOKE_REQ,
299 ) => {
300 self.handle_invoke_request(addr, &msg_header, &proto_header, &proto_payload, handler)
301 .await
302 }
303 (
304 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
305 messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP,
306 ) => {
307 self.handle_status_response(addr, &msg_header, &proto_header)
308 .await
309 }
310 (
311 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
312 messages::ProtocolMessageHeader::INTERACTION_OPCODE_READ_REQ,
313 ) => {
314 log::debug!("Received IM read request");
315 self.handle_read_request(addr, &msg_header, &proto_header, &proto_payload)
316 .await
317 }
318 (
319 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
320 messages::ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_REQ,
321 ) => {
322 log::debug!("Received IM subscribe request");
323 self.handle_subscribe_request(addr, &msg_header, &proto_header, &proto_payload)
324 .await
325 }
326 (
327 messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
328 messages::ProtocolMessageHeader::INTERACTION_OPCODE_WRITE_REQ,
329 ) => {
330 log::debug!("Received IM write request");
331 self.handle_write_request(addr, &msg_header, &proto_header, &proto_payload)
332 .await
333 }
334
335 _ => {
336 log::warn!(
337 "Unhandled opcode: protocol={} opcode=0x{:02x}",
338 proto_header.protocol_id,
339 proto_header.opcode
340 );
341 Ok(())
342 }
343 }
344 }
345}