Skip to main content

matc/device/
mod.rs

1//! very experimental device implementation with many things hardcoded for testing and development purposes.
2
3mod attributes;
4mod case_handler;
5mod commissioning;
6mod crypto;
7mod interaction;
8mod pase;
9mod persist;
10mod send;
11mod types;
12
13pub use types::DeviceConfig;
14pub use attributes::AttrContext;
15use types::{ActiveSubscription, CaseState, FabricInfo, PaseState, PendingChunkState, SubscribeState};
16
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU32, Ordering};
21
22use anyhow::{Ok, Result};
23use tokio::net::UdpSocket;
24
25use crate::{messages, session, tlv};
26
27/// Result returned by [`AppHandler::handle_command`].
28pub enum CommandResult {
29    /// Command succeeded; library will send an IM status response with status 0.
30    Success,
31    /// Command failed with the given status code.
32    Error(u16),
33    /// Command is not handled by this application handler.
34    Unhandled,
35}
36
37pub trait AppHandler: Send {
38    fn handle_command(
39        &mut self,
40        endpoint: u16,
41        cluster: u32,
42        command: u32,
43        payload: &tlv::TlvItem,
44        attrs: &mut AttrContext,
45    ) -> CommandResult;
46}
47
48pub struct Device {
49    pub(crate) config: DeviceConfig,
50    pub(crate) socket: UdpSocket,
51    pub(crate) salt: Vec<u8>,
52    pub(crate) pbkdf_iterations: u32,
53    pub(crate) operational_key: p256::SecretKey,
54    pub(crate) message_counter: AtomicU32,
55    // Commissioning state
56    pub(crate) pase_state: Option<PaseState>,
57    pub(crate) pase_session: Option<session::Session>,
58    pub(crate) case_states: HashMap<u16, CaseState>,
59    pub(crate) case_sessions: Vec<session::Session>,
60    pub(crate) subscribe_states: Vec<SubscribeState>,
61    pub(crate) active_subscriptions: Vec<ActiveSubscription>,
62    pub(crate) pending_chunks: Vec<PendingChunkState>,
63    // Commissioned fabric table (supports multiple fabrics)
64    pub(crate) fabrics: Vec<FabricInfo>,
65    /// Next fabric index to assign (1-based, monotonically increasing).
66    pub(crate) next_fabric_index: u8,
67    /// Temporary root cert from AddTrustedRootCertificate, consumed by AddNOC.
68    pub(crate) pending_root_cert: Option<Vec<u8>>,
69    /// Replay detection for unencrypted (handshake) messages, keyed by
70    /// source node id. Secure sessions carry their own reception state.
71    pub(crate) unencrypted_reception: HashMap<u64, session::MessageReceptionState>,
72    /// Registered application endpoints (EP0 is always present).
73    pub(crate) endpoints: Vec<u16>,
74    // Attribute store: (endpoint, cluster, attribute) -> pre-tagged TLV at context tag 2
75    pub(crate) attributes: HashMap<(u16, u32, u32), Vec<u8>>,
76    /// Attributes mutated since last subscription report was sent.
77    pub(crate) dirty_attributes: HashSet<(u16, u32, u32)>,
78    pub(crate) mdns: Arc<crate::mdns2::MdnsService>,
79    /// Extra attributes to include in persistence (registered by user code).
80    pub(crate) extra_persisted: Vec<(u16, u32, u32)>,
81}
82
83impl Device {
84    pub async fn new(config: DeviceConfig, mdns: Arc<crate::mdns2::MdnsService>) -> Result<Self> {
85        let socket = UdpSocket::bind(&config.listen_address).await?;
86        let mut salt = vec![0u8; 32];
87        rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut salt);
88        let operational_key = p256::SecretKey::random(&mut rand::thread_rng());
89        let mut device = Self {
90            config,
91            socket,
92            salt,
93            pbkdf_iterations: 1000,
94            operational_key,
95            message_counter: AtomicU32::new(crate::util::cryptoutil::initial_message_counter()),
96            pase_state: None,
97            pase_session: None,
98            case_states: HashMap::new(),
99            case_sessions: Vec::new(),
100            subscribe_states: Vec::new(),
101            active_subscriptions: Vec::new(),
102            pending_chunks: Vec::new(),
103            fabrics: Vec::new(),
104            next_fabric_index: 1,
105            pending_root_cert: None,
106            unencrypted_reception: HashMap::new(),
107            endpoints: vec![0],
108            attributes: HashMap::new(),
109            dirty_attributes: HashSet::new(),
110            mdns,
111            extra_persisted: Vec::new(),
112        };
113        device.setup_default_attributes()?;
114        device.dirty_attributes.clear();
115
116        // Register mDNS commissionable service advertisement
117        let port: u16 = device
118            .config
119            .listen_address
120            .rsplit(':')
121            .next()
122            .and_then(|p| p.parse().ok())
123            .unwrap_or(5540);
124        let short_disc = device.config.discriminator >> 8;
125        let instance_name = format!("{:016X}", rand::random::<u64>());
126        let (adv_v4, adv_v6) = device.config.split_advertise_ips();
127        let svc = crate::mdns2::ServiceRegistration {
128            instance_name,
129            service_type: "_matterc._udp.local".to_string(),
130            port,
131            txt_records: vec![
132                ("DN".to_string(), device.config.product_name.clone()),
133                ("D".to_string(), device.config.discriminator.to_string()),
134                (
135                    "VP".to_string(),
136                    format!("{}+{}", device.config.vendor_id, device.config.product_id),
137                ),
138                ("CM".to_string(), "1".to_string()),
139                ("PH".to_string(), "33".to_string()),
140                ("DT".to_string(), "256".to_string()),
141                ("SII".to_string(), "500".to_string()),
142                ("SAI".to_string(), "300".to_string()),
143            ],
144            hostname: device.config.hostname.clone(),
145            ttl: 120,
146            subtypes: vec![format!("_S{}", short_disc)],
147            ips_v4: adv_v4,
148            ips_v6: adv_v6,
149        };
150        device.mdns.register_service(svc).await;
151
152        Ok(device)
153    }
154
155    pub(crate) fn next_counter(&self) -> u32 {
156        self.message_counter.fetch_add(1, Ordering::Relaxed)
157    }
158
159    pub async fn run(&mut self, handler: &mut dyn AppHandler) -> Result<()> {
160        let mut buf = [0u8; 4096];
161        log::info!(
162            "Device listening on {} (PIN: {})",
163            self.config.listen_address,
164            self.config.pin
165        );
166        loop {
167            let max_interval = self
168                .active_subscriptions
169                .iter()
170                .map(|sub| sub.max_interval_secs as u64)
171                .min()
172                .map(std::time::Duration::from_secs)
173                .unwrap_or_else(|| std::time::Duration::from_secs(3600));
174            let has_dirty = !self.dirty_attributes.is_empty();
175            tokio::select! {
176                result = self.socket.recv_from(&mut buf) => {
177                    let (len, addr) = result?;
178                    let data = buf[..len].to_vec();
179                    if let Err(e) = self.handle_packet(&data, &addr, handler).await {
180                        log::warn!("Error handling packet from {}: {:?}", addr, e);
181                    }
182                }
183                _ = tokio::time::sleep(max_interval) => {
184                    if let Err(e) = self.send_subscription_report().await {
185                        log::warn!("Error sending subscription keepalive: {:?}", e);
186                    }
187                }
188                _ = tokio::time::sleep(std::time::Duration::from_secs(1)), if has_dirty => {
189                    if let Err(e) = self.send_subscription_report().await {
190                        log::warn!("Error sending dirty subscription report: {:?}", e);
191                    }
192                }
193            }
194        }
195    }
196
197    async fn handle_packet(&mut self, data: &[u8], addr: &std::net::SocketAddr, handler: &mut dyn AppHandler) -> Result<()> {
198        let (msg_header, rest) = messages::MessageHeader::decode(data)?;
199        log::debug!(
200            "Received message: session={} counter={} from {}",
201            msg_header.session_id,
202            msg_header.message_counter,
203            addr
204        );
205
206        // Try to decrypt if we have a session. Replay/duplicate detection is
207        // per-session and runs only after the message authenticated.
208        let payload = if msg_header.session_id != 0 {
209            // Encrypted message - search CASE sessions, then PASE
210            let session = self
211                .case_sessions
212                .iter()
213                .find(|s| s.my_session_id == msg_header.session_id)
214                .or_else(|| {
215                    self.pase_session
216                        .as_ref()
217                        .filter(|s| s.my_session_id == msg_header.session_id)
218                });
219            match session {
220                Some(ses) => {
221                    let decrypted = ses.decode_message(data)?;
222                    if !ses.counter_is_new(msg_header.message_counter) {
223                        log::debug!(
224                            "Dropping duplicate message counter={}",
225                            msg_header.message_counter
226                        );
227                        return Ok(());
228                    }
229                    let (_, proto_rest) = messages::MessageHeader::decode(&decrypted)?;
230                    proto_rest
231                }
232                None => {
233                    log::debug!(
234                        "No session for session_id={}, dropping",
235                        msg_header.session_id
236                    );
237                    return Ok(());
238                }
239            }
240        } else {
241            let peer = msg_header
242                .source_node_id
243                .as_deref()
244                .and_then(|b| b.try_into().ok())
245                .map(u64::from_le_bytes)
246                .unwrap_or(0);
247            if self.unencrypted_reception.len() > 16 {
248                self.unencrypted_reception.clear();
249            }
250            let state = self.unencrypted_reception.entry(peer).or_default();
251            if !state.counter_is_new(msg_header.message_counter) {
252                log::debug!(
253                    "Dropping duplicate unencrypted message counter={}",
254                    msg_header.message_counter
255                );
256                return Ok(());
257            }
258            rest
259        };
260
261        let (proto_header, proto_payload) = messages::ProtocolMessageHeader::decode(&payload)?;
262        log::debug!(
263            "Protocol: opcode=0x{:02x} protocol={} exchange={} flags=0x{:02x}",
264            proto_header.opcode,
265            proto_header.protocol_id,
266            proto_header.exchange_id,
267            proto_header.exchange_flags
268        );
269
270        // Handle ACKs - just ignore
271        if proto_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
272            && proto_header.opcode == messages::ProtocolMessageHeader::OPCODE_ACK
273        {
274            return Ok(());
275        }
276
277        match (proto_header.protocol_id, proto_header.opcode) {
278            (
279                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
280                messages::ProtocolMessageHeader::OPCODE_PBKDF_REQ,
281            ) => {
282                self.handle_pbkdf_req(addr, &msg_header, &proto_header, &proto_payload, &payload)
283                    .await
284            }
285            (
286                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
287                messages::ProtocolMessageHeader::OPCODE_PASE_PAKE1,
288            ) => {
289                self.handle_pake1(addr, &msg_header, &proto_header, &proto_payload)
290                    .await
291            }
292            (
293                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
294                messages::ProtocolMessageHeader::OPCODE_PASE_PAKE3,
295            ) => {
296                self.handle_pake3(addr, &msg_header, &proto_header, &proto_payload)
297                    .await
298            }
299            (
300                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
301                messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA1,
302            ) => {
303                self.handle_sigma1(addr, &msg_header, &proto_header, &proto_payload)
304                    .await
305            }
306            (
307                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
308                messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA3,
309            ) => {
310                self.handle_sigma3(addr, &msg_header, &proto_header, &proto_payload)
311                    .await
312            }
313            (
314                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
315                messages::ProtocolMessageHeader::OPCODE_STATUS,
316            ) => {
317                self.handle_status_report(&proto_payload).await
318            }
319            (
320                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
321                messages::ProtocolMessageHeader::INTERACTION_OPCODE_INVOKE_REQ,
322            ) => {
323                self.handle_invoke_request(addr, &msg_header, &proto_header, &proto_payload, handler)
324                    .await
325            }
326            (
327                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
328                messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP,
329            ) => {
330                self.handle_status_response(addr, &msg_header, &proto_header)
331                    .await
332            }
333            (
334                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
335                messages::ProtocolMessageHeader::INTERACTION_OPCODE_READ_REQ,
336            ) => {
337                log::debug!("Received IM read request");
338                self.handle_read_request(addr, &msg_header, &proto_header, &proto_payload)
339                    .await
340            }
341            (
342                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
343                messages::ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_REQ,
344            ) => {
345                log::debug!("Received IM subscribe request");
346                self.handle_subscribe_request(addr, &msg_header, &proto_header, &proto_payload)
347                    .await
348            }
349            (
350                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
351                messages::ProtocolMessageHeader::INTERACTION_OPCODE_WRITE_REQ,
352            ) => {
353                log::debug!("Received IM write request");
354                self.handle_write_request(addr, &msg_header, &proto_header, &proto_payload)
355                    .await
356            }
357
358            _ => {
359                log::warn!(
360                    "Unhandled opcode: protocol={} opcode=0x{:02x}",
361                    proto_header.protocol_id,
362                    proto_header.opcode
363                );
364                Ok(())
365            }
366        }
367    }
368}