Skip to content

Feature/json extract scalar #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ datafusion = { version = "46", default-features = false }
jiter = "0.9"
paste = "1"
log = "0.4"
jsonpath-rust = "1.0.0"

[dev-dependencies]
datafusion = { version = "46", default-features = false, features = ["nested_expressions"] }
codspeed-criterion-compat = "2.6"
criterion = "0.5.1"
clap = "4"
tokio = { version = "1.43", features = ["full"] }
rstest = "0.25.0"

[lints.clippy]
dbg_macro = "deny"
Expand Down
45 changes: 40 additions & 5 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::str::Utf8Error;
use std::sync::Arc;

use crate::common_union::{
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
};
use datafusion::arrow::array::{
downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray,
PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray,
Expand All @@ -11,10 +14,8 @@ use datafusion::arrow::datatypes::{ArrowNativeType, DataType, Int64Type, UInt64T
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
};
use jsonpath_rust::parser::model::{Segment, Selector};
use jsonpath_rust::parser::parse_json_path;

/// General implementation of `ScalarUDFImpl::return_type`.
///
Expand Down Expand Up @@ -68,7 +69,7 @@ fn dict_key_type(d: &DataType) -> Option<DataType> {
None
}

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub enum JsonPath<'s> {
Key(&'s str),
Index(usize),
Expand Down Expand Up @@ -140,6 +141,22 @@ impl<'s> JsonPathArgs<'s> {
}
}

pub(crate) fn parse_jsonpath(path: &str) -> Vec<JsonPath<'static>> {
let segments = parse_json_path(path).map(|it| it.segments).unwrap_or(Vec::new());

segments
.into_iter()
.map(|segment| match segment {
Segment::Selector(s) => match s {
Selector::Name(name) => JsonPath::Key(Box::leak(name.into_boxed_str())),
Selector::Index(idx) => JsonPath::Index(idx as usize),
_ => JsonPath::None,
},
_ => JsonPath::None,
})
.collect::<Vec<_>>()
}

pub trait InvokeResult {
type Item;
type Builder;
Expand Down Expand Up @@ -585,3 +602,21 @@ fn mask_dictionary_keys(keys: &PrimitiveArray<Int64Type>, type_ids: &[i8]) -> Pr
}
PrimitiveArray::new(keys.values().clone(), Some(null_mask.into()))
}

#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;

// Test cases for parse_jsonpath
#[rstest]
#[case("$.a.aa", vec![JsonPath::Key("a"), JsonPath::Key("aa")])]
#[case("$.a.ab[0].ac", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(0), JsonPath::Key("ac")])]
#[case("$.a.ab[1].ad", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(1), JsonPath::Key("ad")])]
#[case(r#"$.a["a b"].ad"#, vec![JsonPath::Key("a"), JsonPath::Key("\"a b\""), JsonPath::Key("ad")])]
#[tokio::test]
async fn test_parse_jsonpath(#[case] path: &str, #[case] expected: Vec<JsonPath<'static>>) {
let result = parse_jsonpath(path);
assert_eq!(result, expected);
}
}
82 changes: 82 additions & 0 deletions src/json_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use crate::common::{invoke, parse_jsonpath, return_type_check};
use crate::common_macros::make_udf_function;
use crate::json_get_json::jiter_json_get_json;
use datafusion::arrow::array::StringArray;
use datafusion::arrow::datatypes::{DataType, DataType::Utf8};
use datafusion::common::{exec_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

make_udf_function!(
JsonExtract,
json_extract,
json_data path,
r#"Get a value from a JSON string by its "path" in JSONPath format"#
);

#[derive(Debug)]
pub(super) struct JsonExtract {
signature: Signature,
aliases: [String; 1],
}

impl Default for JsonExtract {
fn default() -> Self {
Self {
signature: Signature::exact(
vec![Utf8, Utf8], // JSON data and JSONPath as strings
Volatility::Immutable,
),
aliases: ["json_extract".to_string()],
}
}
}

impl ScalarUDFImpl for JsonExtract {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.aliases[0].as_str()
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
return_type_check(arg_types, self.name(), Utf8)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
if args.args.len() != 2 {
return exec_err!(
"'{}' expects exactly 2 arguments (JSON data, path), got {}",
self.name(),
args.args.len()
);
}

let json_arg = &args.args[0];
let path_arg = &args.args[1];

let path_str = match path_arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s,
_ => {
return exec_err!(
"'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument",
self.name()
)
}
};

let path = parse_jsonpath(path_str);

invoke::<StringArray>(&[json_arg.clone()], |json, _| jiter_json_get_json(json, &path))
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}
83 changes: 83 additions & 0 deletions src/json_extract_scalar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use std::any::Any;

use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, Result as DataFusionResult};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion::scalar::ScalarValue;

use crate::common::parse_jsonpath;
use crate::common::{invoke, return_type_check};
use crate::common_macros::make_udf_function;
use crate::common_union::JsonUnion;
use crate::json_get::jiter_json_get_union;

make_udf_function!(
JsonExtractScalar,
json_extract_scalar,
json_data path,
r#"Get a value from a JSON string by its "path""#
);

#[derive(Debug)]
pub(super) struct JsonExtractScalar {
signature: Signature,
aliases: [String; 1],
}

impl Default for JsonExtractScalar {
fn default() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: ["json_extract_scalar".to_string()],
}
}
}

impl ScalarUDFImpl for JsonExtractScalar {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.aliases[0].as_str()
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
return_type_check(arg_types, self.name(), JsonUnion::data_type())
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
if args.args.len() != 2 {
return exec_err!(
"'{}' expects exactly 2 arguments (JSON data, path), got {}",
self.name(),
args.args.len()
);
}

let json_arg = &args.args[0];
let path_arg = &args.args[1];

let path_str = match path_arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s,
_ => {
return exec_err!(
"'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument",
self.name()
)
}
};

let path = parse_jsonpath(path_str);

invoke::<JsonUnion>(&[json_arg.clone()], |json, _| jiter_json_get_union(json, &path))
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}
2 changes: 1 addition & 1 deletion src/json_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl InvokeResult for JsonUnion {
}
}

fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result<JsonUnionField, GetError> {
pub(crate) fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result<JsonUnionField, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
build_union(&mut jiter, peek)
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/json_get_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl ScalarUDFImpl for JsonGetJson {
}
}

fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result<String, GetError> {
pub(crate) fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result<String, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
let start = jiter.current_index();
jiter.known_skip(peek)?;
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ mod common_macros;
mod common_union;
mod json_as_text;
mod json_contains;
mod json_extract;
mod json_extract_scalar;
mod json_get;
mod json_get_bool;
mod json_get_float;
Expand All @@ -25,6 +27,8 @@ pub use common_union::{JsonUnionEncoder, JsonUnionValue};
pub mod functions {
pub use crate::json_as_text::json_as_text;
pub use crate::json_contains::json_contains;
pub use crate::json_extract::json_extract;
pub use crate::json_extract_scalar::json_extract_scalar;
pub use crate::json_get::json_get;
pub use crate::json_get_bool::json_get_bool;
pub use crate::json_get_float::json_get_float;
Expand Down Expand Up @@ -60,6 +64,8 @@ pub mod udfs {
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let functions: Vec<Arc<ScalarUDF>> = vec![
json_get::json_get_udf(),
json_extract::json_extract_udf(),
json_extract_scalar::json_extract_scalar_udf(),
json_get_bool::json_get_bool_udf(),
json_get_float::json_get_float_udf(),
json_get_int::json_get_int_udf(),
Expand Down
43 changes: 43 additions & 0 deletions tests/json_extract_scalar_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use crate::utils::{display_val, run_query};
use rstest::{fixture, rstest};

mod utils;

#[fixture]
fn json_data() -> String {
let json = r#"
{
"store": {
"book name": "My Favorite Books",
"book": [
{"title": "1984", "author": "George Orwell"},
{"title": "Pride and Prejudice", "author": "Jane Austen"}
]
}
}
"#;
json.to_string()
}

#[rstest]
#[case("$.store.book[0].author", "{str=George Orwell}")]
#[tokio::test]
async fn test_json_extract_scalar(json_data: String, #[case] path: &str, #[case] expected: &str) {
let result = json_extract_scalar(&json_data, path).await;
assert_eq!(result, expected.to_string());
}

#[rstest]
#[case("[1, 2, 3]", "$[2]", "{int=3}")]
#[case("[1, 2, 3]", "$[3]", "{null=}")]
#[tokio::test]
async fn test_json_extract_scalar_simple(#[case] json: String, #[case] path: &str, #[case] expected: &str) {
let result = json_extract_scalar(&json, path).await;
assert_eq!(result, expected.to_string());
}

async fn json_extract_scalar(json: &str, path: &str) -> String {
let sql = format!("select json_extract_scalar('{}', '{}')", json, path);
let batches = run_query(sql.as_str()).await.unwrap();
display_val(batches).await.1
}
35 changes: 35 additions & 0 deletions tests/json_extract_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use crate::utils::{display_val, run_query};
use rstest::{fixture, rstest};

mod utils;

#[fixture]
fn json_data() -> String {
let json = r#"{"a": {"a a": "My Collection","ab": [{"ac": "Dune", "ca": "Frank Herbert"},{"ad": "Foundation", "da": "Isaac Asimov"}]}}"#;
json.to_string()
}

#[rstest]
#[case(
"$.a.ab",
"[{\"ac\": \"Dune\", \"ca\": \"Frank Herbert\"},{\"ad\": \"Foundation\", \"da\": \"Isaac Asimov\"}]"
)]
#[tokio::test]
async fn test_json_paths(json_data: String, #[case] path: &str, #[case] expected: &str) {
let result = json_extract(&json_data, path).await;
assert_eq!(result, expected.to_string());
}

#[rstest]
#[tokio::test]
#[ignore]
async fn test_invalid_json_path(json_data: String) {
let result = json_extract(&json_data, "store.invalid.path").await;
assert_eq!(result, "".to_string());
}

async fn json_extract(json: &str, path: &str) -> String {
let sql = format!("select json_extract('{}', '{}')", json, path);
let batches = run_query(sql.as_str()).await.unwrap();
display_val(batches).await.1
}
Loading
Loading