Skip to content

Commit 915bef4

Browse files
committed
refactor rust udp
1 parent dd62288 commit 915bef4

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

src/client/client.rs

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
use anyhow::{Context, Result};
1+
use anyhow::{Context as _, Result};
22
use async_shutdown::{ShutdownManager, ShutdownSignal};
3-
use std::net::SocketAddr;
3+
use std::task::{Context, Poll};
4+
use std::{net::SocketAddr, pin::Pin};
45
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
56
use tonic::{transport::Channel, Response, Status, Streaming};
67
use tracing::{debug, error, info, instrument, span};
@@ -12,6 +13,7 @@ use tokio::{
1213
sync::{mpsc, oneshot},
1314
};
1415

16+
use crate::io::AsyncUdpSocket;
1517
use crate::{
1618
constant,
1719
io::{StreamingReader, StreamingWriter, TrafficToServerWrapper},
@@ -331,7 +333,7 @@ async fn handle_work_traffic(
331333

332334
// write the data streaming response to transfer_tx,
333335
// then forward_traffic_to_local can read the data from transfer_rx
334-
let (transfer_tx, mut transfer_rx) = mpsc::channel::<TrafficToClient>(64);
336+
let (transfer_tx, transfer_rx) = mpsc::channel::<TrafficToClient>(64);
335337

336338
let (local_conn_established_tx, local_conn_established_rx) = mpsc::channel::<()>(1);
337339
let mut local_conn_established_rx = Some(local_conn_established_rx);
@@ -375,7 +377,7 @@ async fn handle_work_traffic(
375377
});
376378

377379
let wrapper = TrafficToServerWrapper::new(connection_id.clone());
378-
let mut writer = StreamingWriter::new(streaming_tx.clone(), wrapper);
380+
let writer = StreamingWriter::new(streaming_tx.clone(), wrapper);
379381

380382
if is_udp {
381383
tokio::spawn(async move {
@@ -410,30 +412,16 @@ async fn handle_work_traffic(
410412

411413
local_conn_established_tx.send(()).await.unwrap();
412414

413-
let read_transfer_send_to_local = async {
414-
while let Some(buf) = transfer_rx.recv().await {
415-
socket.send(&buf.data).await.unwrap();
416-
}
417-
};
418-
419-
let read_local_send_to_server = async {
420-
loop {
421-
let mut buf = vec![0u8; 65507];
422-
let result = socket.recv(&mut buf).await;
423-
match result {
424-
Ok(n) => {
425-
writer.write_all(&buf[..n]).await.unwrap();
426-
}
427-
Err(err) => {
428-
error!(err = ?err, "failed to read from local endpoint");
429-
break;
430-
}
431-
}
432-
}
433-
writer.shutdown().await.unwrap();
434-
};
435-
436-
tokio::join!(read_transfer_send_to_local, read_local_send_to_server);
415+
if let Err(err) = forward_traffic_to_local(
416+
AsyncUdpSocket::new(&socket),
417+
AsyncUdpSocket::new(&socket),
418+
StreamingReader::new(transfer_rx),
419+
writer,
420+
)
421+
.await
422+
{
423+
debug!("failed to forward traffic to local: {:?}", err);
424+
}
437425
});
438426
} else {
439427
tokio::spawn(async move {
@@ -485,8 +473,8 @@ async fn handle_work_traffic(
485473
async fn forward_traffic_to_local(
486474
local_r: impl AsyncRead + Unpin,
487475
mut local_w: impl AsyncWrite + Unpin,
488-
remote_r: StreamingReader<TrafficToClient>,
489-
mut remote_w: StreamingWriter<TrafficToServer>,
476+
remote_r: impl AsyncRead + Unpin,
477+
mut remote_w: impl AsyncWrite + Unpin,
490478
) -> Result<()> {
491479
let remote_to_me_to_local = async {
492480
// read from remote, write to local

src/io.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::fmt::Debug;
88
use std::task::{Context, Poll};
99
use tokio::io::AsyncRead;
1010
use tokio::io::AsyncWrite;
11+
use tokio::net::UdpSocket;
1112
use tokio::sync::mpsc;
1213
use tokio::{io, sync::mpsc::Sender};
1314
use tokio_util::sync::CancellationToken;
@@ -263,3 +264,51 @@ macro_rules! generate_async_write_impl {
263264

264265
generate_async_write_impl!(TrafficToServer);
265266
generate_async_write_impl!(Vec<u8>);
267+
268+
pub(crate) struct AsyncUdpSocket<'a> {
269+
socket: &'a UdpSocket,
270+
}
271+
272+
impl<'a> AsyncUdpSocket<'a> {
273+
pub(crate) fn new(socket: &'a UdpSocket) -> Self {
274+
Self { socket }
275+
}
276+
}
277+
278+
impl<'a> AsyncRead for AsyncUdpSocket<'a> {
279+
fn poll_read(
280+
self: std::pin::Pin<&mut Self>,
281+
cx: &mut Context<'_>,
282+
buf: &mut io::ReadBuf<'_>,
283+
) -> Poll<io::Result<()>> {
284+
match self.get_mut().socket.poll_recv_from(cx, buf) {
285+
Poll::Ready(Ok(_addr)) => Poll::Ready(Ok(())),
286+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
287+
Poll::Pending => Poll::Pending,
288+
}
289+
}
290+
}
291+
292+
impl<'a> AsyncWrite for AsyncUdpSocket<'a> {
293+
fn poll_write(
294+
self: std::pin::Pin<&mut Self>,
295+
cx: &mut Context<'_>,
296+
buf: &[u8],
297+
) -> Poll<std::result::Result<usize, std::io::Error>> {
298+
self.get_mut().socket.poll_send(cx, buf)
299+
}
300+
301+
fn poll_flush(
302+
self: std::pin::Pin<&mut Self>,
303+
_cx: &mut Context<'_>,
304+
) -> Poll<std::result::Result<(), std::io::Error>> {
305+
Poll::Ready(Ok(())) // No-op for UDP
306+
}
307+
308+
fn poll_shutdown(
309+
self: std::pin::Pin<&mut Self>,
310+
_cx: &mut Context<'_>,
311+
) -> Poll<std::result::Result<(), std::io::Error>> {
312+
Poll::Ready(Ok(())) // No-op for UDP
313+
}
314+
}

0 commit comments

Comments
 (0)