From 43891d0596aa7195dba5a34f23e8ada69d521ac6 Mon Sep 17 00:00:00 2001 From: Tomek Karwowski Date: Mon, 7 Aug 2023 20:33:58 +0200 Subject: [PATCH] feat: SetHost and Http1RequestTarget middlewares --- src/client/mod.rs | 4 + src/client/services/http1_request_target.rs | 86 +++++++++++++++++++++ src/client/services/mod.rs | 5 ++ src/client/services/set_host.rs | 52 +++++++++++++ 4 files changed, 147 insertions(+) create mode 100644 src/client/services/http1_request_target.rs create mode 100644 src/client/services/mod.rs create mode 100644 src/client/services/set_host.rs diff --git a/src/client/mod.rs b/src/client/mod.rs index e921542..29e014f 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,3 +3,7 @@ /// Legacy implementations of `connect` module and `Client` #[cfg(feature = "client-legacy")] pub mod legacy; + +/// Client services +#[cfg(any(feature = "http1", feature = "http2"))] +pub mod services; diff --git a/src/client/services/http1_request_target.rs b/src/client/services/http1_request_target.rs new file mode 100644 index 0000000..96c6c32 --- /dev/null +++ b/src/client/services/http1_request_target.rs @@ -0,0 +1,86 @@ +use http::{uri::Scheme, Method, Request, Uri}; +use hyper::service::Service; +use tracing::warn; + +/// A `Service` that normalizes the request target. +pub struct Http1RequestTarget { + inner: S, + is_proxied: bool, +} + +impl Http1RequestTarget { + /// Create a new `Http1RequestTarget` service. + pub fn new(inner: S, is_proxied: bool) -> Self { + Self { inner, is_proxied } + } +} + +impl Service> for Http1RequestTarget +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn call(&self, mut req: Request) -> Self::Future { + // CONNECT always sends authority-form, so check it first... + if req.method() == Method::CONNECT { + authority_form(req.uri_mut()); + } else if self.is_proxied { + absolute_form(req.uri_mut()); + } else { + origin_form(req.uri_mut()); + } + self.inner.call(req) + } +} + +fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = ::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path +} + +fn absolute_form(uri: &mut Uri) { + debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); + debug_assert!( + uri.authority().is_some(), + "absolute_form needs an authority" + ); + // If the URI is to HTTPS, and the connector claimed to be a proxy, + // then it *should* have tunneled, and so we don't want to send + // absolute-form in that case. + if uri.scheme() == Some(&Scheme::HTTPS) { + origin_form(uri); + } +} + +fn authority_form(uri: &mut Uri) { + if let Some(path) = uri.path_and_query() { + // `https://hyper.rs` would parse with `/` path, don't + // annoy people about that... + if path != "/" { + warn!("HTTP/1.1 CONNECT request stripping path: {:?}", path); + } + } + *uri = match uri.authority() { + Some(auth) => { + let mut parts = ::http::uri::Parts::default(); + parts.authority = Some(auth.clone()); + Uri::from_parts(parts).expect("authority is valid") + } + None => { + unreachable!("authority_form with relative uri"); + } + }; +} diff --git a/src/client/services/mod.rs b/src/client/services/mod.rs new file mode 100644 index 0000000..4a0d335 --- /dev/null +++ b/src/client/services/mod.rs @@ -0,0 +1,5 @@ +mod http1_request_target; +mod set_host; + +pub use http1_request_target::Http1RequestTarget; +pub use set_host::SetHost; diff --git a/src/client/services/set_host.rs b/src/client/services/set_host.rs new file mode 100644 index 0000000..a5435b3 --- /dev/null +++ b/src/client/services/set_host.rs @@ -0,0 +1,52 @@ +use http::{header::HOST, uri::Port, HeaderValue, Request, Uri}; +use hyper::service::Service; + +/// A `Service` that sets the `Host` header, if it's missing, based on the request URI. +pub struct SetHost { + inner: S, +} + +impl SetHost { + /// Create a new `SetHost` service. + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service> for SetHost +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn call(&self, mut req: Request) -> Self::Future { + let uri = req.uri().clone(); + req.headers_mut().entry(HOST).or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = get_non_default_port(&uri) { + let s = format!("{}:{}", hostname, port); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); + self.inner.call(req) + } +} + +fn get_non_default_port(uri: &Uri) -> Option> { + match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { + (Some(443), true) => None, + (Some(80), false) => None, + _ => uri.port(), + } +} + +fn is_schema_secure(uri: &Uri) -> bool { + uri.scheme_str() + .map(|scheme_str| matches!(scheme_str, "wss" | "https")) + .unwrap_or_default() +}