Skip to content

Commit 3bc294c

Browse files
committed
refactor rust udp
1 parent dd62288 commit 3bc294c

File tree

1 file changed

+64
-29
lines changed

1 file changed

+64
-29
lines changed

src/client/client.rs

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
use anyhow::{Context, Result};
1+
use anyhow::{Context as _, Result};
22
use async_shutdown::{ShutdownManager, ShutdownSignal};
3-
use std::net::SocketAddr;
3+
use std::task::{Context, Poll};
4+
use std::{net::SocketAddr, pin::Pin};
45
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
56
use tonic::{transport::Channel, Response, Status, Streaming};
67
use tracing::{debug, error, info, instrument, span};
@@ -375,7 +376,7 @@ async fn handle_work_traffic(
375376
});
376377

377378
let wrapper = TrafficToServerWrapper::new(connection_id.clone());
378-
let mut writer = StreamingWriter::new(streaming_tx.clone(), wrapper);
379+
let writer = StreamingWriter::new(streaming_tx.clone(), wrapper);
379380

380381
if is_udp {
381382
tokio::spawn(async move {
@@ -410,30 +411,16 @@ async fn handle_work_traffic(
410411

411412
local_conn_established_tx.send(()).await.unwrap();
412413

413-
let read_transfer_send_to_local = async {
414-
while let Some(buf) = transfer_rx.recv().await {
415-
socket.send(&buf.data).await.unwrap();
416-
}
417-
};
418-
419-
let read_local_send_to_server = async {
420-
loop {
421-
let mut buf = vec![0u8; 65507];
422-
let result = socket.recv(&mut buf).await;
423-
match result {
424-
Ok(n) => {
425-
writer.write_all(&buf[..n]).await.unwrap();
426-
}
427-
Err(err) => {
428-
error!(err = ?err, "failed to read from local endpoint");
429-
break;
430-
}
431-
}
432-
}
433-
writer.shutdown().await.unwrap();
434-
};
435-
436-
tokio::join!(read_transfer_send_to_local, read_local_send_to_server);
414+
if let Err(err) = forward_traffic_to_local(
415+
AsyncUdpSocket::new(&socket),
416+
AsyncUdpSocket::new(&socket),
417+
StreamingReader::new(transfer_rx),
418+
writer,
419+
)
420+
.await
421+
{
422+
debug!("failed to forward traffic to local: {:?}", err);
423+
}
437424
});
438425
} else {
439426
tokio::spawn(async move {
@@ -485,8 +472,8 @@ async fn handle_work_traffic(
485472
async fn forward_traffic_to_local(
486473
local_r: impl AsyncRead + Unpin,
487474
mut local_w: impl AsyncWrite + Unpin,
488-
remote_r: StreamingReader<TrafficToClient>,
489-
mut remote_w: StreamingWriter<TrafficToServer>,
475+
remote_r: impl AsyncRead + Unpin,
476+
mut remote_w: impl AsyncWrite + Unpin,
490477
) -> Result<()> {
491478
let remote_to_me_to_local = async {
492479
// read from remote, write to local
@@ -528,3 +515,51 @@ async fn forward_traffic_to_local(
528515

529516
Ok(())
530517
}
518+
519+
struct AsyncUdpSocket<'a> {
520+
socket: &'a UdpSocket,
521+
}
522+
523+
impl<'a> AsyncUdpSocket<'a> {
524+
fn new(socket: &'a UdpSocket) -> Self {
525+
Self { socket }
526+
}
527+
}
528+
529+
impl<'a> AsyncRead for AsyncUdpSocket<'a> {
530+
fn poll_read(
531+
self: Pin<&mut Self>,
532+
cx: &mut Context<'_>,
533+
buf: &mut io::ReadBuf<'_>,
534+
) -> Poll<io::Result<()>> {
535+
match self.get_mut().socket.poll_recv_from(cx, buf) {
536+
Poll::Ready(Ok(_addr)) => Poll::Ready(Ok(())),
537+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
538+
Poll::Pending => Poll::Pending,
539+
}
540+
}
541+
}
542+
543+
impl<'a> AsyncWrite for AsyncUdpSocket<'a> {
544+
fn poll_write(
545+
self: Pin<&mut Self>,
546+
cx: &mut Context<'_>,
547+
buf: &[u8],
548+
) -> Poll<std::result::Result<usize, std::io::Error>> {
549+
self.get_mut().socket.poll_send(cx, buf)
550+
}
551+
552+
fn poll_flush(
553+
self: Pin<&mut Self>,
554+
_cx: &mut Context<'_>,
555+
) -> Poll<std::result::Result<(), std::io::Error>> {
556+
Poll::Ready(Ok(())) // No-op for UDP
557+
}
558+
559+
fn poll_shutdown(
560+
self: Pin<&mut Self>,
561+
_cx: &mut Context<'_>,
562+
) -> Poll<std::result::Result<(), std::io::Error>> {
563+
Poll::Ready(Ok(())) // No-op for UDP
564+
}
565+
}

0 commit comments

Comments
 (0)