matc/device/
mod.rs

1//! very experimental device implementation with many things hardcoded for testing and development purposes.
2
3mod attributes;
4mod case_handler;
5mod commissioning;
6mod crypto;
7mod interaction;
8mod pase;
9mod persist;
10mod send;
11mod types;
12
13pub use types::DeviceConfig;
14pub use attributes::AttrContext;
15use types::{ActiveSubscription, CaseState, FabricInfo, PaseState, PendingChunkState, SubscribeState};
16
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU32, Ordering};
21
22use anyhow::{Ok, Result};
23use tokio::net::UdpSocket;
24
25use crate::{messages, session, tlv};
26
27/// Result returned by [`AppHandler::handle_command`].
28pub enum CommandResult {
29    /// Command succeeded; library will send an IM status response with status 0.
30    Success,
31    /// Command failed with the given status code.
32    Error(u16),
33    /// Command is not handled by this application handler.
34    Unhandled,
35}
36
37pub trait AppHandler: Send {
38    fn handle_command(
39        &mut self,
40        endpoint: u16,
41        cluster: u32,
42        command: u32,
43        payload: &tlv::TlvItem,
44        attrs: &mut AttrContext,
45    ) -> CommandResult;
46}
47
48pub struct Device {
49    pub(crate) config: DeviceConfig,
50    pub(crate) socket: UdpSocket,
51    pub(crate) salt: Vec<u8>,
52    pub(crate) pbkdf_iterations: u32,
53    pub(crate) operational_key: p256::SecretKey,
54    pub(crate) message_counter: AtomicU32,
55    // Commissioning state
56    pub(crate) pase_state: Option<PaseState>,
57    pub(crate) pase_session: Option<session::Session>,
58    pub(crate) case_states: HashMap<u16, CaseState>,
59    pub(crate) case_sessions: Vec<session::Session>,
60    pub(crate) subscribe_states: Vec<SubscribeState>,
61    pub(crate) active_subscriptions: Vec<ActiveSubscription>,
62    pub(crate) pending_chunks: Vec<PendingChunkState>,
63    // Commissioned fabric table (supports multiple fabrics)
64    pub(crate) fabrics: Vec<FabricInfo>,
65    /// Next fabric index to assign (1-based, monotonically increasing).
66    pub(crate) next_fabric_index: u8,
67    /// Temporary root cert from AddTrustedRootCertificate, consumed by AddNOC.
68    pub(crate) pending_root_cert: Option<Vec<u8>>,
69    // Duplicate detection
70    pub(crate) received_counters: HashSet<u32>,
71    /// Registered application endpoints (EP0 is always present).
72    pub(crate) endpoints: Vec<u16>,
73    // Attribute store: (endpoint, cluster, attribute) -> pre-tagged TLV at context tag 2
74    pub(crate) attributes: HashMap<(u16, u32, u32), Vec<u8>>,
75    /// Attributes mutated since last subscription report was sent.
76    pub(crate) dirty_attributes: HashSet<(u16, u32, u32)>,
77    pub(crate) mdns: Arc<crate::mdns2::MdnsService>,
78    /// Extra attributes to include in persistence (registered by user code).
79    pub(crate) extra_persisted: Vec<(u16, u32, u32)>,
80}
81
82impl Device {
83    pub async fn new(config: DeviceConfig, mdns: Arc<crate::mdns2::MdnsService>) -> Result<Self> {
84        let socket = UdpSocket::bind(&config.listen_address).await?;
85        let mut salt = vec![0u8; 32];
86        rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut salt);
87        let operational_key = p256::SecretKey::random(&mut rand::thread_rng());
88        let mut device = Self {
89            config,
90            socket,
91            salt,
92            pbkdf_iterations: 1000,
93            operational_key,
94            message_counter: AtomicU32::new(rand::random()),
95            pase_state: None,
96            pase_session: None,
97            case_states: HashMap::new(),
98            case_sessions: Vec::new(),
99            subscribe_states: Vec::new(),
100            active_subscriptions: Vec::new(),
101            pending_chunks: Vec::new(),
102            fabrics: Vec::new(),
103            next_fabric_index: 1,
104            pending_root_cert: None,
105            received_counters: HashSet::new(),
106            endpoints: vec![0],
107            attributes: HashMap::new(),
108            dirty_attributes: HashSet::new(),
109            mdns,
110            extra_persisted: Vec::new(),
111        };
112        device.setup_default_attributes()?;
113        device.dirty_attributes.clear();
114
115        // Register mDNS commissionable service advertisement
116        let port: u16 = device
117            .config
118            .listen_address
119            .rsplit(':')
120            .next()
121            .and_then(|p| p.parse().ok())
122            .unwrap_or(5540);
123        let short_disc = device.config.discriminator >> 8;
124        let instance_name = format!("{:016X}", rand::random::<u64>());
125        let (adv_v4, adv_v6) = device.config.split_advertise_ips();
126        let svc = crate::mdns2::ServiceRegistration {
127            instance_name,
128            service_type: "_matterc._udp.local".to_string(),
129            port,
130            txt_records: vec![
131                ("DN".to_string(), device.config.product_name.clone()),
132                ("D".to_string(), device.config.discriminator.to_string()),
133                (
134                    "VP".to_string(),
135                    format!("{}+{}", device.config.vendor_id, device.config.product_id),
136                ),
137                ("CM".to_string(), "1".to_string()),
138                ("PH".to_string(), "33".to_string()),
139                ("DT".to_string(), "256".to_string()),
140            ],
141            hostname: device.config.hostname.clone(),
142            ttl: 120,
143            subtypes: vec![format!("_S{}", short_disc)],
144            ips_v4: adv_v4,
145            ips_v6: adv_v6,
146        };
147        device.mdns.register_service(svc).await;
148
149        Ok(device)
150    }
151
152    pub(crate) fn next_counter(&self) -> u32 {
153        self.message_counter.fetch_add(1, Ordering::Relaxed)
154    }
155
156    pub async fn run(&mut self, handler: &mut dyn AppHandler) -> Result<()> {
157        let mut buf = [0u8; 4096];
158        log::info!(
159            "Device listening on {} (PIN: {})",
160            self.config.listen_address,
161            self.config.pin
162        );
163        loop {
164            let max_interval = self
165                .active_subscriptions
166                .iter()
167                .map(|sub| sub.max_interval_secs as u64)
168                .min()
169                .map(std::time::Duration::from_secs)
170                .unwrap_or_else(|| std::time::Duration::from_secs(3600));
171            let has_dirty = !self.dirty_attributes.is_empty();
172            tokio::select! {
173                result = self.socket.recv_from(&mut buf) => {
174                    let (len, addr) = result?;
175                    let data = buf[..len].to_vec();
176                    if let Err(e) = self.handle_packet(&data, &addr, handler).await {
177                        log::warn!("Error handling packet from {}: {:?}", addr, e);
178                    }
179                }
180                _ = tokio::time::sleep(max_interval) => {
181                    if let Err(e) = self.send_subscription_report().await {
182                        log::warn!("Error sending subscription keepalive: {:?}", e);
183                    }
184                }
185                _ = tokio::time::sleep(std::time::Duration::from_secs(1)), if has_dirty => {
186                    if let Err(e) = self.send_subscription_report().await {
187                        log::warn!("Error sending dirty subscription report: {:?}", e);
188                    }
189                }
190            }
191        }
192    }
193
194    async fn handle_packet(&mut self, data: &[u8], addr: &std::net::SocketAddr, handler: &mut dyn AppHandler) -> Result<()> {
195        let (msg_header, rest) = messages::MessageHeader::decode(data)?;
196        log::debug!(
197            "Received message: session={} counter={} from {}",
198            msg_header.session_id,
199            msg_header.message_counter,
200            addr
201        );
202
203        // Duplicate detection
204        if self.received_counters.contains(&msg_header.message_counter) {
205            log::debug!(
206                "Dropping duplicate message counter={}",
207                msg_header.message_counter
208            );
209            return Ok(());
210        }
211        self.received_counters.insert(msg_header.message_counter);
212
213        // Try to decrypt if we have a session
214        let payload = if msg_header.session_id != 0 {
215            // Encrypted message - search CASE sessions, then PASE
216            let session = self
217                .case_sessions
218                .iter()
219                .find(|s| s.my_session_id == msg_header.session_id)
220                .or_else(|| {
221                    self.pase_session
222                        .as_ref()
223                        .filter(|s| s.my_session_id == msg_header.session_id)
224                });
225            match session {
226                Some(ses) => {
227                    let decrypted = ses.decode_message(data)?;
228                    let (_, proto_rest) = messages::MessageHeader::decode(&decrypted)?;
229                    proto_rest
230                }
231                None => {
232                    log::debug!(
233                        "No session for session_id={}, dropping",
234                        msg_header.session_id
235                    );
236                    return Ok(());
237                }
238            }
239        } else {
240            rest
241        };
242
243        let (proto_header, proto_payload) = messages::ProtocolMessageHeader::decode(&payload)?;
244        log::debug!(
245            "Protocol: opcode=0x{:02x} protocol={} exchange={} flags=0x{:02x}",
246            proto_header.opcode,
247            proto_header.protocol_id,
248            proto_header.exchange_id,
249            proto_header.exchange_flags
250        );
251
252        // Handle ACKs - just ignore
253        if proto_header.protocol_id == messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
254            && proto_header.opcode == messages::ProtocolMessageHeader::OPCODE_ACK
255        {
256            return Ok(());
257        }
258
259        match (proto_header.protocol_id, proto_header.opcode) {
260            (
261                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
262                messages::ProtocolMessageHeader::OPCODE_PBKDF_REQ,
263            ) => {
264                self.handle_pbkdf_req(addr, &msg_header, &proto_header, &proto_payload, &payload)
265                    .await
266            }
267            (
268                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
269                messages::ProtocolMessageHeader::OPCODE_PASE_PAKE1,
270            ) => {
271                self.handle_pake1(addr, &msg_header, &proto_header, &proto_payload)
272                    .await
273            }
274            (
275                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
276                messages::ProtocolMessageHeader::OPCODE_PASE_PAKE3,
277            ) => {
278                self.handle_pake3(addr, &msg_header, &proto_header, &proto_payload)
279                    .await
280            }
281            (
282                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
283                messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA1,
284            ) => {
285                self.handle_sigma1(addr, &msg_header, &proto_header, &proto_payload)
286                    .await
287            }
288            (
289                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
290                messages::ProtocolMessageHeader::OPCODE_CASE_SIGMA3,
291            ) => {
292                self.handle_sigma3(addr, &msg_header, &proto_header, &proto_payload)
293                    .await
294            }
295            (
296                messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL,
297                messages::ProtocolMessageHeader::OPCODE_STATUS,
298            ) => {
299                self.handle_status_report(&proto_payload).await
300            }
301            (
302                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
303                messages::ProtocolMessageHeader::INTERACTION_OPCODE_INVOKE_REQ,
304            ) => {
305                self.handle_invoke_request(addr, &msg_header, &proto_header, &proto_payload, handler)
306                    .await
307            }
308            (
309                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
310                messages::ProtocolMessageHeader::INTERACTION_OPCODE_STATUS_RESP,
311            ) => {
312                self.handle_status_response(addr, &msg_header, &proto_header)
313                    .await
314            }
315            (
316                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
317                messages::ProtocolMessageHeader::INTERACTION_OPCODE_READ_REQ,
318            ) => {
319                log::debug!("Received IM read request");
320                self.handle_read_request(addr, &msg_header, &proto_header, &proto_payload)
321                    .await
322            }
323            (
324                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
325                messages::ProtocolMessageHeader::INTERACTION_OPCODE_SUBSCRIBE_REQ,
326            ) => {
327                log::debug!("Received IM subscribe request");
328                self.handle_subscribe_request(addr, &msg_header, &proto_header, &proto_payload)
329                    .await
330            }
331            (
332                messages::ProtocolMessageHeader::PROTOCOL_ID_INTERACTION,
333                messages::ProtocolMessageHeader::INTERACTION_OPCODE_WRITE_REQ,
334            ) => {
335                log::debug!("Received IM write request");
336                self.handle_write_request(addr, &msg_header, &proto_header, &proto_payload)
337                    .await
338            }
339
340            _ => {
341                log::warn!(
342                    "Unhandled opcode: protocol={} opcode=0x{:02x}",
343                    proto_header.protocol_id,
344                    proto_header.opcode
345                );
346                Ok(())
347            }
348        }
349    }
350}