1
- use anyhow:: { Context , Result } ;
1
+ use anyhow:: { Context as _ , Result } ;
2
2
use async_shutdown:: { ShutdownManager , ShutdownSignal } ;
3
- use std:: net:: SocketAddr ;
3
+ use std:: { net:: SocketAddr , sync :: Arc } ;
4
4
use tokio_stream:: { wrappers:: ReceiverStream , StreamExt } ;
5
5
use tonic:: { transport:: Channel , Response , Status , Streaming } ;
6
6
use tracing:: { debug, error, info, instrument, span} ;
7
7
8
8
use tokio:: {
9
9
io:: { self , AsyncRead , AsyncWrite , AsyncWriteExt , BufReader } ,
10
- net:: { TcpStream , UdpSocket } ,
11
10
select,
12
11
sync:: { mpsc, oneshot} ,
13
12
} ;
14
13
14
+ use crate :: socket:: Dialer ;
15
15
use crate :: {
16
16
constant,
17
17
io:: { StreamingReader , StreamingWriter , TrafficToServerWrapper } ,
18
18
pb:: {
19
- self , control_command:: Payload , traffic_to_server, tunnel :: Config ,
19
+ self , control_command:: Payload , traffic_to_server,
20
20
tunnel_service_client:: TunnelServiceClient , ControlCommand , RegisterReq , TrafficToClient ,
21
21
TrafficToServer ,
22
22
} ,
@@ -83,13 +83,13 @@ impl Client {
83
83
) -> Result < Vec < String > > {
84
84
let ( entrypoint_tx, entrypoint_rx) = oneshot:: channel ( ) ;
85
85
let pb_tunnel = tunnel. config . to_pb_tunnel ( tunnel. name ) ;
86
- let local_endpoint = tunnel. local_endpoint ;
86
+ let dialer = tunnel. dialer ;
87
87
88
88
tokio:: spawn ( async move {
89
89
let run_tunnel = self . handle_tunnel (
90
90
shutdown. wait_shutdown_triggered ( ) ,
91
91
pb_tunnel,
92
- local_endpoint ,
92
+ dialer ,
93
93
Some ( move |entrypoint| {
94
94
let _ = entrypoint_tx. send ( entrypoint) ;
95
95
} ) ,
@@ -184,11 +184,10 @@ impl Client {
184
184
& self ,
185
185
shutdown : ShutdownSignal < i8 > ,
186
186
tunnel : pb:: Tunnel ,
187
- local_endpoint : SocketAddr ,
187
+ dial : Dialer ,
188
188
hook : Option < impl FnOnce ( Vec < String > ) + Send + ' static > ,
189
189
) -> Result < ( ) > {
190
190
let mut rpc_client = self . grpc_client . clone ( ) ;
191
- let is_udp = matches ! ( tunnel. config, Some ( Config :: Udp ( _) ) ) ;
192
191
let register = self . register_tunnel ( & mut rpc_client, tunnel) ;
193
192
194
193
tokio:: select! {
@@ -201,8 +200,7 @@ impl Client {
201
200
shutdown. clone( ) ,
202
201
rpc_client,
203
202
register_resp,
204
- local_endpoint,
205
- is_udp,
203
+ dial,
206
204
hook,
207
205
) . await
208
206
}
@@ -216,8 +214,7 @@ impl Client {
216
214
shutdown : ShutdownSignal < i8 > ,
217
215
rpc_client : TunnelServiceClient < Channel > ,
218
216
register_resp : tonic:: Response < Streaming < ControlCommand > > ,
219
- local_endpoint : SocketAddr ,
220
- is_udp : bool ,
217
+ dialer : Dialer ,
221
218
mut hook : Option < impl FnOnce ( Vec < String > ) + Send + ' static > ,
222
219
) -> Result < ( ) > {
223
220
let mut control_stream = register_resp. into_inner ( ) ;
@@ -229,24 +226,18 @@ impl Client {
229
226
hook ( entrypoint) ;
230
227
}
231
228
232
- self . start_streaming (
233
- shutdown,
234
- & mut control_stream,
235
- rpc_client,
236
- local_endpoint,
237
- is_udp,
238
- )
239
- . await
229
+ self . start_streaming ( shutdown, & mut control_stream, rpc_client, dialer)
230
+ . await
240
231
}
241
232
242
233
async fn start_streaming (
243
234
& self ,
244
235
shutdown : ShutdownSignal < i8 > ,
245
236
control_stream : & mut Streaming < ControlCommand > ,
246
237
rpc_client : TunnelServiceClient < Channel > ,
247
- local_endpoint : SocketAddr ,
248
- is_udp : bool ,
238
+ dialer : Dialer ,
249
239
) -> Result < ( ) > {
240
+ let dialer = Arc :: new ( dialer) ;
250
241
loop {
251
242
tokio:: select! {
252
243
result = control_stream. next( ) => {
@@ -268,8 +259,7 @@ impl Client {
268
259
if let Err ( err) = handle_work_traffic(
269
260
rpc_client. clone( ) /* cheap clone operation */ ,
270
261
& work. connection_id,
271
- local_endpoint,
272
- is_udp,
262
+ dialer. clone( ) ,
273
263
) . await {
274
264
error!( err = ?err, "failed to handle work traffic" ) ;
275
265
} else {
@@ -307,8 +297,7 @@ async fn new_rpc_client(control_addr: SocketAddr) -> Result<TunnelServiceClient<
307
297
async fn handle_work_traffic (
308
298
mut rpc_client : TunnelServiceClient < Channel > ,
309
299
connection_id : & str ,
310
- local_endpoint : SocketAddr ,
311
- is_udp : bool ,
300
+ dialer : Arc < Dialer > ,
312
301
) -> Result < ( ) > {
313
302
// write response to the streaming_tx
314
303
// rpc_client sends the data from reading the streaming_rx
@@ -331,7 +320,7 @@ async fn handle_work_traffic(
331
320
332
321
// write the data streaming response to transfer_tx,
333
322
// then forward_traffic_to_local can read the data from transfer_rx
334
- let ( transfer_tx, mut transfer_rx) = mpsc:: channel :: < TrafficToClient > ( 64 ) ;
323
+ let ( transfer_tx, transfer_rx) = mpsc:: channel :: < TrafficToClient > ( 64 ) ;
335
324
336
325
let ( local_conn_established_tx, local_conn_established_rx) = mpsc:: channel :: < ( ) > ( 1 ) ;
337
326
let mut local_conn_established_rx = Some ( local_conn_established_rx) ;
@@ -375,72 +364,24 @@ async fn handle_work_traffic(
375
364
} ) ;
376
365
377
366
let wrapper = TrafficToServerWrapper :: new ( connection_id. clone ( ) ) ;
378
- let mut writer = StreamingWriter :: new ( streaming_tx. clone ( ) , wrapper) ;
379
-
380
- if is_udp {
381
- tokio:: spawn ( async move {
382
- let local_addr: SocketAddr = if local_endpoint. is_ipv4 ( ) {
383
- "0.0.0.0:0"
384
- } else {
385
- "[::]:0"
386
- }
387
- . parse ( )
388
- . unwrap ( ) ;
389
- let socket = UdpSocket :: bind ( local_addr) . await ;
390
- if socket. is_err ( ) {
391
- error ! ( err = ?socket. err( ) , "failed to init udp socket, so let's notify the server to close the user connection" ) ;
392
-
393
- streaming_tx
394
- . send ( TrafficToServer {
395
- connection_id : connection_id. to_string ( ) ,
396
- action : traffic_to_server:: Action :: Close as i32 ,
397
- ..Default :: default ( )
398
- } )
399
- . await
400
- . context ( "terrible, the server may be crashed" )
401
- . unwrap ( ) ;
402
- return ;
403
- }
404
-
405
- let socket = socket. unwrap ( ) ;
406
- let result = socket. connect ( local_endpoint) . await ;
407
- if let Err ( err) = result {
408
- error ! ( err = ?err, "failed to connect to local endpoint, so let's notify the server to close the user connection" ) ;
409
- }
367
+ let writer = StreamingWriter :: new ( streaming_tx. clone ( ) , wrapper) ;
410
368
411
- local_conn_established_tx. send ( ( ) ) . await . unwrap ( ) ;
412
-
413
- let read_transfer_send_to_local = async {
414
- while let Some ( buf) = transfer_rx. recv ( ) . await {
415
- socket. send ( & buf. data ) . await . unwrap ( ) ;
416
- }
417
- } ;
418
-
419
- let read_local_send_to_server = async {
420
- loop {
421
- let mut buf = vec ! [ 0u8 ; 65507 ] ;
422
- let result = socket. recv ( & mut buf) . await ;
423
- match result {
424
- Ok ( n) => {
425
- writer. write_all ( & buf[ ..n] ) . await . unwrap ( ) ;
426
- }
427
- Err ( err) => {
428
- error ! ( err = ?err, "failed to read from local endpoint" ) ;
429
- break ;
430
- }
431
- }
369
+ tokio:: spawn ( async move {
370
+ match dialer. dial ( ) . await {
371
+ Ok ( ( local_r, local_w) ) => {
372
+ local_conn_established_tx. send ( ( ) ) . await . unwrap ( ) ;
373
+ if let Err ( err) =
374
+ transfer ( local_r, local_w, StreamingReader :: new ( transfer_rx) , writer) . await
375
+ {
376
+ debug ! ( "failed to forward traffic to local: {:?}" , err) ;
432
377
}
433
- writer. shutdown ( ) . await . unwrap ( ) ;
434
- } ;
435
-
436
- tokio:: join!( read_transfer_send_to_local, read_local_send_to_server) ;
437
- } ) ;
438
- } else {
439
- tokio:: spawn ( async move {
440
- // TODO(sword): use a connection pool to reuse the tcp connection
441
- let local_conn = TcpStream :: connect ( local_endpoint) . await ;
442
- if local_conn. is_err ( ) {
443
- error ! ( "failed to connect to local endpoint {}, so let's notify the server to close the user connection" , local_endpoint) ;
378
+ }
379
+ Err ( err) => {
380
+ error ! (
381
+ local_endpoint = ?dialer. addr( ) ,
382
+ ?err,
383
+ "failed to connect to local endpoint, so let's notify the server to close the user connection" ,
384
+ ) ;
444
385
445
386
streaming_tx
446
387
. send ( TrafficToServer {
@@ -451,42 +392,26 @@ async fn handle_work_traffic(
451
392
. await
452
393
. context ( "terrible, the server may be crashed" )
453
394
. unwrap ( ) ;
454
- return ;
455
- }
456
-
457
- let mut local_conn = local_conn. unwrap ( ) ;
458
- let ( local_r, local_w) = local_conn. split ( ) ;
459
- local_conn_established_tx. send ( ( ) ) . await . unwrap ( ) ;
460
-
461
- if let Err ( err) = forward_traffic_to_local (
462
- local_r,
463
- local_w,
464
- StreamingReader :: new ( transfer_rx) ,
465
- writer,
466
- )
467
- . await
468
- {
469
- debug ! ( "failed to forward traffic to local: {:?}" , err) ;
470
395
}
471
- } ) ;
472
- }
396
+ }
397
+ } ) ;
473
398
474
399
Ok ( ( ) )
475
400
}
476
401
477
- /// Forwards the traffic from the server to the local endpoint.
402
+ /// transfer the traffic from the server to the local endpoint.
478
403
///
479
404
/// Try to imagine the current client is yourself,
480
405
/// your mission is to forward the traffic from the server to the local,
481
406
/// then write the original response back to the server.
482
407
/// in this process, there are two underlying connections:
483
408
/// 1. remote <=> me
484
409
/// 2. me <=> local
485
- async fn forward_traffic_to_local (
410
+ async fn transfer (
486
411
local_r : impl AsyncRead + Unpin ,
487
412
mut local_w : impl AsyncWrite + Unpin ,
488
- remote_r : StreamingReader < TrafficToClient > ,
489
- mut remote_w : StreamingWriter < TrafficToServer > ,
413
+ remote_r : impl AsyncRead + Unpin ,
414
+ mut remote_w : impl AsyncWrite + Unpin ,
490
415
) -> Result < ( ) > {
491
416
let remote_to_me_to_local = async {
492
417
// read from remote, write to local
0 commit comments