diff --git a/Cargo.toml b/Cargo.toml index 9b87554..0f5f578 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ tracing-instrument = [] chrono = { version = "0.4.38", features = ["serde"] } reqwest = { version = "0.12.12", default-features = false, features = [ "json", + "zstd", "rustls-tls-native-roots", ] } serde = { version = "1.0.217", features = ["derive", "rc"] } @@ -34,6 +35,7 @@ target-lexicon = "0.13.1" is_ci = "1.2.0" sys-locale = "0.3.2" iana-time-zone = "0.1.61" +async-compression = { version = "0.4.18", features = ["zstd", "tokio"] } [dev-dependencies] tokio-test = "0.4.4" diff --git a/src/checkin.rs b/src/checkin.rs index 99f3b60..d18709b 100644 --- a/src/checkin.rs +++ b/src/checkin.rs @@ -8,9 +8,16 @@ pub(crate) type CoherentFeatureFlags = HashMap FeatureFacts { let mut feature_facts = Map::new(); @@ -44,7 +51,6 @@ pub struct Feature { #[cfg(test)] mod test { - #[test] fn test_parse() { let json = r#" diff --git a/src/compression_set.rs b/src/compression_set.rs new file mode 100644 index 0000000..68e83a1 --- /dev/null +++ b/src/compression_set.rs @@ -0,0 +1,168 @@ +use serde::Deserialize; +use tokio::io::AsyncWriteExt; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) struct CompressionSet { + zstd: bool, +} + +impl CompressionSet { + pub(crate) fn delete(&mut self, algo: &CompressionAlgorithm) { + match algo { + CompressionAlgorithm::Identity => { + // noop + } + CompressionAlgorithm::Zstd => { + self.zstd = false; + } + } + } + + pub(crate) fn into_iter(self) -> std::vec::IntoIter { + let mut algos = Vec::with_capacity(2); + if self.zstd { + algos.push(CompressionAlgorithm::Zstd); + } + + algos.push(CompressionAlgorithm::Identity); + + algos.into_iter() + } +} + +impl std::default::Default for CompressionSet { + fn default() -> Self { + Self { zstd: true } + } +} + +impl<'de> Deserialize<'de> for CompressionSet { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let algos: Vec<_> = Vec::::deserialize(deserializer)? + .into_iter() + .filter_map( + |v| match serde_json::from_value::(v) { + Ok(v) => Some(v), + Err(e) => { + tracing::trace!(%e, "Unsupported compression algorithm"); + None + } + }, + ) + .collect(); + + if algos.is_empty() { + return Ok(CompressionSet { zstd: false }); + } + + let mut set = CompressionSet { zstd: false }; + + for algo in algos.into_iter() { + match algo { + CompressionAlgorithm::Zstd => { + set.zstd = true; + } + CompressionAlgorithm::Identity => { + // noop + } + } + } + + Ok(set) + } +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Hash)] +#[serde(rename_all = "kebab-case")] +pub(crate) enum CompressionAlgorithm { + Identity, + Zstd, +} + +impl CompressionAlgorithm { + pub(crate) fn content_encoding(&self) -> Option { + match self { + CompressionAlgorithm::Identity => None, + CompressionAlgorithm::Zstd => Some("zstd".to_string()), + } + } + + pub(crate) async fn compress(&self, r: &[u8]) -> Result, std::io::Error> { + match self { + CompressionAlgorithm::Identity => Ok(r.into()), + CompressionAlgorithm::Zstd => { + let mut output: Vec = vec![]; + let mut encoder = async_compression::tokio::write::ZstdEncoder::new(&mut output); + encoder.write_all(r).await?; + encoder.shutdown().await?; + + Ok(output) + } + } + } +} + +#[cfg(test)] +mod test { + use super::CompressionSet; + + #[test] + fn test_parse_compression_empty_defaults_to_identity() { + let json = r#" + [ + ] + "#; + + assert_eq!( + serde_json::from_str::(json).unwrap(), + CompressionSet { zstd: false } + ); + } + + #[test] + fn test_parse_compression_few() { + let json = r#" + [ + "zstd", + "identity" + ] + "#; + + assert_eq!( + serde_json::from_str::(json).unwrap(), + CompressionSet { zstd: true } + ); + } + + #[test] + fn test_parse_compression_zstd_not_identity() { + let json = r#" + [ + "zstd" + ] + "#; + + assert_eq!( + serde_json::from_str::(json).unwrap(), + CompressionSet { zstd: true } + ); + } + + #[test] + fn test_parse_compression_zstd_with_bogus() { + let json = r#" + [ + "zstd", + "abc123" + ] + "#; + + assert_eq!( + serde_json::from_str::(json).unwrap(), + CompressionSet { zstd: true } + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index ef4d976..2a0b9ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod builder; pub mod checkin; mod collator; +mod compression_set; mod configuration_proxy; mod ds_correlation; mod identity; diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 5e1982f..efe3b3c 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -109,6 +109,7 @@ impl Transport for Transports { match self { Self::None => Ok(crate::checkin::Checkin { options: std::collections::HashMap::new(), + ..Default::default() }), Self::File(t) => Ok(t.checkin(session_properties).await?), Self::Http(t) => Ok(t.checkin(session_properties).await?), diff --git a/src/transport/srv_http.rs b/src/transport/srv_http.rs index 8e34c8b..75fdd39 100644 --- a/src/transport/srv_http.rs +++ b/src/transport/srv_http.rs @@ -5,6 +5,8 @@ use reqwest::Certificate; use reqwest::Url; use tracing::Instrument; +use crate::checkin::Checkin; +use crate::checkin::ServerOptions; use crate::submitter::Batch; use crate::Map; @@ -19,6 +21,7 @@ type Resolver = trust_dns_resolver::AsyncResolver< #[derive(Clone)] pub(crate) struct SrvHttpTransport { srv: Arc>, + server_options: Arc>, reqwest: reqwest::Client, } impl SrvHttpTransport { @@ -59,6 +62,9 @@ impl SrvHttpTransport { Ok(SrvHttpTransport { srv: Arc::new(srv), reqwest: builder.build()?, + server_options: Arc::new(tokio::sync::RwLock::new( + crate::checkin::ServerOptions::default(), + )), }) } } @@ -70,31 +76,20 @@ impl Transport for SrvHttpTransport { async fn submit<'b>(&mut self, batch: Batch<'b>) -> Result<(), Self::Error> { let payload = serde_json::to_string(&batch)?; let reqwest = self.reqwest.clone(); + let server_opts = self.server_options.clone(); let resp = self .srv .execute(move |mut url| { - let payload = payload.clone(); + let payload: Vec = payload.as_bytes().into(); let reqwest = reqwest.clone(); + let server_opts = server_opts.clone(); url.set_path("/events/batch"); - let span = tracing::trace_span!("submission attempt", host = url.to_string()); - - async move { - tracing::trace!("Submitting event logs."); - - reqwest - .post(url) - .header( - http::header::CONTENT_TYPE, - crate::transport::APPLICATION_JSON, - ) - .body(payload) - .send() - .await - .map_err(SrvHttpTransportError::from) - } - .instrument(span) + + let span = tracing::debug_span!("submission", %url); + + perform_request(reqwest, url, payload, server_opts).instrument(span) }) .await?; @@ -112,42 +107,91 @@ impl Transport for SrvHttpTransport { ) -> Result { let payload = serde_json::to_string(&session_properties)?; let reqwest = self.reqwest.clone(); + let server_opts = self.server_options.clone(); + let resp = self .srv .execute(move |mut url| { - let payload = payload.clone(); + let payload: Vec = payload.as_bytes().into(); let reqwest = reqwest.clone(); + let server_opts = server_opts.clone(); + url.set_path("check-in"); - let span = tracing::trace_span!("check-in attempt", host = url.to_string()); - - async move { - tracing::trace!("Fetching check-in configuration."); - - reqwest - .post(url) - .header( - http::header::CONTENT_TYPE, - crate::transport::APPLICATION_JSON, - ) - .body(payload) - .send() - .await - .map_err(SrvHttpTransportError::from) - } - .instrument(span) + let span = tracing::trace_span!("check-in attempt", %url); + + perform_request(reqwest, url, payload, server_opts).instrument(span) }) .await?; - Ok(resp.json().await?) + let checkin: Checkin = resp.json().await?; + + // Update server options to sync up compression options + { + let mut opts = self.server_options.write().await; + *opts = checkin.server_options.clone(); + } + + Ok(checkin) } } +#[tracing::instrument(skip(reqwest, payload, server_opts))] +async fn perform_request( + reqwest: reqwest::Client, + url: url::Url, + payload: Vec, + server_opts: Arc>, +) -> Result { + let algos = server_opts.read().await.compression_algorithms.into_iter(); + + for compression_algo in algos { + let span = tracing::debug_span!("requesting", ?compression_algo); + + let mut req = reqwest + .post(url.clone()) + .header( + http::header::CONTENT_TYPE, + crate::transport::APPLICATION_JSON, + ) + .body(compression_algo.compress(&payload).await?); + + if let Some(encoding) = compression_algo.content_encoding() { + req = req.header(http::header::CONTENT_ENCODING, encoding); + } + + tracing::trace!(parent: &span, "Requesting"); + match req.send().instrument(span.clone()).await { + Ok(resp) if resp.status() == http::StatusCode::UNSUPPORTED_MEDIA_TYPE => { + tracing::debug!( + ?compression_algo, + "Disabling compression algorithm because it is unsupported" + ); + server_opts + .write() + .await + .compression_algorithms + .delete(&compression_algo); + } + + Err(e) => { + return Err(SrvHttpTransportError::from(e)); + } + Ok(resp) => return Ok(resp), + } + } + + Err(SrvHttpTransportError::NoCompressionMode) +} + #[derive(thiserror::Error, Debug)] pub enum SrvHttpTransportError { #[error(transparent)] SrvError(#[from] detsys_srv::Error<::Error>), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] Reqwest(#[from] reqwest::Error), @@ -159,4 +203,7 @@ pub enum SrvHttpTransportError { #[error(transparent)] UrlParse(#[from] url::ParseError), + + #[error("The server has rejected all of our compression modes")] + NoCompressionMode, } diff --git a/src/worker.rs b/src/worker.rs index c09f9fe..dee4820 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -19,7 +19,14 @@ pub struct Worker { impl Worker { #[cfg_attr( feature = "tracing-instrument", - tracing::instrument(skip(system_snapshotter, transport)) + tracing::instrument(skip( + distinct_id, + device_id, + facts, + groups, + system_snapshotter, + transport + )) )] pub(crate) async fn new( distinct_id: Option,