Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for newtype structs and serde flatten #110

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "wmi"
version = "0.14.5"
version = "0.15.0"
authors = ["Ohad Ravid <[email protected]>"]
edition = "2021"
license = "MIT OR Apache-2.0"
Expand Down
38 changes: 28 additions & 10 deletions src/de/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use serde::forward_to_deserialize_any;
/// Return the fields of a struct.
/// Taken directly from <https://github.com/serde-rs/serde/issues/1110>
///
pub fn struct_name_and_fields<'de, T>() -> Result<(&'static str, &'static [&'static str]), Error>
pub fn struct_name_and_fields<'de, T>(
) -> Result<(&'static str, Option<&'static [&'static str]>), Error>
where
T: Deserialize<'de>,
{
Expand All @@ -25,12 +26,13 @@ where

fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
*self.name = Some(name);
visitor.visit_newtype_struct(self)
}

Expand Down Expand Up @@ -83,7 +85,7 @@ where
validate_identifier(field)?;
}

Ok((name, fields.unwrap()))
Ok((name, fields))
}
}
}
Expand Down Expand Up @@ -132,6 +134,7 @@ fn validate_identifier<E: de::Error>(s: &str) -> Result<&str, E> {
}

#[cfg(test)]
#[allow(dead_code)]
mod tests {
use super::*;
use crate::Variant;
Expand All @@ -142,16 +145,14 @@ mod tests {
fn it_works() {
#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem {
#[allow(dead_code)]
Caption: String,
#[allow(dead_code)]
Name: String,
}

let (name, fields) = struct_name_and_fields::<Win32_OperatingSystem>().unwrap();

assert_eq!(name, "Win32_OperatingSystem");
assert_eq!(fields, ["Caption", "Name"]);
assert_eq!(fields.unwrap(), ["Caption", "Name"]);
}

#[test]
Expand All @@ -160,16 +161,34 @@ mod tests {
#[serde(rename = "Win32_OperatingSystem")]
#[serde(rename_all = "PascalCase")]
struct Win32OperatingSystem {
#[allow(dead_code)]
caption: String,
#[allow(dead_code)]
name: String,
}

let (name, fields) = struct_name_and_fields::<Win32OperatingSystem>().unwrap();

assert_eq!(name, "Win32_OperatingSystem");
assert_eq!(fields, ["Caption", "Name"]);
assert_eq!(fields.unwrap(), ["Caption", "Name"]);
}

#[test]
fn it_works_with_flatten() {
#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem_inner {
Caption: String,
Name: String,

#[serde(flatten)]
extra: HashMap<String, Variant>,
}

#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem(pub Win32_OperatingSystem_inner);

let (name, fields) = struct_name_and_fields::<Win32_OperatingSystem>().unwrap();

assert_eq!(name, "Win32_OperatingSystem");
assert_eq!(fields, None);
}

#[test]
Expand All @@ -181,7 +200,6 @@ mod tests {
#[derive(Deserialize, Debug)]
struct EvilFieldName {
#[serde(rename = "Evil\"Field\"Name")]
#[allow(dead_code)]
field: String,
}

Expand Down
81 changes: 64 additions & 17 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ where
}
};

Ok((name, fields, optional_where_clause))
Ok((name, fields.unwrap_or(&["*"]), optional_where_clause))
}

/// Quote/escape a string for WQL.
Expand Down Expand Up @@ -689,13 +689,73 @@ mod tests {
assert!(results_as_json.starts_with(r#"[{"Name":"Microsoft Windows"#));
}

#[test]
fn it_builds_correct_query_for_newtype_struct() {
#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem(pub HashMap<String, Variant>);

let query = build_query::<Win32_OperatingSystem>(None).unwrap();
let select_part = r#"SELECT * FROM Win32_OperatingSystem "#.to_owned();

assert_eq!(query, select_part);
}

#[test]
fn it_can_query_a_newtype_struct() {
let wmi_con = wmi_con();

#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem(pub HashMap<String, Variant>);

let results = wmi_con.query::<Win32_OperatingSystem>().unwrap();

for os in results {
match os.0.get("Caption").unwrap() {
Variant::String(s) => assert!(s.starts_with("Microsoft Windows")),
_ => assert!(false),
}
}
}

#[test]
fn con_query_flatten() {
// Due to serde#1346, it's not possible to use `query` with a struct that has a `flatten` field,
// so we need to either use `raw_query` or a newtype struct.

let wmi_con = wmi_con();

#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem {
Caption: String,
Name: String,

#[serde(flatten)]
extra: HashMap<String, Variant>,
}

let system: Vec<Win32_OperatingSystem> = wmi_con
.raw_query("SELECT * FROM Win32_OperatingSystem")
.unwrap();
let system = system.into_iter().next().unwrap();
assert_ne!(system.Name, "");
assert!(system.extra.contains_key("BuildNumber"));

#[derive(Deserialize, Debug)]
#[serde(rename = "Win32_OperatingSystem")]
struct Win32_OperatingSystemWrapper(pub Win32_OperatingSystem);
let system = wmi_con.query::<Win32_OperatingSystemWrapper>().unwrap();
let system = system.into_iter().next().unwrap();

assert_ne!(system.0.Name, "");
assert!(system.0.extra.contains_key("BuildNumber"));
}

#[test]
fn it_fails_gracefully_when_querying_a_struct() {
let wmi_con = wmi_con();

#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem {
#[allow(dead_code)]
NoSuchField: String,
}

Expand All @@ -708,7 +768,6 @@ mod tests {
fn it_builds_correct_query_without_filters() {
#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem {
#[allow(dead_code)]
Caption: String,
}

Expand All @@ -722,7 +781,6 @@ mod tests {
fn it_builds_correct_notification_query_without_filters() {
#[derive(Deserialize, Debug)]
struct Win32_ProcessStartTrace {
#[allow(dead_code)]
Caption: String,
}

Expand All @@ -736,7 +794,6 @@ mod tests {
fn it_builds_correct_query() {
#[derive(Deserialize, Debug)]
struct Win32_OperatingSystem {
#[allow(dead_code)]
Caption: String,
}

Expand Down Expand Up @@ -769,7 +826,6 @@ mod tests {
fn it_builds_correct_notification_query() {
#[derive(Deserialize, Debug)]
struct Win32_ProcessStartTrace {
#[allow(dead_code)]
Caption: String,
}

Expand Down Expand Up @@ -930,7 +986,6 @@ mod tests {
#[derive(Deserialize, Debug)]
struct Win32_DiskDrive {
__Path: String,
#[allow(dead_code)]
Caption: String,
}

Expand Down Expand Up @@ -962,11 +1017,11 @@ mod tests {
fn it_can_query_correct_variant_types() {
let wmi_con = wmi_con();
let mut results: Vec<HashMap<String, Variant>> = wmi_con
.raw_query("SELECT SystemStabilityIndex FROM Win32_ReliabilityStabilityMetrics")
.raw_query("SELECT CPUScore FROM Win32_WinSAT")
.unwrap();

match results.pop().unwrap().values().next() {
Some(&Variant::R8(_v)) => assert!(true),
Some(&Variant::R4(_v)) => assert!(true),
_ => assert!(false),
}

Expand Down Expand Up @@ -1027,14 +1082,6 @@ mod tests {
fn it_can_query_floats() {
let wmi_con = wmi_con();

#[derive(Deserialize, Debug)]
struct Win32_ReliabilityStabilityMetrics {
SystemStabilityIndex: f64,
}

let metric = wmi_con.get::<Win32_ReliabilityStabilityMetrics>().unwrap();
assert!(metric.SystemStabilityIndex >= 0.0);

#[derive(Deserialize, Debug)]
struct Win32_WinSAT {
CPUScore: f32,
Expand Down