Skip to content

Commit

Permalink
Merge pull request #3 from DeterminateSystems/zstd
Browse files Browse the repository at this point in the history
Add zstd compression support
  • Loading branch information
grahamc authored Feb 11, 2025
2 parents a492732 + 72e3a37 commit f96b649
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 39 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tracing-instrument = []
chrono = { version = "0.4.38", features = ["serde"] }
reqwest = { version = "0.12.12", default-features = false, features = [
"json",
"zstd",
"rustls-tls-native-roots",
] }
serde = { version = "1.0.217", features = ["derive", "rc"] }
Expand All @@ -34,6 +35,7 @@ target-lexicon = "0.13.1"
is_ci = "1.2.0"
sys-locale = "0.3.2"
iana-time-zone = "0.1.61"
async-compression = { version = "0.4.18", features = ["zstd", "tokio"] }

[dev-dependencies]
tokio-test = "0.4.4"
Expand Down
8 changes: 7 additions & 1 deletion src/checkin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@ pub(crate) type CoherentFeatureFlags = HashMap<String, Arc<Feature<serde_json::V

#[derive(Clone, Debug, Deserialize, Default)]
pub struct Checkin {
#[serde(default)]
pub(crate) server_options: ServerOptions,
pub(crate) options: CoherentFeatureFlags,
}

#[derive(Clone, Debug, Deserialize, Default)]
pub(crate) struct ServerOptions {
pub(crate) compression_algorithms: crate::compression_set::CompressionSet,
}

impl Checkin {
pub(crate) fn as_feature_facts(&self) -> FeatureFacts {
let mut feature_facts = Map::new();
Expand Down Expand Up @@ -44,7 +51,6 @@ pub struct Feature<T: serde::de::DeserializeOwned> {

#[cfg(test)]
mod test {

#[test]
fn test_parse() {
let json = r#"
Expand Down
168 changes: 168 additions & 0 deletions src/compression_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
use serde::Deserialize;
use tokio::io::AsyncWriteExt;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct CompressionSet {
zstd: bool,
}

impl CompressionSet {
pub(crate) fn delete(&mut self, algo: &CompressionAlgorithm) {
match algo {
CompressionAlgorithm::Identity => {
// noop
}
CompressionAlgorithm::Zstd => {
self.zstd = false;
}
}
}

pub(crate) fn into_iter(self) -> std::vec::IntoIter<CompressionAlgorithm> {
let mut algos = Vec::with_capacity(2);
if self.zstd {
algos.push(CompressionAlgorithm::Zstd);
}

algos.push(CompressionAlgorithm::Identity);

algos.into_iter()
}
}

impl std::default::Default for CompressionSet {
fn default() -> Self {
Self { zstd: true }
}
}

impl<'de> Deserialize<'de> for CompressionSet {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let algos: Vec<_> = Vec::<serde_json::Value>::deserialize(deserializer)?
.into_iter()
.filter_map(
|v| match serde_json::from_value::<CompressionAlgorithm>(v) {
Ok(v) => Some(v),
Err(e) => {
tracing::trace!(%e, "Unsupported compression algorithm");
None
}
},
)
.collect();

if algos.is_empty() {
return Ok(CompressionSet { zstd: false });
}

let mut set = CompressionSet { zstd: false };

for algo in algos.into_iter() {
match algo {
CompressionAlgorithm::Zstd => {
set.zstd = true;
}
CompressionAlgorithm::Identity => {
// noop
}
}
}

Ok(set)
}
}

#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Hash)]
#[serde(rename_all = "kebab-case")]
pub(crate) enum CompressionAlgorithm {
Identity,
Zstd,
}

impl CompressionAlgorithm {
pub(crate) fn content_encoding(&self) -> Option<String> {
match self {
CompressionAlgorithm::Identity => None,
CompressionAlgorithm::Zstd => Some("zstd".to_string()),
}
}

pub(crate) async fn compress(&self, r: &[u8]) -> Result<Vec<u8>, std::io::Error> {
match self {
CompressionAlgorithm::Identity => Ok(r.into()),
CompressionAlgorithm::Zstd => {
let mut output: Vec<u8> = vec![];
let mut encoder = async_compression::tokio::write::ZstdEncoder::new(&mut output);
encoder.write_all(r).await?;
encoder.shutdown().await?;

Ok(output)
}
}
}
}

#[cfg(test)]
mod test {
use super::CompressionSet;

#[test]
fn test_parse_compression_empty_defaults_to_identity() {
let json = r#"
[
]
"#;

assert_eq!(
serde_json::from_str::<CompressionSet>(json).unwrap(),
CompressionSet { zstd: false }
);
}

#[test]
fn test_parse_compression_few() {
let json = r#"
[
"zstd",
"identity"
]
"#;

assert_eq!(
serde_json::from_str::<CompressionSet>(json).unwrap(),
CompressionSet { zstd: true }
);
}

#[test]
fn test_parse_compression_zstd_not_identity() {
let json = r#"
[
"zstd"
]
"#;

assert_eq!(
serde_json::from_str::<CompressionSet>(json).unwrap(),
CompressionSet { zstd: true }
);
}

#[test]
fn test_parse_compression_zstd_with_bogus() {
let json = r#"
[
"zstd",
"abc123"
]
"#;

assert_eq!(
serde_json::from_str::<CompressionSet>(json).unwrap(),
CompressionSet { zstd: true }
);
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod builder;
pub mod checkin;
mod collator;
mod compression_set;
mod configuration_proxy;
mod ds_correlation;
mod identity;
Expand Down
1 change: 1 addition & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl Transport for Transports {
match self {
Self::None => Ok(crate::checkin::Checkin {
options: std::collections::HashMap::new(),
..Default::default()
}),
Self::File(t) => Ok(t.checkin(session_properties).await?),
Self::Http(t) => Ok(t.checkin(session_properties).await?),
Expand Down
Loading

0 comments on commit f96b649

Please sign in to comment.