1use std::collections::HashMap;
4use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use anyhow::Result;
9use byteorder::{BigEndian, WriteBytesExt};
10use socket2::{Domain, Protocol, Type};
11use tokio::net::UdpSocket;
12use tokio::sync::mpsc::UnboundedReceiver;
13use tokio_util::sync::CancellationToken;
14
15use crate::mdns;
16
17pub(super) const MDNS_ADDR_V4: &str = "224.0.0.251:5353";
18pub(super) const MDNS_ADDR_V6: &str = "[ff02::fb]:5353";
19
20#[derive(Debug, Clone)]
21pub struct CachedRecord {
22 pub rr: mdns::RR,
23 pub received_at: Instant,
24 pub ttl: Duration,
25}
26
27impl CachedRecord {
28 fn is_expired(&self) -> bool {
29 self.received_at.elapsed() > self.ttl
30 }
31}
32
33pub struct RecordCache {
35 pub(super) entries: HashMap<(String, u16), Vec<CachedRecord>>,
36}
37
38impl RecordCache {
39 pub fn new() -> Self {
40 Self {
41 entries: HashMap::new(),
42 }
43 }
44
45 pub fn ingest(&mut self, rr: &mdns::RR) -> bool {
48 let key = (rr.name.to_lowercase(), rr.typ);
49 if rr.ttl == 0 {
50 if let Some(vec) = self.entries.get_mut(&key) {
51 vec.retain(|c| c.rr.rdata != rr.rdata);
52 if vec.is_empty() {
53 self.entries.remove(&key);
54 }
55 }
56 return false;
57 }
58 let cached = CachedRecord {
59 rr: rr.clone(),
60 received_at: Instant::now(),
61 ttl: Duration::from_secs(rr.ttl as u64),
62 };
63 let vec = self.entries.entry(key).or_default();
64 if let Some(existing) = vec.iter_mut().find(|c| c.rr.rdata == rr.rdata) {
66 *existing = cached;
67 false
68 } else {
69 vec.push(cached);
70 true
71 }
72 }
73
74 pub fn evict_expired(&mut self) -> Vec<(String, u16)> {
76 let mut expired_keys = Vec::new();
77 self.entries.retain(|key, records| {
78 records.retain(|c| !c.is_expired());
79 if records.is_empty() {
80 expired_keys.push(key.clone());
81 false
82 } else {
83 true
84 }
85 });
86 expired_keys
87 }
88
89 pub fn lookup(&self, name: &str, qtype: u16) -> Vec<mdns::RR> {
91 let key = (name.to_lowercase(), qtype);
92 self.entries
93 .get(&key)
94 .map(|v| {
95 v.iter()
96 .filter(|c| !c.is_expired())
97 .map(|c| c.rr.clone())
98 .collect()
99 })
100 .unwrap_or_default()
101 }
102
103 pub fn lookup_name(&self, name: &str) -> Vec<mdns::RR> {
105 let lower = name.to_lowercase();
106 self.entries
107 .iter()
108 .filter(|((n, _), _)| *n == lower)
109 .flat_map(|(_, v)| v.iter().filter(|c| !c.is_expired()).map(|c| c.rr.clone()))
110 .collect()
111 }
112}
113
114impl Default for RecordCache {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120#[allow(dead_code)]
121pub(super) enum SendCommand {
122 Multicast(Vec<u8>),
124 Unicast(Vec<u8>, std::net::SocketAddr),
126}
127
128pub(super) fn encode_rr(rr: &mdns::RR, out: &mut Vec<u8>) -> Result<()> {
130 mdns::encode_label(&rr.name, out)?;
131 out.write_u16::<BigEndian>(rr.typ)?;
132 out.write_u16::<BigEndian>(rr.class)?;
133 out.write_u32::<BigEndian>(rr.ttl)?;
134
135 if rr.typ == mdns::TYPE_SRV {
136 let mut rdata = Vec::new();
138 if rr.rdata.len() >= 6 {
140 rdata.extend_from_slice(&rr.rdata[..6]);
141 } else {
142 rdata.write_u16::<BigEndian>(0)?; rdata.write_u16::<BigEndian>(0)?; rdata.write_u16::<BigEndian>(0)?; }
146 if let Some(ref target) = rr.target {
148 let mut srv_rdata = Vec::new();
150 srv_rdata.extend_from_slice(&rdata[..6]);
151 mdns::encode_label(target.trim_end_matches('.'), &mut srv_rdata)?;
152 out.write_u16::<BigEndian>(srv_rdata.len() as u16)?;
153 out.extend_from_slice(&srv_rdata);
154 } else {
155 out.write_u16::<BigEndian>(rr.rdata.len() as u16)?;
156 out.extend_from_slice(&rr.rdata);
157 }
158 } else {
159 out.write_u16::<BigEndian>(rr.rdata.len() as u16)?;
160 out.extend_from_slice(&rr.rdata);
161 }
162 Ok(())
163}
164
165pub(super) fn build_response(answers: &[mdns::RR], additional: &[mdns::RR]) -> Result<Vec<u8>> {
167 let mut out = Vec::with_capacity(512);
168 out.write_u16::<BigEndian>(0)?; out.write_u16::<BigEndian>(0x8400)?; out.write_u16::<BigEndian>(0)?; out.write_u16::<BigEndian>(answers.len() as u16)?;
172 out.write_u16::<BigEndian>(0)?; out.write_u16::<BigEndian>(additional.len() as u16)?;
174
175 for rr in answers {
176 encode_rr(rr, &mut out)?;
177 }
178 for rr in additional {
179 encode_rr(rr, &mut out)?;
180 }
181 Ok(out)
182}
183
184pub(super) fn create_multicast_socket_v4() -> Result<std::net::UdpSocket> {
185 let sock = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
186 sock.set_reuse_address(true)?;
187 #[cfg(not(target_os = "windows"))]
188 sock.set_reuse_port(true)?;
189 let addr: SocketAddrV4 = "0.0.0.0:5353".parse()?;
190 sock.bind(&socket2::SockAddr::from(addr))?;
191 let maddr: Ipv4Addr = "224.0.0.251".parse()?;
192 sock.join_multicast_v4(&maddr, &Ipv4Addr::UNSPECIFIED)?;
193 sock.set_nonblocking(true)?;
194 Ok(sock.into())
195}
196
197pub(super) fn create_multicast_socket_v6(interface: u32) -> Result<std::net::UdpSocket> {
198 let sock = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
199 sock.set_reuse_address(true)?;
200 #[cfg(not(target_os = "windows"))]
201 sock.set_reuse_port(true)?;
202 let addr: SocketAddrV6 = "[::]:5353".parse()?;
203 sock.bind(&socket2::SockAddr::from(addr))?;
204 let maddr: Ipv6Addr = "ff02::fb".parse()?;
205 sock.join_multicast_v6(&maddr, interface)?;
206 sock.set_multicast_if_v6(interface)?;
207 sock.set_nonblocking(true)?;
208 Ok(sock.into())
209}
210
211pub(super) fn get_local_ips() -> (Vec<Ipv4Addr>, Vec<Ipv6Addr>) {
213 let mut v4 = Vec::new();
214 let v6 = Vec::new();
215 if let Ok(ifaces) = if_addrs::get_if_addrs() {
216 for iface in ifaces {
217 match iface.ip() {
218 std::net::IpAddr::V4(ip) if !ip.is_loopback() => v4.push(ip),
219 _ => {}
221 }
222 }
223 }
224 (v4, v6)
225}
226
227pub(super) struct McastSocket {
228 pub sock: Arc<UdpSocket>,
229 pub multicast_addr: &'static str,
230}
231
232pub(super) async fn send_loop(
233 sockets: Vec<McastSocket>,
234 mut rx: UnboundedReceiver<SendCommand>,
235 cancel: CancellationToken,
236) {
237 loop {
238 let cmd = tokio::select! {
239 cmd = rx.recv() => {
240 match cmd {
241 Some(c) => c,
242 None => return,
243 }
244 }
245 _ = cancel.cancelled() => return,
246 };
247
248 match cmd {
249 SendCommand::Multicast(data) => {
250 for ms in &sockets {
251 let _ = ms.sock.send_to(&data, ms.multicast_addr).await;
252 }
253 }
254 SendCommand::Unicast(data, addr) => {
255 for ms in &sockets {
257 if ms.sock.send_to(&data, addr).await.is_ok() {
258 break;
259 }
260 }
261 }
262 }
263 }
264}
265