Skip to content

Commit 578b685

Browse files
authored
refactor rust udp (#101)
* refactor rust udp * refactor a socket module to dial tcp and udp * refactor(server): add a dialer to tunnel to create local connection * add some docs * add test case
1 parent dd62288 commit 578b685

File tree

10 files changed

+249
-147
lines changed

10 files changed

+249
-147
lines changed

src/client/client.rs

Lines changed: 38 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
use anyhow::{Context, Result};
1+
use anyhow::{Context as _, Result};
22
use async_shutdown::{ShutdownManager, ShutdownSignal};
3-
use std::net::SocketAddr;
3+
use std::{net::SocketAddr, sync::Arc};
44
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
55
use tonic::{transport::Channel, Response, Status, Streaming};
66
use tracing::{debug, error, info, instrument, span};
77

88
use tokio::{
99
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
10-
net::{TcpStream, UdpSocket},
1110
select,
1211
sync::{mpsc, oneshot},
1312
};
1413

14+
use crate::socket::Dialer;
1515
use crate::{
1616
constant,
1717
io::{StreamingReader, StreamingWriter, TrafficToServerWrapper},
1818
pb::{
19-
self, control_command::Payload, traffic_to_server, tunnel::Config,
19+
self, control_command::Payload, traffic_to_server,
2020
tunnel_service_client::TunnelServiceClient, ControlCommand, RegisterReq, TrafficToClient,
2121
TrafficToServer,
2222
},
@@ -83,13 +83,13 @@ impl Client {
8383
) -> Result<Vec<String>> {
8484
let (entrypoint_tx, entrypoint_rx) = oneshot::channel();
8585
let pb_tunnel = tunnel.config.to_pb_tunnel(tunnel.name);
86-
let local_endpoint = tunnel.local_endpoint;
86+
let dialer = tunnel.dialer;
8787

8888
tokio::spawn(async move {
8989
let run_tunnel = self.handle_tunnel(
9090
shutdown.wait_shutdown_triggered(),
9191
pb_tunnel,
92-
local_endpoint,
92+
dialer,
9393
Some(move |entrypoint| {
9494
let _ = entrypoint_tx.send(entrypoint);
9595
}),
@@ -184,11 +184,10 @@ impl Client {
184184
&self,
185185
shutdown: ShutdownSignal<i8>,
186186
tunnel: pb::Tunnel,
187-
local_endpoint: SocketAddr,
187+
dial: Dialer,
188188
hook: Option<impl FnOnce(Vec<String>) + Send + 'static>,
189189
) -> Result<()> {
190190
let mut rpc_client = self.grpc_client.clone();
191-
let is_udp = matches!(tunnel.config, Some(Config::Udp(_)));
192191
let register = self.register_tunnel(&mut rpc_client, tunnel);
193192

194193
tokio::select! {
@@ -201,8 +200,7 @@ impl Client {
201200
shutdown.clone(),
202201
rpc_client,
203202
register_resp,
204-
local_endpoint,
205-
is_udp,
203+
dial,
206204
hook,
207205
).await
208206
}
@@ -216,8 +214,7 @@ impl Client {
216214
shutdown: ShutdownSignal<i8>,
217215
rpc_client: TunnelServiceClient<Channel>,
218216
register_resp: tonic::Response<Streaming<ControlCommand>>,
219-
local_endpoint: SocketAddr,
220-
is_udp: bool,
217+
dialer: Dialer,
221218
mut hook: Option<impl FnOnce(Vec<String>) + Send + 'static>,
222219
) -> Result<()> {
223220
let mut control_stream = register_resp.into_inner();
@@ -229,24 +226,18 @@ impl Client {
229226
hook(entrypoint);
230227
}
231228

232-
self.start_streaming(
233-
shutdown,
234-
&mut control_stream,
235-
rpc_client,
236-
local_endpoint,
237-
is_udp,
238-
)
239-
.await
229+
self.start_streaming(shutdown, &mut control_stream, rpc_client, dialer)
230+
.await
240231
}
241232

242233
async fn start_streaming(
243234
&self,
244235
shutdown: ShutdownSignal<i8>,
245236
control_stream: &mut Streaming<ControlCommand>,
246237
rpc_client: TunnelServiceClient<Channel>,
247-
local_endpoint: SocketAddr,
248-
is_udp: bool,
238+
dialer: Dialer,
249239
) -> Result<()> {
240+
let dialer = Arc::new(dialer);
250241
loop {
251242
tokio::select! {
252243
result = control_stream.next() => {
@@ -268,8 +259,7 @@ impl Client {
268259
if let Err(err) = handle_work_traffic(
269260
rpc_client.clone() /* cheap clone operation */,
270261
&work.connection_id,
271-
local_endpoint,
272-
is_udp,
262+
dialer.clone(),
273263
).await {
274264
error!(err = ?err, "failed to handle work traffic");
275265
} else {
@@ -307,8 +297,7 @@ async fn new_rpc_client(control_addr: SocketAddr) -> Result<TunnelServiceClient<
307297
async fn handle_work_traffic(
308298
mut rpc_client: TunnelServiceClient<Channel>,
309299
connection_id: &str,
310-
local_endpoint: SocketAddr,
311-
is_udp: bool,
300+
dialer: Arc<Dialer>,
312301
) -> Result<()> {
313302
// write response to the streaming_tx
314303
// rpc_client sends the data from reading the streaming_rx
@@ -331,7 +320,7 @@ async fn handle_work_traffic(
331320

332321
// write the data streaming response to transfer_tx,
333322
// then forward_traffic_to_local can read the data from transfer_rx
334-
let (transfer_tx, mut transfer_rx) = mpsc::channel::<TrafficToClient>(64);
323+
let (transfer_tx, transfer_rx) = mpsc::channel::<TrafficToClient>(64);
335324

336325
let (local_conn_established_tx, local_conn_established_rx) = mpsc::channel::<()>(1);
337326
let mut local_conn_established_rx = Some(local_conn_established_rx);
@@ -375,72 +364,24 @@ async fn handle_work_traffic(
375364
});
376365

377366
let wrapper = TrafficToServerWrapper::new(connection_id.clone());
378-
let mut writer = StreamingWriter::new(streaming_tx.clone(), wrapper);
379-
380-
if is_udp {
381-
tokio::spawn(async move {
382-
let local_addr: SocketAddr = if local_endpoint.is_ipv4() {
383-
"0.0.0.0:0"
384-
} else {
385-
"[::]:0"
386-
}
387-
.parse()
388-
.unwrap();
389-
let socket = UdpSocket::bind(local_addr).await;
390-
if socket.is_err() {
391-
error!(err = ?socket.err(), "failed to init udp socket, so let's notify the server to close the user connection");
392-
393-
streaming_tx
394-
.send(TrafficToServer {
395-
connection_id: connection_id.to_string(),
396-
action: traffic_to_server::Action::Close as i32,
397-
..Default::default()
398-
})
399-
.await
400-
.context("terrible, the server may be crashed")
401-
.unwrap();
402-
return;
403-
}
404-
405-
let socket = socket.unwrap();
406-
let result = socket.connect(local_endpoint).await;
407-
if let Err(err) = result {
408-
error!(err = ?err, "failed to connect to local endpoint, so let's notify the server to close the user connection");
409-
}
367+
let writer = StreamingWriter::new(streaming_tx.clone(), wrapper);
410368

411-
local_conn_established_tx.send(()).await.unwrap();
412-
413-
let read_transfer_send_to_local = async {
414-
while let Some(buf) = transfer_rx.recv().await {
415-
socket.send(&buf.data).await.unwrap();
416-
}
417-
};
418-
419-
let read_local_send_to_server = async {
420-
loop {
421-
let mut buf = vec![0u8; 65507];
422-
let result = socket.recv(&mut buf).await;
423-
match result {
424-
Ok(n) => {
425-
writer.write_all(&buf[..n]).await.unwrap();
426-
}
427-
Err(err) => {
428-
error!(err = ?err, "failed to read from local endpoint");
429-
break;
430-
}
431-
}
369+
tokio::spawn(async move {
370+
match dialer.dial().await {
371+
Ok((local_r, local_w)) => {
372+
local_conn_established_tx.send(()).await.unwrap();
373+
if let Err(err) =
374+
transfer(local_r, local_w, StreamingReader::new(transfer_rx), writer).await
375+
{
376+
debug!("failed to forward traffic to local: {:?}", err);
432377
}
433-
writer.shutdown().await.unwrap();
434-
};
435-
436-
tokio::join!(read_transfer_send_to_local, read_local_send_to_server);
437-
});
438-
} else {
439-
tokio::spawn(async move {
440-
// TODO(sword): use a connection pool to reuse the tcp connection
441-
let local_conn = TcpStream::connect(local_endpoint).await;
442-
if local_conn.is_err() {
443-
error!("failed to connect to local endpoint {}, so let's notify the server to close the user connection", local_endpoint);
378+
}
379+
Err(err) => {
380+
error!(
381+
local_endpoint = ?dialer.addr(),
382+
?err,
383+
"failed to connect to local endpoint, so let's notify the server to close the user connection",
384+
);
444385

445386
streaming_tx
446387
.send(TrafficToServer {
@@ -451,42 +392,26 @@ async fn handle_work_traffic(
451392
.await
452393
.context("terrible, the server may be crashed")
453394
.unwrap();
454-
return;
455-
}
456-
457-
let mut local_conn = local_conn.unwrap();
458-
let (local_r, local_w) = local_conn.split();
459-
local_conn_established_tx.send(()).await.unwrap();
460-
461-
if let Err(err) = forward_traffic_to_local(
462-
local_r,
463-
local_w,
464-
StreamingReader::new(transfer_rx),
465-
writer,
466-
)
467-
.await
468-
{
469-
debug!("failed to forward traffic to local: {:?}", err);
470395
}
471-
});
472-
}
396+
}
397+
});
473398

474399
Ok(())
475400
}
476401

477-
/// Forwards the traffic from the server to the local endpoint.
402+
/// transfer the traffic from the server to the local endpoint.
478403
///
479404
/// Try to imagine the current client is yourself,
480405
/// your mission is to forward the traffic from the server to the local,
481406
/// then write the original response back to the server.
482407
/// in this process, there are two underlying connections:
483408
/// 1. remote <=> me
484409
/// 2. me <=> local
485-
async fn forward_traffic_to_local(
410+
async fn transfer(
486411
local_r: impl AsyncRead + Unpin,
487412
mut local_w: impl AsyncWrite + Unpin,
488-
remote_r: StreamingReader<TrafficToClient>,
489-
mut remote_w: StreamingWriter<TrafficToServer>,
413+
remote_r: impl AsyncRead + Unpin,
414+
mut remote_w: impl AsyncWrite + Unpin,
490415
) -> Result<()> {
491416
let remote_to_me_to_local = async {
492417
// read from remote, write to local

src/client/tunnel.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@ use std::net::SocketAddr;
44

55
use bytes::Bytes;
66

7-
use crate::pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig};
7+
use crate::{
8+
pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig},
9+
socket::{dial_tcp, dial_udp, Dialer},
10+
};
811

912
/// Tunnel configuration for the client.
1013
#[derive(Debug)]
1114
pub struct Tunnel<'a> {
1215
pub(crate) name: &'a str,
13-
pub(crate) local_endpoint: SocketAddr,
16+
pub(crate) dialer: Dialer,
1417
pub(crate) config: RemoteConfig<'a>,
1518
}
1619

@@ -19,7 +22,14 @@ impl<'a> Tunnel<'a> {
1922
pub fn new(name: &'a str, local_endpoint: SocketAddr, config: RemoteConfig<'a>) -> Self {
2023
Self {
2124
name,
22-
local_endpoint,
25+
dialer: Dialer::new(
26+
match config {
27+
RemoteConfig::Tcp(_) => |endpoint| Box::pin(dial_tcp(endpoint)),
28+
RemoteConfig::Udp(_) => |endpoint| Box::pin(dial_udp(endpoint)),
29+
RemoteConfig::Http(_) => |endpoint| Box::pin(dial_tcp(endpoint)),
30+
},
31+
local_endpoint,
32+
),
2333
config,
2434
}
2535
}

src/helper.rs

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,8 @@
1-
use crate::pb::RegisterReq;
2-
3-
use tokio::net::{TcpListener, UdpSocket};
41
use tonic::Status;
5-
use tracing::error;
62

7-
pub(crate) async fn create_tcp_listener(port: u16) -> Result<TcpListener, Status> {
8-
TcpListener::bind(("0.0.0.0", port))
9-
.await
10-
.map_err(map_bind_error)
11-
}
12-
13-
pub(crate) async fn create_udp_socket(port: u16) -> Result<UdpSocket, Status> {
14-
UdpSocket::bind(("0.0.0.0", port))
15-
.await
16-
.map_err(map_bind_error)
17-
}
18-
19-
fn map_bind_error(err: std::io::Error) -> Status {
20-
match err.kind() {
21-
std::io::ErrorKind::AddrInUse => Status::already_exists("port already in use"),
22-
std::io::ErrorKind::PermissionDenied => Status::permission_denied("permission denied"),
23-
_ => {
24-
error!("failed to bind port: {}", err);
25-
Status::internal("failed to bind port")
26-
}
27-
}
28-
}
3+
use crate::pb::RegisterReq;
294

30-
pub fn validate_register_req(req: &RegisterReq) -> Option<Status> {
5+
pub(crate) fn validate_register_req(req: &RegisterReq) -> Option<Status> {
316
if req.tunnel.is_none() {
327
return Some(Status::invalid_argument("tunnel is required"));
338
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pub(crate) mod constant;
33
pub(crate) mod event;
44
pub(crate) mod helper;
55
pub(crate) mod io;
6+
pub(crate) mod socket;
67

78
pub mod pb {
89
include!("gen/message.rs");

src/server/control_server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type DataStream = Pin<Box<dyn Stream<Item = GrpcResult<TrafficToClient>> + Send>
3838
///
3939
/// We treat the control server is grpc server as well, in the concept,
4040
/// they are same thing.
41-
/// Although the grpc server provides a [`crate::protocol::pb::tunnel_service_server::TunnelService::data`],
41+
/// Although the grpc server provides a [`crate::pb::tunnel_service_server::TunnelService::data`],
4242
/// it's similar to the data server(a little), but in the `data` function body,
4343
/// the most of work is to forward the data from client to data server.
4444
/// We can understand this is a tunnel between the client and the data server.

src/server/tunnel/http.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::bridge::BridgeData;
22
use crate::event::{self, IncomingEventSender};
3-
use crate::helper::create_tcp_listener;
3+
use crate::socket::create_tcp_listener;
44

55
use super::{init_data_sender_bridge, BridgeResult};
66
use anyhow::{Context as _, Result};

src/server/tunnel/tcp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::{
22
event,
3-
helper::create_tcp_listener,
43
io::{StreamingReader, StreamingWriter, VecWrapper},
54
server::tunnel::BridgeResult,
5+
socket::create_tcp_listener,
66
};
77
use anyhow::Context as _;
88
use tokio::{

src/server/tunnel/udp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{net::SocketAddr, sync::Arc};
22

3-
use crate::{bridge::BridgeData, event, helper::create_udp_socket, server::tunnel::BridgeResult};
3+
use crate::{bridge::BridgeData, event, server::tunnel::BridgeResult, socket::create_udp_socket};
44
use dashmap::DashMap;
55
use tokio::{net::UdpSocket, select, sync::mpsc};
66
use tokio_util::sync::CancellationToken;

0 commit comments

Comments
 (0)