1use anyhow::{Context, Result};
6use std::{collections::HashMap, sync::Arc, time::Duration};
7use tokio::{net::UdpSocket, sync::Mutex};
8
9#[async_trait::async_trait]
13pub trait ConnectionTrait: Send + Sync {
14 async fn send(&self, data: &[u8]) -> Result<()>;
15 async fn receive(&self, timeout: Duration) -> Result<Vec<u8>>;
16 fn is_reliable(&self) -> bool { false }
19}
20
21#[derive(Debug, Clone)]
22struct ConnectionInfo {
23 sender: tokio::sync::mpsc::Sender<Vec<u8>>,
24}
25
26pub struct Transport {
32 socket: Arc<UdpSocket>,
33 connections: Mutex<HashMap<String, ConnectionInfo>>,
34 remove_channel_sender: tokio::sync::mpsc::UnboundedSender<String>,
35 stop_receive_token: tokio_util::sync::CancellationToken,
36}
37
38pub struct Connection {
41 transport: Arc<Transport>,
42 remote_address: String,
43 receiver: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
44}
45
46impl Transport {
47 async fn read_from_socket_loop(
48 socket: Arc<UdpSocket>,
49 stop_receive_token: tokio_util::sync::CancellationToken,
50 self_weak: std::sync::Weak<Transport>,
51 ) -> Result<()> {
52 loop {
53 let mut buf = vec![0u8; 1024];
54 let (n, addr) = {
55 tokio::select! {
56 recv_resp = socket.recv_from(&mut buf) => recv_resp,
57 _ = stop_receive_token.cancelled() => break
58 }
59 }?;
60 buf.resize(n, 0);
61 let self_strong = self_weak
62 .upgrade()
63 .context("weakpointer to self is gone - just stop")?;
64 let cons = self_strong.connections.lock().await;
65 if let Some(c) = cons.get(&addr.to_string()) {
66 _ = c.sender.send(buf).await;
67 }
68 }
69 Ok(())
70 }
71
72 async fn read_from_delete_queue_loop(
73 mut remove_channel_receiver: tokio::sync::mpsc::UnboundedReceiver<String>,
74 self_weak: std::sync::Weak<Transport>,
75 ) -> Result<()> {
76 loop {
77 let to_remove = remove_channel_receiver.recv().await;
78 match to_remove {
79 Some(to_remove) => {
80 if to_remove.is_empty() {
81 break;
83 }
84 let self_strong = self_weak
85 .upgrade()
86 .context("weak to self is gone - just stop")?;
87 let mut cons = self_strong.connections.lock().await;
88 _ = cons.remove(&to_remove);
89 }
90 None => break, }
92 }
93 Ok(())
94 }
95
96 pub async fn new(local: &str) -> Result<Arc<Self>> {
98 let socket = UdpSocket::bind(local).await?;
99 let (remove_channel_sender, remove_channel_receiver) =
100 tokio::sync::mpsc::unbounded_channel();
101 let stop_receive_token = tokio_util::sync::CancellationToken::new();
102 let stop_receive_token_child = stop_receive_token.child_token();
103 let o = Arc::new(Self {
104 socket: Arc::new(socket),
105 connections: Mutex::new(HashMap::new()),
106 remove_channel_sender,
107 stop_receive_token,
108 });
109 let self_weak = Arc::downgrade(&o.clone());
110 let socket = o.socket.clone();
111 tokio::spawn(async move {
112 _ = Self::read_from_socket_loop(socket, stop_receive_token_child, self_weak).await;
113 });
114 let self_weak = Arc::downgrade(&o.clone());
115 tokio::spawn(async move {
116 _ = Self::read_from_delete_queue_loop(remove_channel_receiver, self_weak).await;
117 });
118 Ok(o)
119 }
120
121 pub async fn create_connection(self: &Arc<Self>, remote: &str) -> Arc<dyn ConnectionTrait> {
123 let mut clock = self.connections.lock().await;
124 let (sender, receiver) = tokio::sync::mpsc::channel(32);
125 clock.insert(remote.to_owned(), ConnectionInfo { sender });
126 Arc::new(Connection {
127 transport: self.clone(),
128 remote_address: remote.to_owned(),
129 receiver: Mutex::new(receiver),
130 })
131 }
132}
133
134impl Connection {
135 pub async fn send(&self, data: &[u8]) -> Result<()> {
137 self.transport
138 .socket
139 .send_to(data, &self.remote_address)
140 .await?;
141 Ok(())
142 }
143 pub async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
145 let mut ch = self.receiver.lock().await;
146 let rec_future = ch.recv();
147 let with_timeout = tokio::time::timeout(timeout, rec_future);
148 with_timeout.await?.context("eof")
149 }
150}
151
152impl Drop for Transport {
153 fn drop(&mut self) {
154 _ = self.remove_channel_sender.send("".to_owned());
155 self.stop_receive_token.cancel();
156 }
157}
158
159#[async_trait::async_trait]
160impl ConnectionTrait for Connection {
161 async fn send(&self, data: &[u8]) -> Result<()> {
162 self.send(data).await
163 }
164 async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
165 self.receive(timeout).await
166 }
167}
168
169impl Drop for Connection {
170 fn drop(&mut self) {
171 _ = self
172 .transport
173 .remove_channel_sender
174 .send(self.remote_address.clone());
175 }
176}