diff --git a/Cargo.toml b/Cargo.toml index c55f8ff..a70579f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ snafu = "0.4.1" csv = "1" serde = "1" serde_derive = "1" +tokio = { version = "1", features = ["full"] } diff --git a/src/lib.rs b/src/lib.rs index 0c376d6..54c30c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,13 +6,9 @@ extern crate log; use snafu::Snafu; use std::error::Error; -use std::io::copy; -use std::io::prelude::*; -use std::net::{ - Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, - ToSocketAddrs, -}; -use std::thread; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; /// Version of socks const SOCKS_VERSION: u8 = 0x05; @@ -118,7 +114,7 @@ pub struct Merino { impl Merino { /// Create a new Merino instance - pub fn new( + pub async fn new( port: u16, ip: &str, auth_methods: Vec, @@ -126,48 +122,46 @@ impl Merino { ) -> Result> { info!("Listening on {}:{}", ip, port); Ok(Merino { - listener: TcpListener::bind((ip, port))?, + listener: TcpListener::bind((ip, port)).await?, auth_methods, users, }) } - pub fn serve(&mut self) -> Result<(), Box> { + pub async fn serve(&mut self) { info!("Serving Connections..."); - loop { - if let Ok((stream, _remote)) = self.listener.accept() { - // TODO Optimize this - let mut client = - SOCKClient::new(stream, self.users.clone(), self.auth_methods.clone()); - thread::spawn(move || { - match client.init() { - Ok(_) => {} - Err(error) => { - error!("Error! {}", error); - let error_text = format!("{}", error); - - let response: ResponseCode; - - if error_text.contains("Host") { - response = ResponseCode::HostUnreachable; - } else if error_text.contains("Network") { - response = ResponseCode::NetworkUnreachable; - } else if error_text.contains("ttl") { - response = ResponseCode::TtlExpired - } else { - response = ResponseCode::Failure - } - - if client.error(response).is_err() { - warn!("Failed to send error code"); - }; - if client.shutdown().is_err() { - warn!("Failed to shutdown TcpStream"); - }; + while let Ok((stream, client_addr)) = self.listener.accept().await { + let users = self.users.clone(); + let auth_methods = self.auth_methods.clone(); + tokio::spawn(async move { + let mut client = SOCKClient::new(stream, users, auth_methods); + match client.init().await { + Ok(_) => {} + Err(error) => { + error!("Error! {:?}, client: {:?}", error, client_addr); + let error_text = format!("{}", error); + + let response: ResponseCode; + + if error_text.contains("Host") { + response = ResponseCode::HostUnreachable; + } else if error_text.contains("Network") { + response = ResponseCode::NetworkUnreachable; + } else if error_text.contains("ttl") { + response = ResponseCode::TtlExpired + } else { + response = ResponseCode::Failure } - }; - }); - } + + if let Err(e) = client.error(response).await { + warn!("Failed to send error code: {:?}", e); + }; + if let Err(e) = client.shutdown().await { + warn!("Failed to shutdown TcpStream: {:?}", e); + }; + } + }; + }); } } } @@ -198,22 +192,22 @@ impl SOCKClient { } /// Send an error to the client - pub fn error(&mut self, r: ResponseCode) -> Result<(), Box> { - self.stream.write_all(&[5, r as u8])?; + pub async fn error(&mut self, r: ResponseCode) -> Result<(), Box> { + self.stream.write_all(&[5, r as u8]).await?; Ok(()) } /// Shutdown a client - pub fn shutdown(&mut self) -> Result<(), Box> { - self.stream.shutdown(Shutdown::Both)?; + pub async fn shutdown(&mut self) -> Result<(), Box> { + self.stream.shutdown().await?; Ok(()) } - fn init(&mut self) -> Result<(), Box> { + async fn init(&mut self) -> Result<(), Box> { debug!("New connection from: {}", self.stream.peer_addr()?.ip()); let mut header = [0u8; 2]; // Read a byte from the stream and determine the version being requested - self.stream.read_exact(&mut header)?; + self.stream.read_exact(&mut header).await?; self.socks_version = header[0]; self.auth_nmethods = header[1]; @@ -224,26 +218,26 @@ impl SOCKClient { self.auth_nmethods ); - // Handle SOCKS4 requests - if header[0] != SOCKS_VERSION { - warn!("Init: Unsupported version: SOCKS{}", self.socks_version); - self.shutdown()?; - } - // Valid SOCKS5 - else { - // Authenticate w/ client - self.auth()?; - // Handle requests - self.handle_client()?; + match self.socks_version { + SOCKS_VERSION => { + // Authenticate w/ client + self.auth().await?; + // Handle requests + self.handle_client().await?; + } + _ => { + warn!("Init: Unsupported version: SOCKS{}", self.socks_version); + self.shutdown().await?; + } } Ok(()) } - fn auth(&mut self) -> Result<(), Box> { + async fn auth(&mut self) -> Result<(), Box> { debug!("Authenticating w/ {}", self.stream.peer_addr()?.ip()); // Get valid auth methods - let methods = self.get_avalible_methods()?; + let methods = self.get_avalible_methods().await?; trace!("methods: {:?}", methods); let mut response = [0u8; 2]; @@ -256,12 +250,12 @@ impl SOCKClient { response[1] = AuthMethods::UserPass as u8; debug!("Sending USER/PASS packet"); - self.stream.write_all(&response)?; + self.stream.write_all(&response).await?; let mut header = [0u8; 2]; // Read a byte from the stream and determine the version being requested - self.stream.read_exact(&mut header)?; + self.stream.read_exact(&mut header).await?; // debug!("Auth Header: [{}, {}]", header[0], header[1]); @@ -270,11 +264,11 @@ impl SOCKClient { let mut username = vec![0; ulen]; - self.stream.read_exact(&mut username)?; + self.stream.read_exact(&mut username).await?; // Password Parsing let mut plen = [0u8; 1]; - self.stream.read_exact(&mut plen)?; + self.stream.read_exact(&mut plen).await?; let mut password = vec![0; plen[0] as usize]; @@ -283,7 +277,7 @@ impl SOCKClient { password.push(0); } - self.stream.read_exact(&mut password)?; + self.stream.read_exact(&mut password).await?; let username_str = String::from_utf8(username)?; let password_str = String::from_utf8(password)?; @@ -297,14 +291,14 @@ impl SOCKClient { if self.authed(&user) { debug!("Access Granted. User: {}", user.username); let response = [1, ResponseCode::Success as u8]; - self.stream.write_all(&response)?; + self.stream.write_all(&response).await?; } else { debug!("Access Denied. User: {}", user.username); let response = [1, ResponseCode::Failure as u8]; - self.stream.write_all(&response)?; + self.stream.write_all(&response).await?; // Shutdown - self.shutdown()?; + self.shutdown().await?; } Ok(()) @@ -312,24 +306,24 @@ impl SOCKClient { // set the default auth method (no auth) response[1] = AuthMethods::NoAuth as u8; debug!("Sending NOAUTH packet"); - self.stream.write_all(&response)?; + self.stream.write_all(&response).await?; Ok(()) } else { warn!("Client has no suitable Auth methods!"); response[1] = AuthMethods::NoMethods as u8; - self.stream.write_all(&response)?; - self.shutdown()?; + self.stream.write_all(&response).await?; + self.shutdown().await?; Err(Box::new(ResponseCode::Failure)) } } /// Handles a client - pub fn handle_client(&mut self) -> Result<(), Box> { + pub async fn handle_client(&mut self) -> Result> { debug!("Handling requests for {}", self.stream.peer_addr()?.ip()); // Read request // loop { // Parse Request - let req = SOCKSReq::from_stream(&mut self.stream)?; + let req = SOCKSReq::from_stream(&mut self.stream).await?; if req.addr_type == AddrType::V6 {} @@ -353,7 +347,7 @@ impl SOCKClient { trace!("Connecting to: {:?}", sock_addr); - let target = TcpStream::connect(&sock_addr[..])?; + let mut target = TcpStream::connect(&sock_addr[..]).await?; trace!("Connected!"); @@ -404,50 +398,42 @@ impl SOCKClient { 0, 0, ]) - .unwrap(); - - // Copy it all - let mut outbound_in = target.try_clone()?; - let mut outbound_out = target.try_clone()?; - let mut inbound_in = self.stream.try_clone()?; - let mut inbound_out = self.stream.try_clone()?; - - // Download Thread - thread::spawn(move || { - copy(&mut outbound_in, &mut inbound_out).unwrap_or_else(|e| { - warn!("error while proxy download stream: {:?}", e); - 0 - }); - outbound_in.shutdown(Shutdown::Read).unwrap_or(()); - inbound_out.shutdown(Shutdown::Write).unwrap_or(()); - }); - - // Upload Thread - thread::spawn(move || { - copy(&mut inbound_in, &mut outbound_out).unwrap_or_else(|e| { - warn!("error while proxy upload stream: {:?}", e); - 0 - }); - inbound_in.shutdown(Shutdown::Read).unwrap_or(()); - outbound_out.shutdown(Shutdown::Write).unwrap_or(()); - }); + .await?; + + trace!("copy bidirectional"); + match tokio::io::copy_bidirectional(&mut self.stream, &mut target).await { + // ignore not connected for shutdown error + Err(e) if e.kind() == std::io::ErrorKind::NotConnected => Ok(0), + Err(e) => Err(Box::new(e)), + Ok((_s_to_t, t_to_s)) => { + debug!( + "{:?} bytes proxy from {:?} to {:?}:{:?}", + t_to_s, + self.stream.peer_addr()?.ip(), + displayed_addr, + req.port + ); + Ok(t_to_s as usize) + } + } } - SockCommand::Bind => {} - SockCommand::UdpAssosiate => {} + SockCommand::Bind => Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Bind not supported", + ))), + SockCommand::UdpAssosiate => Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "UdpAssosiate not supported", + ))), } - - // connected = false; - // } - - Ok(()) } /// Return the avalible methods based on `self.auth_nmethods` - fn get_avalible_methods(&mut self) -> Result, Box> { + async fn get_avalible_methods(&mut self) -> Result, Box> { let mut methods: Vec = Vec::with_capacity(self.auth_nmethods as usize); for _ in 0..self.auth_nmethods { let mut method = [0u8; 1]; - self.stream.read_exact(&mut method)?; + self.stream.read_exact(&mut method).await?; if self.auth_methods.contains(&method[0]) { methods.append(&mut method.to_vec()); } @@ -461,7 +447,7 @@ fn addr_to_socket( addr_type: &AddrType, addr: &[u8], port: u16, -) -> Result, Box> { +) -> Result, Box> { match addr_type { AddrType::V6 => { let new_addr = (0..8) @@ -536,7 +522,7 @@ struct SOCKSReq { impl SOCKSReq { /// Parse a SOCKS Req from a TcpStream - fn from_stream(stream: &mut TcpStream) -> Result> { + async fn from_stream(stream: &mut TcpStream) -> Result> { // From rfc 1928 (S4), the SOCKS request is formed as follows: // // +----+-----+-------+------+----------+----------+ @@ -562,11 +548,11 @@ impl SOCKSReq { // order let mut packet = [0u8; 4]; // Read a byte from the stream and determine the version being requested - stream.read_exact(&mut packet)?; + stream.read_exact(&mut packet).await?; if packet[0] != SOCKS_VERSION { warn!("from_stream Unsupported version: SOCKS{}", packet[0]); - stream.shutdown(Shutdown::Both)?; + stream.shutdown().await?; } // Get command @@ -574,7 +560,7 @@ impl SOCKSReq { Some(com) => Ok(com), None => { warn!("Invalid Command"); - stream.shutdown(Shutdown::Both)?; + stream.shutdown().await?; Err(ResponseCode::CommandNotSupported) } }?; @@ -585,31 +571,31 @@ impl SOCKSReq { Some(addr) => Ok(addr), None => { error!("No Addr"); - stream.shutdown(Shutdown::Both)?; + stream.shutdown().await?; Err(ResponseCode::AddrTypeNotSupported) } }?; trace!("Getting Addr"); // Get Addr from addr_type and stream - let addr: Result, Box> = match addr_type { + let addr: Result, Box> = match addr_type { AddrType::Domain => { let mut dlen = [0u8; 1]; - stream.read_exact(&mut dlen)?; + stream.read_exact(&mut dlen).await?; let mut domain = vec![0u8; dlen[0] as usize]; - stream.read_exact(&mut domain)?; + stream.read_exact(&mut domain).await?; Ok(domain) } AddrType::V4 => { let mut addr = [0u8; 4]; - stream.read_exact(&mut addr)?; + stream.read_exact(&mut addr).await?; Ok(addr.to_vec()) } AddrType::V6 => { let mut addr = [0u8; 16]; - stream.read_exact(&mut addr)?; + stream.read_exact(&mut addr).await?; Ok(addr.to_vec()) } }; @@ -618,7 +604,7 @@ impl SOCKSReq { // read DST.port let mut port = [0u8; 2]; - stream.read_exact(&mut port)?; + stream.read_exact(&mut port).await?; // Merge two u8s into u16 let port = (u16::from(port[0]) << 8) | u16::from(port[1]); diff --git a/src/main.rs b/src/main.rs index 002b516..15313b7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,7 +40,8 @@ struct Opt { users: Option, } -fn main() -> Result<(), Box> { +#[tokio::main] +async fn main() -> Result<(), Box> { println!("{}", LOGO); let opt = Opt::from_args(); @@ -91,10 +92,10 @@ fn main() -> Result<(), Box> { } // Create proxy server - let mut merino = Merino::new(opt.port, &opt.ip, auth_methods, authed_users)?; + let mut merino = Merino::new(opt.port, &opt.ip, auth_methods, authed_users).await?; // Start Proxies - merino.serve()?; + merino.serve().await; Ok(()) }