Skip to content

Commit c60d8eb

Browse files
committed
refactor(server): port range and exclusive port
1 parent c1ba9ac commit c60d8eb

File tree

7 files changed

+172
-23
lines changed

7 files changed

+172
-23
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ fastrand = "2.1.0"
3838
prost = "0.12.6"
3939
async-shutdown = "0.2.2"
4040
httparse = "1.9.4"
41+
rand = "0.8.5"
42+
rand_chacha = "0.3.1"
4143

4244
[build-dependencies]
4345
tonic-build = "0.11.0"

src/bin/castled.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ struct Args {
4040
/// Maximum accepted port number.
4141
#[clap(long, default_value_t = 65535)]
4242
random_max_port: u16,
43+
44+
#[clap(long, default_value = "[]")]
45+
exclude_ports: Vec<u16>,
4346
}
4447

4548
#[tokio::main]
@@ -71,6 +74,7 @@ async fn main() {
7174
ip: args.ip,
7275
vhttp_behind_proxy_tls: args.vhttp_behind_proxy_tls,
7376
port_range: args.random_min_port..=args.random_max_port,
77+
exclude_ports: args.exclude_ports,
7478
},
7579
});
7680
if let Err(err) = server.run(cancel.cancelled()).await {

src/server/data_server.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::event::{self, ClientEventResponse, Payload};
1+
use crate::{
2+
event::{self, ClientEventResponse, Payload},
3+
server::port::{Available, PortManager},
4+
};
25

36
use super::{
47
tunnel::{
@@ -29,14 +32,20 @@ pub(crate) struct DataServer {
2932
http_tunnel: Http,
3033
http_registry: DynamicRegistry,
3134
entrypoint_config: EntrypointConfig,
35+
port_manager: PortManager,
3236
}
3337

3438
impl DataServer {
3539
pub(crate) fn new(vhttp_port: u16, entrypoint_config: EntrypointConfig) -> Self {
3640
let http_registry = DynamicRegistry::new();
41+
let port_manager = PortManager::new(
42+
entrypoint_config.port_range.clone(),
43+
entrypoint_config.exclude_ports.clone(),
44+
);
3745
Self {
3846
http_registry: http_registry.clone(),
3947
http_tunnel: Http::new(vhttp_port, Arc::new(Box::new(http_registry))),
48+
port_manager,
4049
entrypoint_config,
4150
}
4251
}
@@ -62,8 +71,8 @@ impl DataServer {
6271
while let Some(event) = receiver.recv().await {
6372
match event.payload {
6473
event::Payload::RegisterTcp { port } => {
65-
let result: Result<(u16, TcpListener), tonic::Status> =
66-
create_socket::<Tcp>(port, this.entrypoint_config.port_range.clone()).await;
74+
let result: Result<(Available, TcpListener), tonic::Status> =
75+
create_socket::<Tcp>(port, &mut self.port_manager).await;
6776
match result {
6877
Ok((port, listener)) => {
6978
let cancel = event.close_listener;
@@ -72,12 +81,13 @@ impl DataServer {
7281
Tcp::new(listener, conn_event_chan.clone())
7382
.serve(cancel)
7483
.await;
75-
info!(port, "tcp server closed");
84+
info!(port = *port, "tcp server closed");
7685
});
7786
event
7887
.resp
7988
.send(ClientEventResponse::registered(
80-
this.entrypoint_config.make_entrypoint(&event.payload, port),
89+
this.entrypoint_config
90+
.make_entrypoint(&event.payload, *port),
8191
))
8292
.unwrap(); // success
8393
}
@@ -90,8 +100,8 @@ impl DataServer {
90100
}
91101
}
92102
event::Payload::RegisterUdp { port } => {
93-
let result: Result<(u16, UdpSocket), tonic::Status> =
94-
create_socket::<Udp>(port, this.entrypoint_config.port_range.clone()).await;
103+
let result: Result<(Available, UdpSocket), tonic::Status> =
104+
create_socket::<Udp>(port, &mut self.port_manager).await;
95105

96106
match result {
97107
Ok((port, socket)) => {
@@ -102,12 +112,13 @@ impl DataServer {
102112
Udp::new(socket, conn_event_chan.clone())
103113
.serve(cancel)
104114
.await;
105-
info!(port, "udp server closed");
115+
info!(port = *port, "udp server closed");
106116
});
107117
event
108118
.resp
109119
.send(ClientEventResponse::registered(
110-
this.entrypoint_config.make_entrypoint(&event.payload, port),
120+
this.entrypoint_config
121+
.make_entrypoint(&event.payload, *port),
111122
))
112123
.unwrap(); // success
113124
}
@@ -179,7 +190,7 @@ impl DataServer {
179190
}
180191

181192
async fn register_http(
182-
&self,
193+
mut self,
183194
shutdown: CancellationToken,
184195
domain: Bytes,
185196
subdomain: &mut Bytes,
@@ -234,15 +245,14 @@ impl DataServer {
234245
None
235246
}
236247
} else {
237-
let result =
238-
create_socket::<Tcp>(*port, self.entrypoint_config.port_range.clone()).await;
248+
let result = create_socket::<Tcp>(*port, &mut self.port_manager).await;
239249
match result {
240-
Ok((random_port, listener)) => {
241-
*port = random_port;
250+
Ok((available_port, listener)) => {
251+
*port = *available_port;
242252
let conn_event_chan = conn_event_chan.clone();
243253
spawn(async move {
244254
Http::new(
245-
random_port,
255+
*available_port,
246256
Arc::new(Box::new(FixedRegistry::new(conn_event_chan))),
247257
)
248258
.serve_with_listener(listener, shutdown);

src/server/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod control_server;
22
mod data_server;
3+
mod port;
34
mod tunnel;
45
pub use control_server::Server;
56

@@ -24,6 +25,7 @@ pub struct EntrypointConfig {
2425
pub ip: Vec<IpAddr>,
2526
pub vhttp_behind_proxy_tls: bool,
2627
pub port_range: RangeInclusive<u16>,
28+
pub exclude_ports: Vec<u16>,
2729
}
2830

2931
impl Default for Config {
@@ -43,6 +45,7 @@ impl Default for EntrypointConfig {
4345
ip: Vec::new(),
4446
vhttp_behind_proxy_tls: false,
4547
port_range: 1024..=65535,
48+
exclude_ports: Vec::new(),
4649
}
4750
}
4851
}

src/server/port.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
use std::sync::Arc;
2+
3+
use dashmap::DashSet;
4+
use rand::prelude::*;
5+
use rand::seq::IteratorRandom;
6+
use rand_chacha::ChaCha20Rng;
7+
8+
pub struct PortManager {
9+
rng: ChaCha20Rng,
10+
pool: Arc<DashSet<u16>>,
11+
}
12+
13+
impl PortManager {
14+
pub fn new(port_range: std::ops::RangeInclusive<u16>, exclude_ports: Vec<u16>) -> Self {
15+
let ports: Vec<_> = port_range
16+
.filter(|port| !exclude_ports.contains(port))
17+
.collect();
18+
19+
let pool = DashSet::from_iter(ports);
20+
let rng = ChaCha20Rng::from_entropy();
21+
Self {
22+
rng,
23+
pool: Arc::new(pool),
24+
}
25+
}
26+
27+
pub fn get(&mut self) -> Option<Available> {
28+
let port = *self.pool.iter().choose(&mut self.rng)?;
29+
self.pool.remove(&port);
30+
Some(Available {
31+
port,
32+
pool: Some(Arc::clone(&self.pool)),
33+
})
34+
}
35+
36+
pub fn remove(&mut self, port: u16) {
37+
self.pool.remove(&port);
38+
}
39+
}
40+
41+
#[derive(Debug)]
42+
pub struct Available {
43+
port: u16,
44+
/// the pool may not exist if creates `Available` directly.
45+
pool: Option<Arc<DashSet<u16>>>,
46+
}
47+
48+
impl Drop for Available {
49+
fn drop(&mut self) {
50+
if let Some(pool) = &self.pool {
51+
pool.insert(self.port);
52+
}
53+
}
54+
}
55+
56+
impl std::ops::Deref for Available {
57+
type Target = u16;
58+
59+
fn deref(&self) -> &Self::Target {
60+
&self.port
61+
}
62+
}
63+
64+
impl From<u16> for Available {
65+
fn from(port: u16) -> Self {
66+
Self { port, pool: None }
67+
}
68+
}
69+
70+
impl From<Available> for u16 {
71+
fn from(available: Available) -> Self {
72+
available.port
73+
}
74+
}
75+
76+
#[cfg(test)]
77+
mod test {
78+
use std::ops::RangeInclusive;
79+
80+
use super::*;
81+
82+
#[test]
83+
fn test_port_manager() {
84+
let port_range: RangeInclusive<u16> = 2000..=2100;
85+
let exclude_ports = vec![2010, 2020];
86+
let len = port_range.len() - exclude_ports.len();
87+
let mut port_manager = PortManager::new(port_range, exclude_ports.clone());
88+
89+
for _ in 0..10000 {
90+
let port = port_manager.get();
91+
assert!(port.is_some());
92+
// auto drop
93+
}
94+
95+
let ports: DashSet<u16> = Default::default();
96+
for _ in 0..1000 {
97+
let port = port_manager.get();
98+
if port.is_none() {
99+
break;
100+
}
101+
ports.insert(*port.unwrap());
102+
}
103+
assert_eq!(ports.len(), len);
104+
assert!(!ports.contains(&2010));
105+
assert!(!ports.contains(&2020));
106+
107+
drop(ports);
108+
109+
port_manager.remove(2000);
110+
let ports: DashSet<u16> = Default::default();
111+
for _ in 0..1000 {
112+
let port = port_manager.get();
113+
if port.is_none() {
114+
break;
115+
}
116+
ports.insert(*port.unwrap());
117+
}
118+
assert_eq!(ports.len(), len - 1);
119+
assert!(!ports.contains(&2010));
120+
assert!(!ports.contains(&2020));
121+
assert!(!ports.contains(&2000));
122+
}
123+
}

src/server/tunnel/mod.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::ops::RangeInclusive;
2-
31
use crate::{
42
bridge::{self, DataSenderBridge, IdDataSenderBridge},
53
event,
@@ -11,6 +9,8 @@ use tokio_util::sync::CancellationToken;
119
use tonic::Status;
1210
use uuid::Uuid;
1311

12+
use super::port::{Available, PortManager};
13+
1414
pub(crate) mod http;
1515
pub(crate) mod tcp;
1616
pub(crate) mod udp;
@@ -87,21 +87,26 @@ pub(crate) trait SocketCreator {
8787

8888
pub(crate) async fn create_socket<T: SocketCreator>(
8989
port: u16,
90-
free_port_range: RangeInclusive<u16>,
91-
) -> anyhow::Result<(u16, T::Output), Status> {
90+
port_manager: &mut PortManager,
91+
) -> anyhow::Result<(Available, T::Output), Status> {
9292
if port > 0 {
9393
let socket = T::create_socket(port).await?;
94-
Ok((port, socket))
94+
Ok((port.into(), socket))
9595
} else {
9696
// refer: https://github.com/ekzhang/bore/blob/v0.5.1/src/server.rs#L88
9797
// todo: a better way to find a free port
9898
for _ in 0..150 {
99-
let freeport = fastrand::u16(free_port_range.clone());
100-
let result = T::create_socket(freeport).await;
99+
let port: Available = match port_manager.get().await {
100+
None => {
101+
return Err(Status::resource_exhausted("no available port"));
102+
}
103+
Some(port) => port,
104+
};
105+
let result = T::create_socket(*port).await;
101106
if result.is_err() {
102107
continue;
103108
}
104-
return Ok((freeport, result.unwrap()));
109+
return Ok((port, result.unwrap()));
105110
}
106111
Err(Status::internal("failed to find a free port"))
107112
}

0 commit comments

Comments
 (0)