matc/device/
mod.rs

1//! very experimental device implementation with many things hardcoded for testing and development purposes.
2
3mod attributes;
4mod case_handler;
5mod commissioning;
6mod crypto;
7mod interaction;
8mod pase;
9mod persist;
10mod send;
11mod types;
12
13pub use types::DeviceConfig;
14pub use attributes::{attr_get_bool, attr_set_bool};
15use types::{ActiveSubscription, CaseState, FabricInfo, PaseState, PendingChunkState, SubscribeState};
16
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU32, Ordering};
21
22use anyhow::{Ok, Result};
23use tokio::net::UdpSocket;
24
25use crate::{messages, session, tlv};
26
27/// Result returned by [`AppHandler::handle_command`].
28pub enum CommandResult {
29    /// Command succeeded; library will send an IM status response with status 0.
30    Success,
31    /// Command failed with the given status code.
32    Error(u16),
33    /// Command is not handled by this application handler.
34    Unhandled,
35}
36
37pub trait AppHandler: Send {
38    fn handle_command(
39        &mut self,
40        endpoint: u16,
41        cluster: u32,
42        command: u32,
43        payload: &tlv::TlvItem,
44        attributes: &mut HashMap<(u16, u32, u32), Vec<u8>>,
45        dirty_attributes: &mut HashSet<(u16, u32, u32)>,
46    ) -> CommandResult;
47}
48
49pub struct Device {
50    pub(crate) config: DeviceConfig,
51    pub(crate) socket: UdpSocket,
52    pub(crate) salt: Vec<u8>,
53    pub(crate) pbkdf_iterations: u32,
54    pub(crate) operational_key: p256::SecretKey,
55    pub(crate) message_counter: AtomicU32,
56    // Commissioning state
57    pub(crate) pase_state: Option<PaseState>,
58    pub(crate) pase_session: Option<session::Session>,
59    pub(crate) case_states: HashMap<u16, CaseState>,
60    pub(crate) case_sessions: Vec<session::Session>,
61    pub(crate) subscribe_states: Vec<SubscribeState>,
62    pub(crate) active_subscriptions: Vec<ActiveSubscription>,
63    pub(crate) pending_chunks: Vec<PendingChunkState>,
64    // Commissioned fabric table (supports multiple fabrics)
65    pub(crate) fabrics: Vec<FabricInfo>,
66    /// Next fabric index to assign (1-based, monotonically increasing).
67    pub(crate) next_fabric_index: u8,
68    /// Temporary root cert from AddTrustedRootCertificate, consumed by AddNOC.
69    pub(crate) pending_root_cert: Option<Vec<u8>>,
70    // Duplicate detection
71    pub(crate) received_counters: HashSet<u32>,
72    // Attribute store: (endpoint, cluster, attribute) -> pre-tagged TLV at context tag 2
73    pub(crate) attributes: HashMap<(u16, u32, u32), Vec<u8>>,
74    /// Attributes mutated since last subscription report was sent.
75    pub(crate) dirty_attributes: HashSet<(u16, u32, u32)>,
76    pub(crate) mdns: Arc<crate::mdns2::MdnsService>,
77    /// Extra attributes to include in persistence (registered by user code).
78    pub(crate) extra_persisted: Vec<(u16, u32, u32)>,
79}
80
81impl Device {
82    pub async fn new(config: DeviceConfig, mdns: Arc<crate::mdns2::MdnsService>) -> Result<Self> {
83        let socket = UdpSocket::bind(&config.listen_address).await?;
84        let mut salt = vec![0u8; 32];
85        rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut salt);
86        let operational_key = p256::SecretKey::random(&mut rand::thread_rng());
87        let mut device = Self {
88            config,
89            socket,
90            salt,
91            pbkdf_iterations: 1000,
92            operational_key,
93            message_counter: AtomicU32::new(rand::random()),
94            pase_state: None,
95            pase_session: None,
96            case_states: HashMap::new(),
97            case_sessions: Vec::new(),
98            subscribe_states: Vec::new(),
99            active_subscriptions: Vec::new(),
100            pending_chunks: Vec::new(),
101            fabrics: Vec::new(),
102            next_fabric_index: 1,
103            pending_root_cert: None,
104            received_counters: HashSet::new(),
105            attributes: HashMap::new(),
106            dirty_attributes: HashSet::new(),
107            mdns,
108            extra_persisted: Vec::new(),
109        };
110        device.setup_default_attributes()?;
111        device.dirty_attributes.clear();
112
113        // Register mDNS commissionable service advertisement
114        let port: u16 = device
115            .config
116            .listen_address
117            .rsplit(':')
118            .next()
119            .and_then(|p| p.parse().ok())
120            .unwrap_or(5540);
121        let short_disc = device.config.discriminator >> 8;
122        let instance_name = format!("{:016X}", rand::random::<u64>());
123        let svc = crate::mdns2::ServiceRegistration {
124            instance_name,
125            service_type: "_matterc._udp.local".to_string(),
126            port,
127            txt_records: vec![
128                ("DN".to_string(), device.config.product_name.clone()),
129                ("D".to_string(), device.config.discriminator.to_string()),
130                (
131                    "VP".to_string(),
132                    format!("{}+{}", device.config.vendor_id, device.config.product_id),
133                ),
134                ("CM".to_string(), "1".to_string()),
135                ("PH".to_string(), "33".to_string()),
136                ("DT".to_string(), "256".to_string()),
137            ],
138            hostname: device.config.hostname.clone(),
139            ttl: 120,
140            subtypes: vec![format!("_S{}", short_disc)],
141        };
142        device.mdns.register_service(svc).await;
143
144        Ok(device)
145    }
146
147    pub(crate) fn next_counter(&self) -> u32 {
148        self.message_counter.fetch_add(1, Ordering::Relaxed)
149    }
150
151    pub async fn run(&mut self, handler: &mut dyn AppHandler) -> Result<()> {
152        let mut buf = [0u8; 4096];
153        log::info!(
154            "Device listening on {} (PIN: {})",
155            self.config.listen_address,
156            self.config.pin
157        );
158        loop {
159            let max_interval = self
160                .active_subscriptions
161                .iter()
162                .map(|sub| sub.max_interval_secs as u64)
163                .min()
164                .map(std::time::Duration::from_secs)
165                .unwrap_or_else(|| std::time::Duration::from_secs(3600));
166            let has_dirty = !self.dirty_attributes.is_empty();
167            tokio::select! {
168                result = self.socket.recv_from(&mut buf) => {
169                    let (len, addr) = result?;
170                    let data = buf[..len].to_vec();
171                    if let Err(e) = self.handle_packet(&data, &addr, handler).await {
172                        log::warn!("Error handling packet from {}: {:?}", addr, e);
173                    }
174                }
175                _ = tokio::time::sleep(max_interval) => {
176                    if let Err(e) = self.send_subscription_report().await {
177                        log::warn!("Error sending subscription keepalive: {:?}", e);
178                    }
179                }
180                _ = tokio::time::sleep(std::time::Duration::from_secs(1)), if has_dirty => {
181                    if let Err(e) = self.send_subscription_report().await {
182                        log::warn!("Error sending dirty subscription report: {:?}", e);
183                    }
184                }
185            }
186        }
187    }
188
189    async fn handle_packet(&mut self, data: &[u8], addr: &std::net::SocketAddr, handler: &mut dyn AppHandler) -> Result<()> {
190        let (msg_header, rest) = messages::MessageHeader::decode(data)?;
191        log::debug!(
192            "Received message: session={} counter={} from {}",
193            msg_header.session_id,
194            msg_header.message_counter,
195            addr
196        );
197
198        // Duplicate detection
199        if self.received_counters.contains(&msg_header.message_counter) {
200            log::debug!(
201                "Dropping duplicate message counter={}",
202                msg_header.message_counter
203            );
204            return Ok(());
205        }
206        self.received_counters.insert(msg_header.message_counter);
207
208        // Try to decrypt if we have a session
209        let payload = if msg_header.session_id != 0 {
210            // Encrypted message - search CASE sessions, then PASE
211            let session = self
212                .case_sessions
213                .iter()
214                .find(|s| s.my_session_id == msg_header.session_id)
215                .or_else(|| {
216                    self.pase_session
217                        .as_ref()
218                        .filter(|s| s.my_session_id == msg_header.session_id)
219                });
220            match session {
221                Some(ses) => {
222                    let decrypted = ses.decode_message(data)?;
223                    let (_, proto_rest) = messages::MessageHeader::decode(&decrypted)?;
224                    proto_rest
225                }
226                None => {
227                    log::debug!(
228                        "No session for session_id={}, dropping",
229                        msg_header.session_id
230                    );
231                    return Ok(());
232                }
233            }
234        } else {
235            rest
236        };
237
238        let (proto_header, proto_payload) = messages::ProtocolMessageHeader::decode(&payload)?;
239        log::debug!(
240            "Protocol: opcode=0x{:02x} protocol={} exchange={} flags=0x{:02x}",
241            proto_header.opcode,
242            proto_header.protocol_id,
243            proto_header.exchange_id,
244            proto_header.exchange_flags
245        );
246
247        // Handle ACKs - just ignore
248        if proto_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
249            && proto_header.opcode == messages::ProtocolMessageHeader::OPCODE_ACK
250        {
251            return Ok(());
252        }
253
254        match (proto_header.protocol_id, proto_header.opcode) {
255            (
256                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
257                messages::ProtocolMessageHeader::OPCODE_PBKDF_REQ,
258            ) => {
259                self.handle_pbkdf_req(addr, &msg_header, &proto_header, &proto_payload, &payload)
260                    .await
261            }
262            (
263                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
264                messages::ProtocolMessageHeader::OPCODE_PASE_PAKE1,
265            ) => {
266                self.handle_pake1(addr, &msg_header, &proto_header, &proto_payload)
267                    .await
268            }
269            (
270                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
271                messages::ProtocolMessageHeader::OPCODE_PASE_PAKE3,
272            ) => {
273                self.handle_pake3(addr, &msg_header, &proto_header, &proto_payload)
274                    .await
275            }
276            (
277                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
278                messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA1,
279            ) => {
280                self.handle_sigma1(addr, &msg_header, &proto_header, &proto_payload)
281                    .await
282            }
283            (
284                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
285                messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA3,
286            ) => {
287                self.handle_sigma3(addr, &msg_header, &proto_header, &proto_payload)
288                    .await
289            }
290            (
291                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
292                messages::ProtocolMessageHeader::OPCODE_STATUS,
293            ) => {
294                self.handle_status_report(&proto_payload).await
295            }
296            (
297                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
298                messages::ProtocolMessageHeader::INTERACTION_OPCODE_INVOKE_REQ,
299            ) => {
300                self.handle_invoke_request(addr, &msg_header, &proto_header, &proto_payload, handler)
301                    .await
302            }
303            (
304                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
305                messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP,
306            ) => {
307                self.handle_status_response(addr, &msg_header, &proto_header)
308                    .await
309            }
310            (
311                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
312                messages::ProtocolMessageHeader::INTERACTION_OPCODE_READ_REQ,
313            ) => {
314                log::debug!("Received IM read request");
315                self.handle_read_request(addr, &msg_header, &proto_header, &proto_payload)
316                    .await
317            }
318            (
319                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
320                messages::ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_REQ,
321            ) => {
322                log::debug!("Received IM subscribe request");
323                self.handle_subscribe_request(addr, &msg_header, &proto_header, &proto_payload)
324                    .await
325            }
326            (
327                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
328                messages::ProtocolMessageHeader::INTERACTION_OPCODE_WRITE_REQ,
329            ) => {
330                log::debug!("Received IM write request");
331                self.handle_write_request(addr, &msg_header, &proto_header, &proto_payload)
332                    .await
333            }
334
335            _ => {
336                log::warn!(
337                    "Unhandled opcode: protocol={} opcode=0x{:02x}",
338                    proto_header.protocol_id,
339                    proto_header.opcode
340                );
341                Ok(())
342            }
343        }
344    }
345}