diff --git a/Cargo.toml b/Cargo.toml index 4dbb883..4696cb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ datafusion = { version = "46", default-features = false } jiter = "0.9" paste = "1" log = "0.4" +jsonpath-rust = "1.0.0" [dev-dependencies] datafusion = { version = "46", default-features = false, features = ["nested_expressions"] } @@ -22,6 +23,7 @@ codspeed-criterion-compat = "2.6" criterion = "0.5.1" clap = "4" tokio = { version = "1.43", features = ["full"] } +rstest = "0.25.0" [lints.clippy] dbg_macro = "deny" diff --git a/src/common.rs b/src/common.rs index 66505cb..1afa9f4 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,6 +1,9 @@ use std::str::Utf8Error; use std::sync::Arc; +use crate::common_union::{ + is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL, +}; use datafusion::arrow::array::{ downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray, PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray, @@ -11,10 +14,8 @@ use datafusion::arrow::datatypes::{ArrowNativeType, DataType, Int64Type, UInt64T use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use jiter::{Jiter, JiterError, Peek}; - -use crate::common_union::{ - is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL, -}; +use jsonpath_rust::parser::model::{Segment, Selector}; +use jsonpath_rust::parser::parse_json_path; /// General implementation of `ScalarUDFImpl::return_type`. /// @@ -68,7 +69,7 @@ fn dict_key_type(d: &DataType) -> Option { None } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum JsonPath<'s> { Key(&'s str), Index(usize), @@ -140,6 +141,22 @@ impl<'s> JsonPathArgs<'s> { } } +pub(crate) fn parse_jsonpath(path: &str) -> Vec> { + let segments = parse_json_path(path).map(|it| it.segments).unwrap_or(Vec::new()); + + segments + .into_iter() + .map(|segment| match segment { + Segment::Selector(s) => match s { + Selector::Name(name) => JsonPath::Key(Box::leak(name.into_boxed_str())), + Selector::Index(idx) => JsonPath::Index(idx as usize), + _ => JsonPath::None, + }, + _ => JsonPath::None, + }) + .collect::>() +} + pub trait InvokeResult { type Item; type Builder; @@ -585,3 +602,21 @@ fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> Pr } PrimitiveArray::new(keys.values().clone(), Some(null_mask.into())) } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + // Test cases for parse_jsonpath + #[rstest] + #[case("$.a.aa", vec![JsonPath::Key("a"), JsonPath::Key("aa")])] + #[case("$.a.ab[0].ac", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(0), JsonPath::Key("ac")])] + #[case("$.a.ab[1].ad", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(1), JsonPath::Key("ad")])] + #[case(r#"$.a["a b"].ad"#, vec![JsonPath::Key("a"), JsonPath::Key("\"a b\""), JsonPath::Key("ad")])] + #[tokio::test] + async fn test_parse_jsonpath(#[case] path: &str, #[case] expected: Vec>) { + let result = parse_jsonpath(path); + assert_eq!(result, expected); + } +} diff --git a/src/json_extract.rs b/src/json_extract.rs new file mode 100644 index 0000000..2c60a95 --- /dev/null +++ b/src/json_extract.rs @@ -0,0 +1,82 @@ +use crate::common::{invoke, parse_jsonpath, return_type_check}; +use crate::common_macros::make_udf_function; +use crate::json_get_json::jiter_json_get_json; +use datafusion::arrow::array::StringArray; +use datafusion::arrow::datatypes::{DataType, DataType::Utf8}; +use datafusion::common::{exec_err, Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +make_udf_function!( + JsonExtract, + json_extract, + json_data path, + r#"Get a value from a JSON string by its "path" in JSONPath format"# +); + +#[derive(Debug)] +pub(super) struct JsonExtract { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonExtract { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![Utf8, Utf8], // JSON data and JSONPath as strings + Volatility::Immutable, + ), + aliases: ["json_extract".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + return_type_check(arg_types, self.name(), Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + if args.args.len() != 2 { + return exec_err!( + "'{}' expects exactly 2 arguments (JSON data, path), got {}", + self.name(), + args.args.len() + ); + } + + let json_arg = &args.args[0]; + let path_arg = &args.args[1]; + + let path_str = match path_arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, + _ => { + return exec_err!( + "'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument", + self.name() + ) + } + }; + + let path = parse_jsonpath(path_str); + + invoke::(&[json_arg.clone()], |json, _| jiter_json_get_json(json, &path)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} diff --git a/src/json_extract_scalar.rs b/src/json_extract_scalar.rs new file mode 100644 index 0000000..0a8791b --- /dev/null +++ b/src/json_extract_scalar.rs @@ -0,0 +1,83 @@ +use std::any::Any; + +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{exec_err, Result as DataFusionResult}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion::scalar::ScalarValue; + +use crate::common::parse_jsonpath; +use crate::common::{invoke, return_type_check}; +use crate::common_macros::make_udf_function; +use crate::common_union::JsonUnion; +use crate::json_get::jiter_json_get_union; + +make_udf_function!( + JsonExtractScalar, + json_extract_scalar, + json_data path, + r#"Get a value from a JSON string by its "path""# +); + +#[derive(Debug)] +pub(super) struct JsonExtractScalar { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonExtractScalar { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_extract_scalar".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonExtractScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + return_type_check(arg_types, self.name(), JsonUnion::data_type()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + if args.args.len() != 2 { + return exec_err!( + "'{}' expects exactly 2 arguments (JSON data, path), got {}", + self.name(), + args.args.len() + ); + } + + let json_arg = &args.args[0]; + let path_arg = &args.args[1]; + + let path_str = match path_arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, + _ => { + return exec_err!( + "'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument", + self.name() + ) + } + }; + + let path = parse_jsonpath(path_str); + + invoke::(&[json_arg.clone()], |json, _| jiter_json_get_union(json, &path)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} diff --git a/src/json_get.rs b/src/json_get.rs index 097bae2..9b803da 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -93,7 +93,7 @@ impl InvokeResult for JsonUnion { } } -fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result { +pub(crate) fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result { if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { build_union(&mut jiter, peek) } else { diff --git a/src/json_get_json.rs b/src/json_get_json.rs index a8c6477..460fd23 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -56,7 +56,7 @@ impl ScalarUDFImpl for JsonGetJson { } } -fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result { +pub(crate) fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result { if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { let start = jiter.current_index(); jiter.known_skip(peek)?; diff --git a/src/lib.rs b/src/lib.rs index cb0f25a..0532450 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,8 @@ mod common_macros; mod common_union; mod json_as_text; mod json_contains; +mod json_extract; +mod json_extract_scalar; mod json_get; mod json_get_bool; mod json_get_float; @@ -25,6 +27,8 @@ pub use common_union::{JsonUnionEncoder, JsonUnionValue}; pub mod functions { pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains; + pub use crate::json_extract::json_extract; + pub use crate::json_extract_scalar::json_extract_scalar; pub use crate::json_get::json_get; pub use crate::json_get_bool::json_get_bool; pub use crate::json_get_float::json_get_float; @@ -60,6 +64,8 @@ pub mod udfs { pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ json_get::json_get_udf(), + json_extract::json_extract_udf(), + json_extract_scalar::json_extract_scalar_udf(), json_get_bool::json_get_bool_udf(), json_get_float::json_get_float_udf(), json_get_int::json_get_int_udf(), diff --git a/tests/json_extract_scalar_test.rs b/tests/json_extract_scalar_test.rs new file mode 100644 index 0000000..2df5b44 --- /dev/null +++ b/tests/json_extract_scalar_test.rs @@ -0,0 +1,43 @@ +use crate::utils::{display_val, run_query}; +use rstest::{fixture, rstest}; + +mod utils; + +#[fixture] +fn json_data() -> String { + let json = r#" + { + "store": { + "book name": "My Favorite Books", + "book": [ + {"title": "1984", "author": "George Orwell"}, + {"title": "Pride and Prejudice", "author": "Jane Austen"} + ] + } + } + "#; + json.to_string() +} + +#[rstest] +#[case("$.store.book[0].author", "{str=George Orwell}")] +#[tokio::test] +async fn test_json_extract_scalar(json_data: String, #[case] path: &str, #[case] expected: &str) { + let result = json_extract_scalar(&json_data, path).await; + assert_eq!(result, expected.to_string()); +} + +#[rstest] +#[case("[1, 2, 3]", "$[2]", "{int=3}")] +#[case("[1, 2, 3]", "$[3]", "{null=}")] +#[tokio::test] +async fn test_json_extract_scalar_simple(#[case] json: String, #[case] path: &str, #[case] expected: &str) { + let result = json_extract_scalar(&json, path).await; + assert_eq!(result, expected.to_string()); +} + +async fn json_extract_scalar(json: &str, path: &str) -> String { + let sql = format!("select json_extract_scalar('{}', '{}')", json, path); + let batches = run_query(sql.as_str()).await.unwrap(); + display_val(batches).await.1 +} diff --git a/tests/json_extract_test.rs b/tests/json_extract_test.rs new file mode 100644 index 0000000..62df348 --- /dev/null +++ b/tests/json_extract_test.rs @@ -0,0 +1,35 @@ +use crate::utils::{display_val, run_query}; +use rstest::{fixture, rstest}; + +mod utils; + +#[fixture] +fn json_data() -> String { + let json = r#"{"a": {"a a": "My Collection","ab": [{"ac": "Dune", "ca": "Frank Herbert"},{"ad": "Foundation", "da": "Isaac Asimov"}]}}"#; + json.to_string() +} + +#[rstest] +#[case( + "$.a.ab", + "[{\"ac\": \"Dune\", \"ca\": \"Frank Herbert\"},{\"ad\": \"Foundation\", \"da\": \"Isaac Asimov\"}]" +)] +#[tokio::test] +async fn test_json_paths(json_data: String, #[case] path: &str, #[case] expected: &str) { + let result = json_extract(&json_data, path).await; + assert_eq!(result, expected.to_string()); +} + +#[rstest] +#[tokio::test] +#[ignore] +async fn test_invalid_json_path(json_data: String) { + let result = json_extract(&json_data, "store.invalid.path").await; + assert_eq!(result, "".to_string()); +} + +async fn json_extract(json: &str, path: &str) -> String { + let sql = format!("select json_extract('{}', '{}')", json, path); + let batches = run_query(sql.as_str()).await.unwrap(); + display_val(batches).await.1 +} diff --git a/tests/main.rs b/tests/main.rs index f591385..82df51e 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -82,6 +82,28 @@ async fn test_json_get_union() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_json_extract_union() { + let batches = run_query("select name, json_extract(json_data, '$.foo') as foo from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-------------+", + "| name | foo |", + "+------------------+-------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+-------------+", + ]; + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_json_get_array() { let sql = "select json_get('[1, 2, 3]', 2)";