|
1 |
| -use anyhow::{Context, Result}; |
| 1 | +use anyhow::{Context as _, Result}; |
2 | 2 | use async_shutdown::{ShutdownManager, ShutdownSignal};
|
3 |
| -use std::net::SocketAddr; |
| 3 | +use std::task::{Context, Poll}; |
| 4 | +use std::{net::SocketAddr, pin::Pin}; |
4 | 5 | use tokio_stream::{wrappers::ReceiverStream, StreamExt};
|
5 | 6 | use tonic::{transport::Channel, Response, Status, Streaming};
|
6 | 7 | use tracing::{debug, error, info, instrument, span};
|
@@ -375,7 +376,7 @@ async fn handle_work_traffic(
|
375 | 376 | });
|
376 | 377 |
|
377 | 378 | let wrapper = TrafficToServerWrapper::new(connection_id.clone());
|
378 |
| - let mut writer = StreamingWriter::new(streaming_tx.clone(), wrapper); |
| 379 | + let writer = StreamingWriter::new(streaming_tx.clone(), wrapper); |
379 | 380 |
|
380 | 381 | if is_udp {
|
381 | 382 | tokio::spawn(async move {
|
@@ -410,30 +411,16 @@ async fn handle_work_traffic(
|
410 | 411 |
|
411 | 412 | local_conn_established_tx.send(()).await.unwrap();
|
412 | 413 |
|
413 |
| - let read_transfer_send_to_local = async { |
414 |
| - while let Some(buf) = transfer_rx.recv().await { |
415 |
| - socket.send(&buf.data).await.unwrap(); |
416 |
| - } |
417 |
| - }; |
418 |
| - |
419 |
| - let read_local_send_to_server = async { |
420 |
| - loop { |
421 |
| - let mut buf = vec![0u8; 65507]; |
422 |
| - let result = socket.recv(&mut buf).await; |
423 |
| - match result { |
424 |
| - Ok(n) => { |
425 |
| - writer.write_all(&buf[..n]).await.unwrap(); |
426 |
| - } |
427 |
| - Err(err) => { |
428 |
| - error!(err = ?err, "failed to read from local endpoint"); |
429 |
| - break; |
430 |
| - } |
431 |
| - } |
432 |
| - } |
433 |
| - writer.shutdown().await.unwrap(); |
434 |
| - }; |
435 |
| - |
436 |
| - tokio::join!(read_transfer_send_to_local, read_local_send_to_server); |
| 414 | + if let Err(err) = forward_traffic_to_local( |
| 415 | + AsyncUdpSocket::new(&socket), |
| 416 | + AsyncUdpSocket::new(&socket), |
| 417 | + StreamingReader::new(transfer_rx), |
| 418 | + writer, |
| 419 | + ) |
| 420 | + .await |
| 421 | + { |
| 422 | + debug!("failed to forward traffic to local: {:?}", err); |
| 423 | + } |
437 | 424 | });
|
438 | 425 | } else {
|
439 | 426 | tokio::spawn(async move {
|
@@ -485,8 +472,8 @@ async fn handle_work_traffic(
|
485 | 472 | async fn forward_traffic_to_local(
|
486 | 473 | local_r: impl AsyncRead + Unpin,
|
487 | 474 | mut local_w: impl AsyncWrite + Unpin,
|
488 |
| - remote_r: StreamingReader<TrafficToClient>, |
489 |
| - mut remote_w: StreamingWriter<TrafficToServer>, |
| 475 | + remote_r: impl AsyncRead + Unpin, |
| 476 | + mut remote_w: impl AsyncWrite + Unpin, |
490 | 477 | ) -> Result<()> {
|
491 | 478 | let remote_to_me_to_local = async {
|
492 | 479 | // read from remote, write to local
|
@@ -528,3 +515,51 @@ async fn forward_traffic_to_local(
|
528 | 515 |
|
529 | 516 | Ok(())
|
530 | 517 | }
|
| 518 | + |
| 519 | +struct AsyncUdpSocket<'a> { |
| 520 | + socket: &'a UdpSocket, |
| 521 | +} |
| 522 | + |
| 523 | +impl<'a> AsyncUdpSocket<'a> { |
| 524 | + fn new(socket: &'a UdpSocket) -> Self { |
| 525 | + Self { socket } |
| 526 | + } |
| 527 | +} |
| 528 | + |
| 529 | +impl<'a> AsyncRead for AsyncUdpSocket<'a> { |
| 530 | + fn poll_read( |
| 531 | + self: Pin<&mut Self>, |
| 532 | + cx: &mut Context<'_>, |
| 533 | + buf: &mut io::ReadBuf<'_>, |
| 534 | + ) -> Poll<io::Result<()>> { |
| 535 | + match self.get_mut().socket.poll_recv_from(cx, buf) { |
| 536 | + Poll::Ready(Ok(_addr)) => Poll::Ready(Ok(())), |
| 537 | + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
| 538 | + Poll::Pending => Poll::Pending, |
| 539 | + } |
| 540 | + } |
| 541 | +} |
| 542 | + |
| 543 | +impl<'a> AsyncWrite for AsyncUdpSocket<'a> { |
| 544 | + fn poll_write( |
| 545 | + self: Pin<&mut Self>, |
| 546 | + cx: &mut Context<'_>, |
| 547 | + buf: &[u8], |
| 548 | + ) -> Poll<std::result::Result<usize, std::io::Error>> { |
| 549 | + self.get_mut().socket.poll_send(cx, buf) |
| 550 | + } |
| 551 | + |
| 552 | + fn poll_flush( |
| 553 | + self: Pin<&mut Self>, |
| 554 | + _cx: &mut Context<'_>, |
| 555 | + ) -> Poll<std::result::Result<(), std::io::Error>> { |
| 556 | + Poll::Ready(Ok(())) // No-op for UDP |
| 557 | + } |
| 558 | + |
| 559 | + fn poll_shutdown( |
| 560 | + self: Pin<&mut Self>, |
| 561 | + _cx: &mut Context<'_>, |
| 562 | + ) -> Poll<std::result::Result<(), std::io::Error>> { |
| 563 | + Poll::Ready(Ok(())) // No-op for UDP |
| 564 | + } |
| 565 | +} |
0 commit comments