From 00eb2d080e2cff1212ed10438f01462ad4799f4c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Apr 2024 12:03:44 -0700 Subject: [PATCH] Coerce Dictionary types for scalar functions (#10077) * Coerce Dictionary types for scalar functions * Fix * Fix format * Add test --- datafusion/expr/src/built_in_function.rs | 15 +++++++++++ .../expr/src/type_coercion/functions.rs | 25 +++++++++++++++++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index dc1fc98a5c02..e8763657ceb3 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -448,4 +448,19 @@ mod tests { .unwrap(); assert_eq!(return_type, DataType::Date32); } + + #[test] + fn test_coalesce_return_types_dictionary() { + let coalesce = BuiltinScalarFunction::Coalesce; + let return_type = coalesce + .return_type(&[ + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Utf8, + ]) + .unwrap(); + assert_eq!( + return_type, + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + ); + } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 34b607d0884d..37eeb7d464b8 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -314,8 +314,13 @@ fn coerced_from<'a>( // match Dictionary first match (type_into, type_from) { // coerced dictionary first - (cur_type, Dictionary(_, value_type)) | (Dictionary(_, value_type), cur_type) - if coerced_from(cur_type, value_type).is_some() => + (_, Dictionary(_, value_type)) + if coerced_from(type_into, value_type).is_some() => + { + Some(type_into.clone()) + } + (Dictionary(_, value_type), _) + if coerced_from(value_type, type_from).is_some() => { Some(type_into.clone()) } @@ -624,4 +629,20 @@ mod tests { Ok(()) } + + #[test] + fn test_coerced_from_dictionary() { + let type_into = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32)); + let type_from = DataType::Int64; + assert_eq!(coerced_from(&type_into, &type_from), None); + + let type_from = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32)); + let type_into = DataType::Int64; + assert_eq!( + coerced_from(&type_into, &type_from), + Some(type_into.clone()) + ); + } }