diff --git a/src/client/legacy/connect/mod.rs b/src/client/legacy/connect/mod.rs index bd00baa..0d3fc92 100644 --- a/src/client/legacy/connect/mod.rs +++ b/src/client/legacy/connect/mod.rs @@ -74,6 +74,8 @@ pub mod dns; #[cfg(feature = "tokio")] mod http; +pub mod proxy; + pub(crate) mod capture; pub use capture::{capture_connection, CaptureConnection}; diff --git a/src/client/legacy/connect/proxy/mod.rs b/src/client/legacy/connect/proxy/mod.rs new file mode 100644 index 0000000..b7a7c14 --- /dev/null +++ b/src/client/legacy/connect/proxy/mod.rs @@ -0,0 +1,5 @@ +//! Proxy helpers + +mod tunnel; + +pub use self::tunnel::Tunnel; diff --git a/src/client/legacy/connect/proxy/tunnel.rs b/src/client/legacy/connect/proxy/tunnel.rs new file mode 100644 index 0000000..3c22b48 --- /dev/null +++ b/src/client/legacy/connect/proxy/tunnel.rs @@ -0,0 +1,181 @@ +use std::future::Future; +use std::marker::{PhantomData, Unpin}; +use std::pin::Pin; +use std::task::{self, Poll}; + +use http::{HeaderValue, Uri}; +use hyper::rt::{Read, Write}; +use pin_project_lite::pin_project; +use tower_service::Service; + +/// Tunnel Proxy via HTTP CONNECT +#[derive(Debug)] +pub struct Tunnel { + auth: Option, + inner: C, + proxy_dst: Uri, +} + +#[derive(Debug)] +pub enum TunnelError { + Inner(C), + Io(std::io::Error), + MissingHost, + ProxyAuthRequired, + ProxyHeadersTooLong, + TunnelUnexpectedEof, + TunnelUnsuccessful, +} + +pin_project! { + // Not publicly exported (so missing_docs doesn't trigger). + // + // We return this `Future` instead of the `Pin>` directly + // so that users don't rely on it fitting in a `Pin>` slot + // (and thus we can change the type in the future). + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct Tunneling { + #[pin] + fut: BoxTunneling, + _marker: PhantomData, + } +} + +type BoxTunneling = Pin>> + Send>>; + +impl Tunnel { + /// Create a new Tunnel service. + pub fn new(proxy_dst: Uri, connector: C) -> Self { + Self { + auth: None, + inner: connector, + proxy_dst, + } + } + + /// Add `proxy-authorization` header value to the CONNECT request. + pub fn with_auth(mut self, mut auth: HeaderValue) -> Self { + // just in case the user forgot + auth.set_sensitive(true); + self.auth = Some(auth); + self + } +} + +impl Service for Tunnel +where + C: Service, + C::Future: Send + 'static, + C::Response: Read + Write + Unpin + Send + 'static, + C::Error: Send + 'static, +{ + type Response = C::Response; + type Error = TunnelError; + type Future = Tunneling; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + futures_util::ready!(self.inner.poll_ready(cx)).map_err(TunnelError::Inner)?; + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let connecting = self.inner.call(self.proxy_dst.clone()); + + Tunneling { + fut: Box::pin(async move { + let conn = connecting.await.map_err(TunnelError::Inner)?; + tunnel( + conn, + dst.host().ok_or(TunnelError::MissingHost)?, + dst.port().map(|p| p.as_u16()).unwrap_or(443), + None, + None, + ) + .await + }), + _marker: PhantomData, + } + } +} + +impl Future for Tunneling +where + F: Future>, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + self.project().fut.poll(cx) + } +} + +async fn tunnel( + mut conn: T, + host: &str, + port: u16, + user_agent: Option, + auth: Option, +) -> Result> +where + T: Read + Write + Unpin, +{ + let mut buf = format!( + "\ + CONNECT {host}:{port} HTTP/1.1\r\n\ + Host: {host}:{port}\r\n\ + " + ) + .into_bytes(); + + // user-agent + if let Some(user_agent) = user_agent { + buf.extend_from_slice(b"User-Agent: "); + buf.extend_from_slice(user_agent.as_bytes()); + buf.extend_from_slice(b"\r\n"); + } + + // proxy-authorization + if let Some(value) = auth { + //log::debug!("tunnel to {host}:{port} using basic auth"); + buf.extend_from_slice(b"Proxy-Authorization: "); + buf.extend_from_slice(value.as_bytes()); + buf.extend_from_slice(b"\r\n"); + } + + // headers end + buf.extend_from_slice(b"\r\n"); + + crate::rt::write_all(&mut conn, &buf) + .await + .map_err(TunnelError::Io)?; + + let mut buf = [0; 8192]; + let mut pos = 0; + + loop { + let n = crate::rt::read(&mut conn, &mut buf[pos..]) + .await + .map_err(TunnelError::Io)?; + + if n == 0 { + return Err(TunnelError::TunnelUnexpectedEof); + } + pos += n; + + let recvd = &buf[..pos]; + if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") { + if recvd.ends_with(b"\r\n\r\n") { + return Ok(conn); + } + if pos == buf.len() { + return Err(TunnelError::ProxyHeadersTooLong); + } + // else read more + } else if recvd.starts_with(b"HTTP/1.1 407") { + return Err(TunnelError::ProxyAuthRequired); + } else { + return Err(TunnelError::TunnelUnsuccessful); + } + } +} diff --git a/src/rt/io.rs b/src/rt/io.rs new file mode 100644 index 0000000..0ce3ea9 --- /dev/null +++ b/src/rt/io.rs @@ -0,0 +1,33 @@ +use std::marker::Unpin; +use std::pin::Pin; +use std::task::Poll; + +use futures_util::future; +use futures_util::ready; +use hyper::rt::{Read, ReadBuf, Write}; + +pub(crate) async fn read(io: &mut T, buf: &mut [u8]) -> Result +where + T: Read + Unpin, +{ + future::poll_fn(move |cx| { + let mut buf = ReadBuf::new(buf); + ready!(Pin::new(&mut *io).poll_read(cx, buf.unfilled()))?; + Poll::Ready(Ok(buf.filled().len())) + }) + .await +} + +pub(crate) async fn write_all(io: &mut T, buf: &[u8]) -> Result<(), std::io::Error> +where + T: Write + Unpin, +{ + let mut n = 0; + future::poll_fn(move |cx| { + while n < buf.len() { + n += ready!(Pin::new(&mut *io).poll_write(cx, &buf[n..])?); + } + Poll::Ready(Ok(())) + }) + .await +} diff --git a/src/rt/mod.rs b/src/rt/mod.rs index 3ed8628..71363cc 100644 --- a/src/rt/mod.rs +++ b/src/rt/mod.rs @@ -1,5 +1,10 @@ //! Runtime utilities +#[cfg(feature = "client-legacy")] +mod io; +#[cfg(feature = "client-legacy")] +pub(crate) use self::io::{read, write_all}; + #[cfg(feature = "tokio")] pub mod tokio; diff --git a/tests/proxy.rs b/tests/proxy.rs new file mode 100644 index 0000000..f828bc1 --- /dev/null +++ b/tests/proxy.rs @@ -0,0 +1,37 @@ +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tower_service::Service; + +use hyper_util::client::legacy::connect::{proxy::Tunnel, HttpConnector}; + +#[cfg(not(miri))] +#[tokio::test] +async fn test_tunnel_works() { + let tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = tcp.local_addr().expect("local_addr"); + + let proxy_dst = format!("http://{}", addr).parse().expect("uri"); + let mut connector = Tunnel::new(proxy_dst, HttpConnector::new()); + let t1 = tokio::spawn(async move { + let _conn = connector + .call("https://hyper.rs".parse().unwrap()) + .await + .expect("tunnel"); + }); + + let t2 = tokio::spawn(async move { + let (mut io, _) = tcp.accept().await.expect("accept"); + let mut buf = [0u8; 64]; + let n = io.read(&mut buf).await.expect("read 1"); + assert_eq!( + &buf[..n], + b"CONNECT hyper.rs:443 HTTP/1.1\r\nHost: hyper.rs:443\r\n\r\n" + ); + io.write_all(b"HTTP/1.1 200 OK\r\n\r\n") + .await + .expect("write 1"); + }); + + t1.await.expect("task 1"); + t2.await.expect("task 2"); +}