1use anyhow::{Context, Result};
2use std::{collections::HashMap, sync::Arc, time::Duration};
3use tokio::{net::UdpSocket, sync::Mutex};
4
5#[derive(Debug, Clone)]
6struct ConnectionInfo {
7 sender: tokio::sync::mpsc::Sender<Vec<u8>>,
8}
9
10pub struct Transport {
11 socket: Arc<UdpSocket>,
12 connections: Mutex<HashMap<String, ConnectionInfo>>,
13 remove_channel_sender: tokio::sync::mpsc::UnboundedSender<String>,
14 stop_receive_token: tokio_util::sync::CancellationToken,
15}
16
17pub struct Connection {
18 transport: Arc<Transport>,
19 remote_address: String,
20 receiver: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
21}
22
23impl Transport {
24 async fn read_from_socket_loop(
25 socket: Arc<UdpSocket>,
26 stop_receive_token: tokio_util::sync::CancellationToken,
27 self_weak: std::sync::Weak<Transport>,
28 ) -> Result<()> {
29 loop {
30 let mut buf = vec![0u8; 1024];
31 let (n, addr) = {
32 tokio::select! {
33 recv_resp = socket.recv_from(&mut buf) => recv_resp,
34 _ = stop_receive_token.cancelled() => break
35 }
36 }?;
37 buf.resize(n, 0);
38 let self_strong = self_weak
39 .upgrade()
40 .context("weakpointer to self is gone - just stop")?;
41 let cons = self_strong.connections.lock().await;
42 if let Some(c) = cons.get(&addr.to_string()) {
43 _ = c.sender.send(buf).await;
44 }
45 }
46 Ok(())
47 }
48
49 async fn read_from_delete_queue_loop(
50 mut remove_channel_receiver: tokio::sync::mpsc::UnboundedReceiver<String>,
51 self_weak: std::sync::Weak<Transport>,
52 ) -> Result<()> {
53 loop {
54 let to_remove = remove_channel_receiver.recv().await;
55 match to_remove {
56 Some(to_remove) => {
57 if to_remove.is_empty() {
58 break;
59 }
60 let self_strong = self_weak
61 .upgrade()
62 .context("weak to self is gone - just stop")?;
63 let mut cons = self_strong.connections.lock().await;
64 _ = cons.remove(&to_remove);
65 }
66 None => break,
67 }
68 }
69 Ok(())
70 }
71
72 pub async fn new(local: &str) -> Result<Arc<Self>> {
73 let socket = UdpSocket::bind(local).await?;
74 let (remove_channel_sender, remove_channel_receiver) =
75 tokio::sync::mpsc::unbounded_channel();
76 let stop_receive_token = tokio_util::sync::CancellationToken::new();
77 let stop_receive_token_child = stop_receive_token.child_token();
78 let o = Arc::new(Self {
79 socket: Arc::new(socket),
80 connections: Mutex::new(HashMap::new()),
81 remove_channel_sender,
82 stop_receive_token,
83 });
84 let self_weak = Arc::downgrade(&o.clone());
85 let socket = o.socket.clone();
86 tokio::spawn(async move {
87 _ = Self::read_from_socket_loop(socket, stop_receive_token_child, self_weak).await;
88 });
89 let self_weak = Arc::downgrade(&o.clone());
90 tokio::spawn(async move {
91 _ = Self::read_from_delete_queue_loop(remove_channel_receiver, self_weak).await;
92 });
93 Ok(o)
94 }
95
96 pub async fn create_connection(self: &Arc<Self>, remote: &str) -> Arc<Connection> {
97 let mut clock = self.connections.lock().await;
98 let (sender, receiver) = tokio::sync::mpsc::channel(32);
99 clock.insert(remote.to_owned(), ConnectionInfo { sender });
100 Arc::new(Connection {
101 transport: self.clone(),
102 remote_address: remote.to_owned(),
103 receiver: Mutex::new(receiver),
104 })
105 }
106}
107
108impl Connection {
109 pub async fn send(&self, data: &[u8]) -> Result<()> {
110 self.transport
111 .socket
112 .send_to(data, &self.remote_address)
113 .await?;
114 Ok(())
115 }
116 pub async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
117 let mut ch = self.receiver.lock().await;
118 let rec_future = ch.recv();
119 let with_timeout = tokio::time::timeout(timeout, rec_future);
120 with_timeout.await?.context("eof")
121 }
122}
123
124impl Drop for Transport {
125 fn drop(&mut self) {
126 _ = self.remove_channel_sender.send("".to_owned());
127 self.stop_receive_token.cancel();
128 }
129}
130
131impl Drop for Connection {
132 fn drop(&mut self) {
133 _ = self
134 .transport
135 .remove_channel_sender
136 .send(self.remote_address.clone());
137 }
138}