Skip to content

Commit 40142e7

Browse files
committed
optimize stn_tun tcp
1 parent 6db8e97 commit 40142e7

File tree

4 files changed

+117
-216
lines changed

4 files changed

+117
-216
lines changed

stn/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ pin-project = "1.0"
4949
trust-dns-proto = { version = "0.20", default-features = false }
5050
lru = "0.6"
5151

52-
# stn_base = { version = "*", path = "../stn_base" }
5352
# stn_dns = { version = "*", path = "../stn_dns" }
5453
# stn_http_proxy_client = { version = "*", path = "../stn_http_proxy_client" }
5554
stn_http_proxy_server = { version = "*", path = "../stn_http_proxy_server" }

stn_tun/src/tcp.rs

Lines changed: 109 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use etherparse::{IpHeader, PacketBuilder, TcpHeader};
1+
use etherparse::{IpHeader, TcpHeader};
22
use std::{
33
io,
44
net::{SocketAddr, ToSocketAddrs},
@@ -24,229 +24,130 @@ impl super::Tun {
2424

2525
pub(crate) async fn handle_tcp(
2626
&self,
27-
ip_header: &IpHeader,
28-
tcp_header: &TcpHeader,
27+
ip_header: &mut IpHeader,
28+
tcp_header: TcpHeader,
2929
payload: &[u8],
30+
buf_len: usize,
3031
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
31-
let mut builder = {
32-
/*
33-
tun 10.1.2.3/24
34-
tcp_redirect4 10.1.2.3:1
35-
36-
raw 10.1.2.3:12345 -> 1.2.3.4:80 SYN
37-
modified 10.1.2.4:12345 -> 10.1.2.3:1 SYN
38-
raw 10.1.2.3:1 -> 10.1.2.4:12345 ACK SYN
39-
modified 1.2.3.4:80 -> 10.1.2.3:12345 ACK SYN
40-
raw 10.1.2.3:12345 -> 1.2.3.4:80 ACK
41-
modified 10.1.2.4:12345 -> 10.1.2.3:1 ACK
42-
*/
43-
match (ip_header, self.tcp_redirect4, self.tcp_redirect6) {
44-
(IpHeader::Version4(ip_header), Some((fake_saddr, fake_daddr)), _) => {
45-
let saddr = (fake_saddr.octets(), tcp_header.source_port).into();
46-
let daddr = (ip_header.destination, tcp_header.destination_port).into();
47-
let mut tcp_map_lock = self.tcp_map.lock().unwrap();
48-
49-
if tcp_header.syn && !tcp_header.ack {
50-
tcp_map_lock.insert(saddr, (daddr, TcpStatus::Established));
51-
}
32+
if let Some(buf) = self.handle_buf(ip_header, tcp_header, payload, buf_len)? {
33+
self.write(&buf).await?;
34+
}
5235

53-
if let Some((_, status)) = tcp_map_lock.get_mut(&saddr) {
54-
match *status {
55-
TcpStatus::Established => {
56-
if tcp_header.fin {
57-
*status = TcpStatus::FinWait;
58-
}
59-
}
60-
TcpStatus::FinWait => {
61-
if tcp_header.fin {
62-
*status = TcpStatus::LastAck;
63-
}
64-
}
65-
TcpStatus::LastAck => {
66-
if tcp_header.ack {
67-
tcp_map_lock.remove(&saddr);
68-
}
69-
}
70-
}
71-
if tcp_header.rst {
72-
tcp_map_lock.remove(&saddr);
73-
}
36+
Ok(())
37+
}
7438

75-
PacketBuilder::ipv4(
76-
fake_saddr.octets(),
77-
fake_daddr.ip().octets(),
78-
ip_header.time_to_live,
79-
)
80-
.tcp(
81-
saddr.port(),
82-
fake_daddr.port(),
83-
tcp_header.sequence_number,
84-
tcp_header.window_size,
85-
)
86-
} else if let Some((SocketAddr::V4(origin_daddr), status)) =
87-
tcp_map_lock.get_mut(&daddr)
88-
{
89-
let builder = PacketBuilder::ipv4(
90-
origin_daddr.ip().octets(),
91-
fake_daddr.ip().octets(),
92-
ip_header.time_to_live,
93-
)
94-
.tcp(
95-
origin_daddr.port(),
96-
daddr.port(),
97-
tcp_header.sequence_number,
98-
tcp_header.window_size,
99-
);
39+
fn handle_buf(
40+
&self,
41+
mut ip_header: &mut IpHeader,
42+
mut tcp_header: TcpHeader,
43+
payload: &[u8],
44+
buf_len: usize,
45+
) -> Result<Option<Vec<u8>>, Box<dyn std::error::Error + Send + Sync>> {
46+
let mut tcp_map_lock = self.tcp_map.lock().unwrap();
47+
48+
/*
49+
tun 10.1.2.3/24
50+
tcp_redirect4 10.1.2.3:1
51+
52+
raw 10.1.2.3:12345 -> 1.2.3.4:80 SYN
53+
modified 10.1.2.4:12345 -> 10.1.2.3:1 SYN
54+
raw 10.1.2.3:1 -> 10.1.2.4:12345 ACK SYN
55+
modified 1.2.3.4:80 -> 10.1.2.3:12345 ACK SYN
56+
raw 10.1.2.3:12345 -> 1.2.3.4:80 ACK
57+
modified 10.1.2.4:12345 -> 10.1.2.3:1 ACK
58+
*/
59+
60+
// modify
61+
let (status, saddr) = match (&mut ip_header, self.tcp_redirect4, self.tcp_redirect6) {
62+
(IpHeader::Version4(ip_header), Some((fake_saddr, fake_daddr)), _) => {
63+
let saddr = (fake_saddr.octets(), tcp_header.source_port).into();
64+
let daddr = (ip_header.destination, tcp_header.destination_port).into();
65+
66+
// syn
67+
if tcp_header.syn && !tcp_header.ack {
68+
tcp_map_lock.insert(saddr, (daddr, TcpStatus::Established));
69+
}
10070

101-
match *status {
102-
TcpStatus::Established => {
103-
if tcp_header.fin {
104-
*status = TcpStatus::FinWait;
105-
}
106-
}
107-
TcpStatus::FinWait => {
108-
if tcp_header.fin {
109-
*status = TcpStatus::LastAck;
110-
}
111-
}
112-
TcpStatus::LastAck => {
113-
if tcp_header.ack {
114-
tcp_map_lock.remove(&saddr);
115-
}
116-
}
117-
}
118-
if tcp_header.rst {
119-
tcp_map_lock.remove(&saddr);
120-
}
71+
if let Some((_, status)) = tcp_map_lock.get_mut(&saddr) {
72+
ip_header.source = fake_saddr.octets();
73+
ip_header.destination = fake_daddr.ip().octets();
74+
ip_header.header_checksum = ip_header.calc_header_checksum()?;
75+
76+
tcp_header.source_port = saddr.port();
77+
tcp_header.destination_port = fake_daddr.port();
78+
tcp_header.checksum = tcp_header.calc_checksum_ipv4(&ip_header, payload)?;
79+
80+
(status, saddr)
81+
} else if let Some((SocketAddr::V4(origin_daddr), status)) =
82+
tcp_map_lock.get_mut(&daddr)
83+
{
84+
ip_header.source = origin_daddr.ip().octets();
85+
ip_header.destination = fake_daddr.ip().octets();
86+
ip_header.header_checksum = ip_header.calc_header_checksum()?;
87+
88+
tcp_header.source_port = origin_daddr.port();
89+
tcp_header.destination_port = daddr.port();
90+
tcp_header.checksum = tcp_header.calc_checksum_ipv4(&ip_header, payload)?;
91+
92+
(status, saddr)
93+
} else {
94+
return Ok(None);
95+
}
96+
}
97+
(IpHeader::Version6(ip_header), _, Some((fake_saddr, fake_daddr))) => {
98+
let saddr = (fake_saddr.octets(), tcp_header.source_port).into();
99+
let daddr = (ip_header.destination, tcp_header.destination_port).into();
121100

122-
builder
123-
} else {
124-
return Ok(());
125-
}
101+
// syn
102+
if tcp_header.syn && !tcp_header.ack {
103+
tcp_map_lock.insert(saddr, (daddr, TcpStatus::Established));
126104
}
127-
(IpHeader::Version6(ip_header), _, Some((fake_saddr, fake_daddr))) => {
128-
let saddr = (fake_saddr.octets(), tcp_header.source_port).into();
129-
let daddr = (ip_header.destination, tcp_header.destination_port).into();
130-
let mut tcp_map_lock = self.tcp_map.lock().unwrap();
131105

132-
if tcp_header.syn && !tcp_header.ack {
133-
tcp_map_lock.insert(saddr, (daddr, TcpStatus::Established));
134-
}
106+
if let Some((_, status)) = tcp_map_lock.get_mut(&saddr) {
107+
ip_header.source = fake_saddr.octets();
108+
ip_header.destination = fake_daddr.ip().octets();
135109

136-
if let Some((_, status)) = tcp_map_lock.get_mut(&saddr) {
137-
match *status {
138-
TcpStatus::Established => {
139-
if tcp_header.fin {
140-
*status = TcpStatus::FinWait;
141-
}
142-
}
143-
TcpStatus::FinWait => {
144-
if tcp_header.fin {
145-
*status = TcpStatus::LastAck;
146-
}
147-
}
148-
TcpStatus::LastAck => {
149-
if tcp_header.ack {
150-
tcp_map_lock.remove(&saddr);
151-
}
152-
}
153-
}
154-
if tcp_header.rst {
155-
tcp_map_lock.remove(&saddr);
156-
}
110+
tcp_header.source_port = saddr.port();
111+
tcp_header.destination_port = fake_daddr.port();
112+
tcp_header.checksum = tcp_header.calc_checksum_ipv6(&ip_header, payload)?;
157113

158-
PacketBuilder::ipv6(
159-
fake_saddr.octets(),
160-
fake_daddr.ip().octets(),
161-
ip_header.hop_limit,
162-
)
163-
.tcp(
164-
saddr.port(),
165-
fake_daddr.port(),
166-
tcp_header.sequence_number,
167-
tcp_header.window_size,
168-
)
169-
} else if let Some((SocketAddr::V6(origin_daddr), status)) =
170-
tcp_map_lock.get_mut(&daddr)
171-
{
172-
let builder = PacketBuilder::ipv6(
173-
origin_daddr.ip().octets(),
174-
fake_daddr.ip().octets(),
175-
ip_header.hop_limit,
176-
)
177-
.tcp(
178-
origin_daddr.port(),
179-
daddr.port(),
180-
tcp_header.sequence_number,
181-
tcp_header.window_size,
182-
);
114+
(status, saddr)
115+
} else if let Some((SocketAddr::V6(origin_daddr), status)) =
116+
tcp_map_lock.get_mut(&daddr)
117+
{
118+
ip_header.source = origin_daddr.ip().octets();
119+
ip_header.destination = fake_daddr.ip().octets();
183120

184-
match *status {
185-
TcpStatus::Established => {
186-
if tcp_header.fin {
187-
*status = TcpStatus::FinWait;
188-
}
189-
}
190-
TcpStatus::FinWait => {
191-
if tcp_header.fin {
192-
*status = TcpStatus::LastAck;
193-
}
194-
}
195-
TcpStatus::LastAck => {
196-
if tcp_header.ack {
197-
tcp_map_lock.remove(&saddr);
198-
}
199-
}
200-
}
201-
if tcp_header.rst {
202-
tcp_map_lock.remove(&saddr);
203-
}
121+
tcp_header.source_port = origin_daddr.port();
122+
tcp_header.destination_port = daddr.port();
123+
tcp_header.checksum = tcp_header.calc_checksum_ipv6(&ip_header, payload)?;
204124

205-
builder
206-
} else {
207-
return Ok(());
208-
}
125+
(status, saddr)
126+
} else {
127+
return Ok(None);
209128
}
210-
_ => return Ok(()),
211129
}
130+
_ => return Ok(None),
212131
};
213-
if tcp_header.ns {
214-
builder = builder.ns();
215-
}
216-
if tcp_header.fin {
217-
builder = builder.fin();
218-
}
219-
if tcp_header.syn {
220-
builder = builder.syn();
221-
}
222-
if tcp_header.rst {
223-
builder = builder.rst();
224-
}
225-
if tcp_header.psh {
226-
builder = builder.psh();
227-
}
228-
if tcp_header.ack {
229-
builder = builder.ack(tcp_header.acknowledgment_number);
230-
}
231-
if tcp_header.urg {
232-
builder = builder.urg(tcp_header.urgent_pointer);
233-
}
234-
if tcp_header.ece {
235-
builder = builder.ece();
236-
}
237-
if tcp_header.cwr {
238-
builder = builder.cwr();
239-
}
240-
builder = builder
241-
.options_raw(tcp_header.options())
242-
.map_err(|e| format!("{:?}", e))?;
243132

244-
let mut buf = Vec::with_capacity(builder.size(payload.len()));
245-
builder.write(&mut buf, payload)?;
133+
// status
134+
if tcp_header.rst || (tcp_header.ack && *status == TcpStatus::LastAck) {
135+
tcp_map_lock.remove(&saddr);
136+
} else if tcp_header.fin {
137+
if *status == TcpStatus::Established {
138+
*status = TcpStatus::FinWait;
139+
} else if *status == TcpStatus::FinWait {
140+
*status = TcpStatus::LastAck;
141+
}
142+
}
246143

247-
self.write(&buf).await?;
144+
// generate buf
145+
let mut buf = Vec::with_capacity(buf_len);
146+
ip_header.write(&mut buf)?;
147+
tcp_header.write(&mut buf)?;
148+
buf.extend(payload);
248149

249-
Ok(())
150+
Ok(Some(buf))
250151
}
251152
}
252153

@@ -260,9 +161,9 @@ pub(crate) enum TcpStatus {
260161
// ip tuntap add mode tun tun123
261162
// ifconfig tun123 inet 10.1.2.3 netmask 255.255.255.0 up
262163
//
263-
// cargo test tcp::t1 -- --nocapture
164+
// cargo test --package stn_tun tcp::t1 -- --nocapture
264165
//
265-
// curl --interface tun123 8.8.8.8
166+
// curl --interface tun123 1.2.3.4
266167
#[tokio::test]
267168
async fn t1() {
268169
use tokio::io::AsyncReadExt;

stn_tun/src/tun.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ impl Tun {
5656
let nread = self.read(&mut buf).await?;
5757

5858
// parse
59-
let ph = PacketHeaders::from_ip_slice(&buf[..nread])?;
60-
let ip_header = match &ph.ip {
59+
let mut ph = PacketHeaders::from_ip_slice(&buf[..nread])?;
60+
let ip_header = match &mut ph.ip {
6161
Some(s) => s,
6262
None => Err(io::Error::new(
6363
io::ErrorKind::InvalidData,
@@ -68,10 +68,11 @@ impl Tun {
6868
// dispatch
6969
match ph.transport {
7070
Some(etherparse::TransportHeader::Udp(udp_header)) => {
71-
self.handle_udp(ip_header, &udp_header, ph.payload).await?;
71+
self.handle_udp(ip_header, udp_header, ph.payload).await?;
7272
}
7373
Some(etherparse::TransportHeader::Tcp(tcp_header)) => {
74-
self.handle_tcp(ip_header, &tcp_header, ph.payload).await?;
74+
self.handle_tcp(ip_header, tcp_header, ph.payload, nread)
75+
.await?;
7576
}
7677
_ => continue,
7778
}

0 commit comments

Comments
 (0)