Skip to content

feat: implement json_get_array udf v2 #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ SELECT id, json_col->'a' as json_col_a FROM test_table
* [x] `json_get_float(json: str, *keys: str | int) -> float` - Get a float value from a JSON string by its "path"
* [x] `json_get_bool(json: str, *keys: str | int) -> bool` - Get a boolean value from a JSON string by its "path"
* [x] `json_get_json(json: str, *keys: str | int) -> str` - Get a nested raw JSON string from a JSON string by its "path"
* [x] `json_get_array(json: str, *keys: str | int) -> array` - Get an arrow array from a JSON string by its "path"
* [x] `json_as_text(json: str, *keys: str | int) -> str` - Get any value from a JSON string by its "path", represented as a string (used for the `->>` operator)
* [x] `json_length(json: str, *keys: str | int) -> int` - get the length of a JSON string or array

Expand Down
134 changes: 134 additions & 0 deletions src/json_get_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use std::any::Any;
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, ListBuilder, StringBuilder};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
JsonGetArray,
json_get_array,
json_data path,
r#"Get an arrow array from a JSON string by its "path""#
);

#[derive(Debug)]
pub(super) struct JsonGetArray {
signature: Signature,
aliases: [String; 1],
}

impl Default for JsonGetArray {
fn default() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: ["json_get_array".to_string()],
}
}
}

impl ScalarUDFImpl for JsonGetArray {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.aliases[0].as_str()
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
return_type_check(
arg_types,
self.name(),
DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
"item",
DataType::Utf8,
true,
))),
)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
invoke::<BuildArrayList>(&args.args, jiter_json_get_array)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

#[derive(Debug)]
struct BuildArrayList;

impl InvokeResult for BuildArrayList {
type Item = Vec<String>;

type Builder = ListBuilder<StringBuilder>;

const ACCEPT_DICT_RETURN: bool = false;

fn builder(capacity: usize) -> Self::Builder {
let values_builder = StringBuilder::new();
ListBuilder::with_capacity(values_builder, capacity)
}

fn append_value(builder: &mut Self::Builder, value: Option<Self::Item>) {
builder.append_option(value.map(|v| v.into_iter().map(Some)));
}

fn finish(mut builder: Self::Builder) -> DataFusionResult<ArrayRef> {
Ok(Arc::new(builder.finish()))
}

fn scalar(value: Option<Self::Item>) -> ScalarValue {
let mut builder = ListBuilder::new(StringBuilder::new());

if let Some(array_items) = value {
for item in array_items {
builder.values().append_value(item);
}

builder.append(true);
} else {
builder.append(false);
}
let array = builder.finish();
ScalarValue::List(Arc::new(array))
}
}

fn jiter_json_get_array(opt_json: Option<&str>, path: &[JsonPath]) -> Result<Vec<String>, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
match peek {
Peek::Array => {
let mut peek_opt = jiter.known_array()?;
let mut array_items: Vec<String> = Vec::new();

while let Some(element_peek) = peek_opt {
// Get the raw JSON slice for each array element
let start = jiter.current_index();
jiter.known_skip(element_peek)?;
let slice = jiter.slice_to_current(start);
let element_str = std::str::from_utf8(slice)?.to_string();

array_items.push(element_str);
peek_opt = jiter.array_step()?;
}

Ok(array_items)
}
_ => get_err!(),
}
} else {
get_err!()
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod common_union;
mod json_as_text;
mod json_contains;
mod json_get;
mod json_get_array;
mod json_get_bool;
mod json_get_float;
mod json_get_int;
Expand All @@ -26,6 +27,7 @@ pub mod functions {
pub use crate::json_as_text::json_as_text;
pub use crate::json_contains::json_contains;
pub use crate::json_get::json_get;
pub use crate::json_get_array::json_get_array;
pub use crate::json_get_bool::json_get_bool;
pub use crate::json_get_float::json_get_float;
pub use crate::json_get_int::json_get_int;
Expand All @@ -39,6 +41,7 @@ pub mod udfs {
pub use crate::json_as_text::json_as_text_udf;
pub use crate::json_contains::json_contains_udf;
pub use crate::json_get::json_get_udf;
pub use crate::json_get_array::json_get_array_udf;
pub use crate::json_get_bool::json_get_bool_udf;
pub use crate::json_get_float::json_get_float_udf;
pub use crate::json_get_int::json_get_int_udf;
Expand All @@ -64,6 +67,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
json_get_float::json_get_float_udf(),
json_get_int::json_get_int_udf(),
json_get_json::json_get_json_udf(),
json_get_array::json_get_array_udf(),
json_as_text::json_as_text_udf(),
json_get_str::json_get_str_udf(),
json_contains::json_contains_udf(),
Expand Down
65 changes: 64 additions & 1 deletion tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,77 @@ async fn test_json_get_union() {
}

#[tokio::test]
async fn test_json_get_array() {
async fn test_json_get_array_elem() {
let sql = "select json_get('[1, 2, 3]', 2)";
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::Union(_, _)));
assert_eq!(value_repr, "{int=3}");
}

#[tokio::test]
async fn test_json_get_array_basic_numbers() {
let sql = "select json_get_array('[1, 2, 3]')";
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::List(_)));
assert_eq!(value_repr, "[1, 2, 3]");
}

#[tokio::test]
async fn test_json_get_array_mixed_types() {
let sql = r#"select json_get_array('["hello", 42, true, null, 3.14]')"#;
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::List(_)));
assert_eq!(value_repr, r#"["hello", 42, true, null, 3.14]"#);
}

#[tokio::test]
async fn test_json_get_array_nested_objects() {
let sql = r#"select json_get_array('[{"name": "John"}, {"age": 30}]')"#;
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::List(_)));
assert_eq!(value_repr, r#"[{"name": "John"}, {"age": 30}]"#);
}

#[tokio::test]
async fn test_json_get_array_nested_arrays() {
let sql = r#"select json_get_array('[[1, 2], [3, 4]]')"#;
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::List(_)));
assert_eq!(value_repr, "[[1, 2], [3, 4]]");
}

#[tokio::test]
async fn test_json_get_array_empty() {
let sql = "select json_get_array('[]')";
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::List(_)));
assert_eq!(value_repr, "[]");
}

#[tokio::test]
async fn test_json_get_array_invalid_json() {
let sql = "select json_get_array('')";
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::List(_)));
assert_eq!(value_repr, "");
}

#[tokio::test]
async fn test_json_get_array_with_path() {
let sql = r#"select json_get_array('{"items": [1, 2, 3]}', 'items')"#;
let batches = run_query(sql).await.unwrap();
let (value_type, value_repr) = display_val(batches).await;
assert!(matches!(value_type, DataType::List(_)));
assert_eq!(value_repr, "[1, 2, 3]");
}

#[tokio::test]
async fn test_json_get_equals() {
let e = run_query(r"select name, json_get(json_data, 'foo')='abc' from test")
Expand Down
Loading