From d6446364439b428d2ffa216fc54dd756158068f9 Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Sat, 17 Jun 2023 01:04:12 +0400 Subject: [PATCH] feat: Improve subquery support --- Cargo.lock | 2 +- datafusion-cli/Cargo.lock | 2 +- datafusion/common/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- .../core/src/datasource/listing/helpers.rs | 1 + datafusion/core/src/logical_plan/builder.rs | 7 +- .../core/src/logical_plan/expr_rewriter.rs | 17 +- .../core/src/logical_plan/expr_schema.rs | 3 +- .../core/src/logical_plan/expr_visitor.rs | 4 + datafusion/core/src/logical_plan/mod.rs | 2 +- datafusion/core/src/logical_plan/plan.rs | 109 +++- .../src/optimizer/common_subexpr_eliminate.rs | 4 + .../core/src/optimizer/projection_drop_out.rs | 2 + .../src/optimizer/projection_push_down.rs | 10 +- .../src/optimizer/simplify_expressions.rs | 1 + datafusion/core/src/optimizer/utils.rs | 17 +- datafusion/core/src/physical_plan/planner.rs | 33 +- datafusion/core/src/physical_plan/subquery.rs | 183 ++++-- datafusion/core/src/sql/planner.rs | 239 ++++++-- datafusion/core/src/sql/utils.rs | 17 +- datafusion/core/tests/sql/expr.rs | 132 ++++- datafusion/core/tests/sql/subquery.rs | 162 ++++- datafusion/expr/Cargo.toml | 2 +- datafusion/expr/src/expr.rs | 49 +- .../physical-expr/src/expressions/any.rs | 554 +++++++++++------- 25 files changed, 1206 insertions(+), 350 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 382e2f2a2308..f43c482b6217 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2274,7 +2274,7 @@ dependencies = [ [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=b3b40586d4c32a218ffdfcb0462e7e216cf3d6eb#b3b40586d4c32a218ffdfcb0462e7e216cf3d6eb" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=2229652dc8fae8f45cbec344b4a1e40cf1bb69d9#2229652dc8fae8f45cbec344b4a1e40cf1bb69d9" dependencies = [ "log", ] diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 4d7a14443e1b..0f05f29b0633 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1477,7 +1477,7 @@ checksum = "45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451" [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=b3b40586d4c32a218ffdfcb0462e7e216cf3d6eb#b3b40586d4c32a218ffdfcb0462e7e216cf3d6eb" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=2229652dc8fae8f45cbec344b4a1e40cf1bb69d9#2229652dc8fae8f45cbec344b4a1e40cf1bb69d9" dependencies = [ "log", ] diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 41ed0f5fdbae..22c21a71260d 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true } ordered-float = "2.10" parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "096ef28dde6b1ae43ce89ba2c3a9d98295f2972e", features = ["arrow"], optional = true } pyo3 = { version = "0.16", optional = true } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "b3b40586d4c32a218ffdfcb0462e7e216cf3d6eb" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "2229652dc8fae8f45cbec344b4a1e40cf1bb69d9" } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index ab00521a853a..62b6a79678aa 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7" pyo3 = { version = "0.16", optional = true } rand = "0.8" smallvec = { version = "1.6", features = ["union"] } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "b3b40586d4c32a218ffdfcb0462e7e216cf3d6eb" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "2229652dc8fae8f45cbec344b4a1e40cf1bb69d9" } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 2f36a86c0e84..331d049c6900 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -97,6 +97,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::ILike { .. } | Expr::SimilarTo { .. } | Expr::InList { .. } + | Expr::InSubquery { .. } | Expr::GetIndexedField { .. } | Expr::Case { .. } => Recursion::Continue(self), diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index 1f8237d244a7..6196909c4d61 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -26,7 +26,7 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::expr_schema::ExprSchemable; use crate::logical_plan::plan::{ Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort, Subquery, - TableScan, TableUDFs, ToStringifiedPlan, Union, Window, + SubqueryType, TableScan, TableUDFs, ToStringifiedPlan, Union, Window, }; use crate::optimizer::utils; use crate::prelude::*; @@ -528,12 +528,15 @@ impl LogicalPlanBuilder { pub fn subquery( &self, subqueries: impl IntoIterator>, + types: impl IntoIterator, ) -> Result { let subqueries = subqueries.into_iter().map(|l| l.into()).collect::>(); - let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries)); + let types = types.into_iter().collect::>(); + let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries, &types)); Ok(Self::from(LogicalPlan::Subquery(Subquery { input: Arc::new(self.plan.clone()), subqueries, + types, schema, }))) } diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs index 426de5028391..8e65962771b5 100644 --- a/datafusion/core/src/logical_plan/expr_rewriter.rs +++ b/datafusion/core/src/logical_plan/expr_rewriter.rs @@ -122,10 +122,16 @@ impl ExprRewritable for Expr { op, right: rewrite_boxed(right, rewriter)?, }, - Expr::AnyExpr { left, op, right } => Expr::AnyExpr { + Expr::AnyExpr { + left, + op, + right, + all, + } => Expr::AnyExpr { left: rewrite_boxed(left, rewriter)?, op, right: rewrite_boxed(right, rewriter)?, + all, }, Expr::Like(Like { negated, @@ -263,6 +269,15 @@ impl ExprRewritable for Expr { list: rewrite_vec(list, rewriter)?, negated, }, + Expr::InSubquery { + expr, + subquery, + negated, + } => Expr::InSubquery { + expr: rewrite_boxed(expr, rewriter)?, + subquery: rewrite_boxed(subquery, rewriter)?, + negated, + }, Expr::Wildcard => Expr::Wildcard, Expr::QualifiedWildcard { qualifier } => { Expr::QualifiedWildcard { qualifier } diff --git a/datafusion/core/src/logical_plan/expr_schema.rs b/datafusion/core/src/logical_plan/expr_schema.rs index 85e2c603be19..e3f4d4ab2c7e 100644 --- a/datafusion/core/src/logical_plan/expr_schema.rs +++ b/datafusion/core/src/logical_plan/expr_schema.rs @@ -111,6 +111,7 @@ impl ExprSchemable for Expr { | Expr::IsNull(_) | Expr::Between { .. } | Expr::InList { .. } + | Expr::InSubquery { .. } | Expr::AnyExpr { .. } | Expr::IsNotNull(_) => Ok(DataType::Boolean), Expr::BinaryExpr { @@ -158,7 +159,7 @@ impl ExprSchemable for Expr { | Expr::Between { expr, .. } | Expr::InList { expr, .. } => expr.nullable(input_schema), Expr::Column(c) => input_schema.nullable(c), - Expr::OuterColumn(_, _) => Ok(true), + Expr::OuterColumn(_, _) | Expr::InSubquery { .. } => Ok(true), Expr::Literal(value) => Ok(value.is_null()), Expr::Case { when_then_expr, diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs index 41294cf2cf54..9296848ea8ab 100644 --- a/datafusion/core/src/logical_plan/expr_visitor.rs +++ b/datafusion/core/src/logical_plan/expr_visitor.rs @@ -191,6 +191,10 @@ impl ExprVisitable for Expr { list.iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)) } + Expr::InSubquery { expr, subquery, .. } => { + let visitor = expr.accept(visitor)?; + subquery.accept(visitor) + } }?; visitor.post_visit(self) diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index f2c5f5370289..5609a468638b 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -68,6 +68,6 @@ pub use plan::{ CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CrossJoin, Distinct, DropTable, EmptyRelation, Filter, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, PlanVisitor, Repartition, StringifiedPlan, Subquery, - TableScan, ToStringifiedPlan, Union, Values, + SubqueryType, TableScan, ToStringifiedPlan, Union, Values, }; pub use registry::FunctionRegistry; diff --git a/datafusion/core/src/logical_plan/plan.rs b/datafusion/core/src/logical_plan/plan.rs index 3c8d5a72fa0e..00a4bc0716be 100644 --- a/datafusion/core/src/logical_plan/plan.rs +++ b/datafusion/core/src/logical_plan/plan.rs @@ -26,7 +26,7 @@ use crate::error::DataFusionError; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::DFSchema; +use datafusion_common::{DFField, DFSchema}; use std::fmt::Formatter; use std::{ collections::HashSet, @@ -267,22 +267,97 @@ pub struct Limit { /// Evaluates correlated sub queries #[derive(Clone)] pub struct Subquery { - /// The list of sub queries - pub subqueries: Vec, /// The incoming logical plan pub input: Arc, + /// The list of sub queries + pub subqueries: Vec, + /// The list of subquery types + pub types: Vec, /// The schema description of the output pub schema: DFSchemaRef, } +/// Subquery type +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum SubqueryType { + /// Scalar (SELECT, WHERE) evaluating to one value + Scalar, + /// EXISTS(...) evaluating to true if at least one row was produced + Exists, + /// ANY(...) / ALL(...) + AnyAll, + // [NOT] IN(...) is not defined as it is implicitly evaluated as ANY = (...) / ALL <> (...) +} + impl Subquery { /// Merge schema of main input and correlated subquery columns - pub fn merged_schema(input: &LogicalPlan, subqueries: &[LogicalPlan]) -> DFSchema { - subqueries.iter().fold((**input.schema()).clone(), |a, b| { - let mut res = a; - res.merge(b.schema()); - res - }) + pub fn merged_schema( + input: &LogicalPlan, + subqueries: &[LogicalPlan], + types: &[SubqueryType], + ) -> DFSchema { + subqueries.iter().zip(types.iter()).fold( + (**input.schema()).clone(), + |schema, (plan, typ)| { + let mut schema = schema; + schema.merge(&Self::transform_dfschema(plan.schema(), *typ)); + schema + }, + ) + } + + /// Transform DataFusion schema according to subquery type + pub fn transform_dfschema(schema: &DFSchema, typ: SubqueryType) -> DFSchema { + match typ { + SubqueryType::Scalar => schema.clone(), + SubqueryType::Exists | SubqueryType::AnyAll => { + let new_fields = schema + .fields() + .iter() + .map(|field| { + let new_field = Subquery::transform_field(field.field(), typ); + if let Some(qualifier) = field.qualifier() { + DFField::from_qualified(qualifier, new_field) + } else { + DFField::from(new_field) + } + }) + .collect(); + DFSchema::new_with_metadata(new_fields, schema.metadata().clone()) + .unwrap() + } + } + } + + /// Transform Arrow field according to subquery type + pub fn transform_field(field: &Field, typ: SubqueryType) -> Field { + match typ { + SubqueryType::Scalar => field.clone(), + SubqueryType::Exists => Field::new(field.name(), DataType::Boolean, false), + // ANY/ALL subquery converts subquery result rows into a list + // and uses existing code evaluating ANY with a list to evaluate the result + SubqueryType::AnyAll => { + let item = Field::new_dict( + "item", + field.data_type().clone(), + true, + field.dict_id().unwrap_or(0), + field.dict_is_ordered().unwrap_or(false), + ); + Field::new(field.name(), DataType::List(Box::new(item)), false) + } + } + } +} + +impl Display for SubqueryType { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let name = match self { + Self::Scalar => "Scalar", + Self::Exists => "Exists", + Self::AnyAll => "AnyAll", + }; + write!(f, "{}", name) } } @@ -475,13 +550,23 @@ impl LogicalPlan { LogicalPlan::Values(Values { schema, .. }) => vec![schema], LogicalPlan::Window(Window { input, schema, .. }) | LogicalPlan::Projection(Projection { input, schema, .. }) - | LogicalPlan::Subquery(Subquery { input, schema, .. }) | LogicalPlan::Aggregate(Aggregate { input, schema, .. }) | LogicalPlan::TableUDFs(TableUDFs { input, schema, .. }) => { let mut schemas = input.all_schemas(); schemas.insert(0, schema); schemas } + LogicalPlan::Subquery(Subquery { + input, + subqueries, + schema, + .. + }) => { + let mut schemas = input.all_schemas(); + schemas.extend(subqueries.iter().map(|s| s.schema())); + schemas.insert(0, schema); + schemas + } LogicalPlan::Join(Join { left, right, @@ -1063,7 +1148,9 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Subquery(Subquery { .. }) => write!(f, "Subquery"), + LogicalPlan::Subquery(Subquery { types, .. }) => { + write!(f, "Subquery: types={:?}", types) + } LogicalPlan::Filter(Filter { predicate: ref expr, .. diff --git a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs index f9002fe69c71..9fa278a8b688 100644 --- a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs @@ -508,6 +508,10 @@ impl ExprIdentifierVisitor<'_> { desc.push_str("InList-"); desc.push_str(&negated.to_string()); } + Expr::InSubquery { negated, .. } => { + desc.push_str("InSubquery-"); + desc.push_str(&negated.to_string()); + } Expr::Wildcard => { desc.push_str("Wildcard-"); } diff --git a/datafusion/core/src/optimizer/projection_drop_out.rs b/datafusion/core/src/optimizer/projection_drop_out.rs index e96d365d1e9d..479c9ca917f2 100644 --- a/datafusion/core/src/optimizer/projection_drop_out.rs +++ b/datafusion/core/src/optimizer/projection_drop_out.rs @@ -254,6 +254,7 @@ fn optimize_plan( LogicalPlan::Subquery(Subquery { input, subqueries, + types, schema, }) => { // TODO: subqueries are not optimized @@ -269,6 +270,7 @@ fn optimize_plan( .map(|(p, _)| p)?, ), subqueries: subqueries.clone(), + types: types.clone(), schema: schema.clone(), }), None, diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index 060f10c30ba5..1fc8ac83e3a2 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -453,7 +453,10 @@ fn optimize_plan( })) } LogicalPlan::Subquery(Subquery { - input, subqueries, .. + input, + subqueries, + types, + .. }) => { let mut subquery_required_columns = HashSet::new(); for subquery in subqueries.iter() { @@ -484,11 +487,12 @@ fn optimize_plan( has_projection, _optimizer_config, )?; - let new_schema = Subquery::merged_schema(&input, subqueries); + let new_schema = Subquery::merged_schema(&input, subqueries, types); Ok(LogicalPlan::Subquery(Subquery { input: Arc::new(input), - schema: Arc::new(new_schema), subqueries: subqueries.clone(), + types: types.clone(), + schema: Arc::new(new_schema), })) } // all other nodes: Add any additional columns used by diff --git a/datafusion/core/src/optimizer/simplify_expressions.rs b/datafusion/core/src/optimizer/simplify_expressions.rs index 8ee40bff6ed8..d23a72ca2558 100644 --- a/datafusion/core/src/optimizer/simplify_expressions.rs +++ b/datafusion/core/src/optimizer/simplify_expressions.rs @@ -392,6 +392,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::OuterColumn(_, _) | Expr::WindowFunction { .. } | Expr::Sort { .. } + | Expr::InSubquery { .. } | Expr::Wildcard | Expr::QualifiedWildcard { .. } => false, Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 2e741eb892b3..aa35e2772c9c 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -93,6 +93,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { | Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } | Expr::InList { .. } + | Expr::InSubquery { .. } | Expr::Wildcard | Expr::QualifiedWildcard { .. } | Expr::GetIndexedField { .. } => {} @@ -161,10 +162,11 @@ pub fn from_plan( alias: alias.clone(), })) } - LogicalPlan::Subquery(Subquery { schema, .. }) => { + LogicalPlan::Subquery(Subquery { types, schema, .. }) => { Ok(LogicalPlan::Subquery(Subquery { - subqueries: inputs[1..inputs.len()].to_vec(), input: Arc::new(inputs[0].clone()), + subqueries: inputs[1..inputs.len()].to_vec(), + types: types.clone(), schema: schema.clone(), })) } @@ -390,6 +392,9 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { } Ok(expr_list) } + Expr::InSubquery { expr, subquery, .. } => { + Ok(vec![expr.as_ref().to_owned(), subquery.as_ref().to_owned()]) + } Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -410,10 +415,11 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { op: *op, right: Box::new(expressions[1].clone()), }), - Expr::AnyExpr { op, .. } => Ok(Expr::AnyExpr { + Expr::AnyExpr { op, all, .. } => Ok(Expr::AnyExpr { left: Box::new(expressions[0].clone()), op: *op, right: Box::new(expressions[1].clone()), + all: *all, }), Expr::Like(Like { negated, @@ -597,6 +603,11 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { Ok(expr) } } + Expr::InSubquery { negated, .. } => Ok(Expr::InSubquery { + expr: Box::new(expressions[0].clone()), + subquery: Box::new(expressions[1].clone()), + negated: *negated, + }), Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 8f3b9c261a65..b81bb2162ecb 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -114,10 +114,16 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let right = create_physical_name(right, false)?; Ok(format!("{} {:?} {}", left, op, right)) } - Expr::AnyExpr { left, op, right } => { + Expr::AnyExpr { + left, + op, + right, + all, + } => { let left = create_physical_name(left, false)?; let right = create_physical_name(right, false)?; - Ok(format!("{} {:?} ANY({})", left, op, right)) + let keyword = if *all { "ALL" } else { "ANY" }; + Ok(format!("{} {:?} {}({})", left, op, keyword, right)) } Expr::Case { expr, @@ -204,6 +210,16 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(format!("{} IN ({:?})", expr, list)) } } + Expr::InSubquery { + expr, + subquery, + negated, + } => { + let expr = create_physical_name(expr, false)?; + let subquery = create_physical_name(subquery, false)?; + let negated = if *negated { "NOT " } else { "" }; + Ok(format!("{} {}IN ({})", expr, negated, subquery)) + } Expr::Between { expr, negated, @@ -917,7 +933,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch))) } - LogicalPlan::Subquery(Subquery { subqueries, input, schema }) => { + LogicalPlan::Subquery(Subquery { input, subqueries, types, schema }) => { let cursor = Arc::new(OuterQueryCursor::new(schema.as_ref().to_owned().into())); let mut new_session_state = session_state.clone(); new_session_state.execution_props = new_session_state.execution_props.with_outer_query_cursor(cursor.clone()); @@ -931,7 +947,7 @@ impl DefaultPhysicalPlanner { }) .collect::>(); let input = self.create_initial_plan(input, &new_session_state).await?; - Ok(Arc::new(SubqueryExec::try_new(subqueries, input, cursor)?)) + Ok(Arc::new(SubqueryExec::try_new(input, subqueries, types.clone(), cursor)?)) } LogicalPlan::CreateExternalTable(_) => { // There is no default plan for "CREATE EXTERNAL @@ -1290,7 +1306,12 @@ pub fn create_physical_expr( binary_expr } } - Expr::AnyExpr { left, op, right } => { + Expr::AnyExpr { + left, + op, + right, + all, + } => { let lhs = create_physical_expr( left, input_dfschema, @@ -1303,7 +1324,7 @@ pub fn create_physical_expr( input_schema, execution_props, )?; - any(lhs, *op, rhs, input_schema) + any(lhs, *op, rhs, *all, input_schema) } Expr::InList { expr, diff --git a/datafusion/core/src/physical_plan/subquery.rs b/datafusion/core/src/physical_plan/subquery.rs index c7ffad4cbd55..912c36cf737e 100644 --- a/datafusion/core/src/physical_plan/subquery.rs +++ b/datafusion/core/src/physical_plan/subquery.rs @@ -28,8 +28,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::{Subquery, SubqueryType}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; -use arrow::array::new_null_array; +use arrow::array::{new_null_array, BooleanArray}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; @@ -38,19 +39,21 @@ use super::expressions::PhysicalSortExpr; use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; use crate::execution::context::TaskContext; use async_trait::async_trait; -use datafusion_common::OuterQueryCursor; +use datafusion_common::{OuterQueryCursor, ScalarValue}; use futures::stream::Stream; use futures::stream::StreamExt; /// Execution plan for a sub query #[derive(Debug)] pub struct SubqueryExec { + /// The input plan + input: Arc, /// Sub queries subqueries: Vec>, + /// Subquery types + types: Vec, /// Merged schema schema: SchemaRef, - /// The input plan - input: Arc, /// Cursor used to send outer query column values to sub queries cursor: Arc, } @@ -58,15 +61,23 @@ pub struct SubqueryExec { impl SubqueryExec { /// Create a projection on an input pub fn try_new( - subqueries: Vec>, input: Arc, + subqueries: Vec>, + types: Vec, cursor: Arc, ) -> Result { let input_schema = input.schema(); let mut total_fields = input_schema.fields().clone(); - for q in subqueries.iter() { - total_fields.append(&mut q.schema().fields().clone()); + for (q, t) in subqueries.iter().zip(types.iter()) { + total_fields.append( + &mut q + .schema() + .fields() + .iter() + .map(|f| Subquery::transform_field(f, *t)) + .collect(), + ); } let merged_schema = Schema::new_with_metadata(total_fields, HashMap::new()); @@ -78,9 +89,10 @@ impl SubqueryExec { } Ok(Self { + input, subqueries, + types, schema: Arc::new(merged_schema), - input, cursor, }) } @@ -134,8 +146,9 @@ impl ExecutionPlan for SubqueryExec { } Ok(Arc::new(SubqueryExec::try_new( - children.iter().skip(1).cloned().collect(), children[0].clone(), + children.iter().skip(1).cloned().collect(), + self.types.clone(), self.cursor.clone(), )?)) } @@ -148,74 +161,116 @@ impl ExecutionPlan for SubqueryExec { let stream = self.input.execute(partition, context.clone()).await?; let cursor = self.cursor.clone(); let subqueries = self.subqueries.clone(); + let types = self.types.clone(); let context = context.clone(); let size_hint = stream.size_hint(); let schema = self.schema.clone(); - let res_stream = - stream.then(move |batch| { - let cursor = cursor.clone(); - let context = context.clone(); - let subqueries = subqueries.clone(); - let schema = schema.clone(); - async move { - let batch = batch?; - let b = Arc::new(batch.clone()); - cursor.set_batch(b)?; - let mut subquery_arrays = vec![Vec::new(); subqueries.len()]; - for i in 0..batch.num_rows() { - cursor.set_position(i)?; - for (subquery_i, subquery) in subqueries.iter().enumerate() { - let null_array = || { - let schema = subquery.schema(); - let fields = schema.fields(); - if fields.len() != 1 { - return Err(ArrowError::ComputeError(format!( - "Sub query should have only one column but got {}", - fields.len() - ))); - } - - let data_type = fields.get(0).unwrap().data_type(); - Ok(new_null_array(data_type, 1)) - }; + let res_stream = stream.then(move |batch| { + let cursor = cursor.clone(); + let context = context.clone(); + let subqueries = subqueries.clone(); + let types = types.clone(); + let schema = schema.clone(); + async move { + let batch = batch?; + let b = Arc::new(batch.clone()); + cursor.set_batch(b)?; + let mut subquery_arrays = vec![Vec::new(); subqueries.len()]; + for i in 0..batch.num_rows() { + cursor.set_position(i)?; + for (subquery_i, (subquery, subquery_type)) in + subqueries.iter().zip(types.iter()).enumerate() + { + let schema = subquery.schema(); + let fields = schema.fields(); + if fields.len() != 1 { + return Err(ArrowError::ComputeError(format!( + "Sub query should have only one column but got {}", + fields.len() + ))); + } + let data_type = fields.get(0).unwrap().data_type(); + let null_array = || new_null_array(data_type, 1); - if subquery.output_partitioning().partition_count() != 1 { - return Err(ArrowError::ComputeError(format!( - "Sub query should have only one partition but got {}", - subquery.output_partitioning().partition_count() - ))); - } - let mut stream = subquery.execute(0, context.clone()).await?; - let res = stream.next().await; - if let Some(subquery_batch) = res { - let subquery_batch = subquery_batch?; - match subquery_batch.column(0).len() { - 0 => subquery_arrays[subquery_i].push(null_array()?), + if subquery.output_partitioning().partition_count() != 1 { + return Err(ArrowError::ComputeError(format!( + "Sub query should have only one partition but got {}", + subquery.output_partitioning().partition_count() + ))); + } + let mut stream = subquery.execute(0, context.clone()).await?; + let res = stream.next().await; + if let Some(subquery_batch) = res { + let subquery_batch = subquery_batch?; + match subquery_type { + SubqueryType::Scalar => match subquery_batch + .column(0) + .len() + { + 0 => subquery_arrays[subquery_i].push(null_array()), 1 => subquery_arrays[subquery_i] .push(subquery_batch.column(0).clone()), _ => return Err(ArrowError::ComputeError( "Sub query should return no more than one row" .to_string(), )), - }; - } else { - subquery_arrays[subquery_i].push(null_array()?); - } + }, + SubqueryType::Exists => match subquery_batch + .column(0) + .len() + { + 0 => subquery_arrays[subquery_i] + .push(Arc::new(BooleanArray::from(vec![false]))), + _ => subquery_arrays[subquery_i] + .push(Arc::new(BooleanArray::from(vec![true]))), + }, + SubqueryType::AnyAll => { + let array_ref = subquery_batch.column(0); + // TODO: optimize? + let mut scalars = vec![]; + for i in 0..array_ref.len() { + scalars.push(ScalarValue::try_from_array( + array_ref, i, + )?); + } + let list = ScalarValue::List( + Some(Box::new(scalars)), + Box::new(data_type.clone()), + ); + subquery_arrays[subquery_i].push(list.to_array()); + } + }; + } else { + match subquery_type { + SubqueryType::Scalar => { + subquery_arrays[subquery_i].push(null_array()) + } + SubqueryType::Exists => subquery_arrays[subquery_i] + .push(Arc::new(BooleanArray::from(vec![false]))), + SubqueryType::AnyAll => { + let list = ScalarValue::List( + Some(Box::new(vec![])), + Box::new(data_type.clone()), + ); + subquery_arrays[subquery_i].push(list.to_array()); + } + }; } } - let mut new_columns = batch.columns().to_vec(); - for subquery_array in subquery_arrays { - new_columns.push(concat( - subquery_array - .iter() - .map(|a| a.as_ref()) - .collect::>() - .as_slice(), - )?); - } - RecordBatch::try_new(schema.clone(), new_columns) } - }); + let mut new_columns = batch.columns().to_vec(); + for subquery_array in subquery_arrays { + new_columns.push(concat( + subquery_array + .iter() + .map(|a| a.as_ref()) + .collect::>() + .as_slice(), + )?); + } + RecordBatch::try_new(schema.clone(), new_columns) + } + }); Ok(Box::pin(SubQueryStream { schema: self.schema.clone(), stream: Box::pin(res_stream), diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 440154151154..fa459db43048 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -19,8 +19,9 @@ use std::collections::HashSet; use std::iter; +use std::ops::RangeFrom; use std::str::FromStr; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::{convert::TryInto, vec}; use crate::catalog::TableReference; @@ -32,7 +33,7 @@ use crate::logical_plan::{ and, builder::expand_qualified_wildcard, builder::expand_wildcard, col, lit, normalize_col, rewrite_udtfs_to_columns, Column, CreateMemoryTable, DFSchema, DFSchemaRef, DropTable, Expr, ExprSchemable, Like, LogicalPlan, LogicalPlanBuilder, - Operator, PlanType, ToDFSchema, ToStringifiedPlan, + Operator, PlanType, SubqueryType, ToDFSchema, ToStringifiedPlan, }; use crate::optimizer::utils::exprlist_to_columns; use crate::prelude::JoinType; @@ -97,13 +98,14 @@ pub struct SqlToRel<'a, S: ContextProvider> { schema_provider: &'a S, table_columns_precedence_over_projection: bool, context: SqlToRelContext, + subquery_alias_iter: Arc>>, } /// Planning context #[derive(Default)] pub struct SqlToRelContext { outer_query_context_schema: Vec, - subqueries_plans: Option>>, + subqueries_plans: Option>>, } impl SqlToRelContext { @@ -115,12 +117,12 @@ impl SqlToRelContext { } } - fn add_subquery_plan(&self, plan: LogicalPlan) -> Result<()> { - self.subqueries_plans.as_ref().ok_or_else(|| DataFusionError::Plan(format!("Sub query {:?} planned outside of sub query context. This type of sub query isn't supported", plan)))?.write().unwrap().push(plan); + fn add_subquery_plan(&self, plan: LogicalPlan, typ: SubqueryType) -> Result<()> { + self.subqueries_plans.as_ref().ok_or_else(|| DataFusionError::Plan(format!("Sub query {:?} planned outside of sub query context. This type of sub query isn't supported", plan)))?.write().unwrap().push((plan, typ)); Ok(()) } - fn subqueries_plans(&self) -> Result>> { + fn subqueries_plans(&self) -> Result>> { Ok(if let Some(subqueries) = self.subqueries_plans.as_ref() { Some( subqueries @@ -170,6 +172,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema_provider, table_columns_precedence_over_projection, context: SqlToRelContext::default(), + subquery_alias_iter: Arc::new(Mutex::new(0..)), } } @@ -182,6 +185,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_columns_precedence_over_projection: self .table_columns_precedence_over_projection, context, + subquery_alias_iter: Arc::clone(&self.subquery_alias_iter), } } @@ -877,6 +881,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { selection: Option, plans: Vec, ) -> Result { + // TODO: enable subqueries for joins let plan = match selection { Some(predicate_expr) => { // build join schema @@ -978,6 +983,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // remove join expressions from filter match remove_join_expressions(&filter_expr, &all_join_keys)? { Some(filter_expr) => { + let left = self.wrap_with_subquery_plan_if_necessary(left)?; LogicalPlanBuilder::from(left).filter(filter_expr)?.build() } _ => Ok(left), @@ -1011,7 +1017,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plans.first(), Some(LogicalPlan::EmptyRelation(_))); // process `where` clause - let plan = self.plan_selection(select.selection, plans)?; + let with_where_outer_query_context = + self.with_context(|c| c.subqueries_plans = Some(RwLock::new(Vec::new()))); + let plan = + with_where_outer_query_context.plan_selection(select.selection, plans)?; // process the SELECT expressions, with wildcards expanded. let with_outer_query_context = @@ -1200,8 +1209,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .read() .map_err(|e| DataFusionError::Plan(e.to_string()))?; if !subqueries.is_empty() { + let (subqueries, plans): (Vec<_>, Vec<_>) = + subqueries.clone().into_iter().unzip(); LogicalPlanBuilder::from(plan) - .subquery(subqueries.clone())? + .subquery(subqueries, plans)? .build()? } else { plan @@ -1439,7 +1450,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::Column(col) => match &col.relation { Some(r) => { if let Some(plans) = self.context.subqueries_plans()? { - if plans.into_iter().any(|p| { + if plans.into_iter().any(|(p, _)| { p.schema().field_with_qualified_name(r, &col.name).is_ok() }) { return Ok(()); @@ -1450,7 +1461,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => { if let Some(plans) = self.context.subqueries_plans()? { - if plans.into_iter().any(|p| { + if plans.into_iter().any(|(p, _)| { !p.schema() .fields_with_unqualified_name(&col.name) .is_empty() @@ -1561,13 +1572,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { left: SQLExpr, op: BinaryOperator, right: SQLExpr, + all: bool, schema: &DFSchema, ) -> Result { let operator = match op { BinaryOperator::Eq => Ok(Operator::Eq), BinaryOperator::NotEq => Ok(Operator::NotEq), + BinaryOperator::Lt => Ok(Operator::Lt), + BinaryOperator::LtEq => Ok(Operator::LtEq), + BinaryOperator::Gt => Ok(Operator::Gt), + BinaryOperator::GtEq => Ok(Operator::GtEq), _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported SQL ANY operator {:?}", + "Unsupported SQL ANY/ALL operator {:?}", op ))), }?; @@ -1576,6 +1592,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { left: Box::new(self.sql_expr_to_logical_expr(left, schema)?), op: operator, right: Box::new(self.sql_expr_to_logical_expr(right, schema)?), + all, }) } @@ -1588,13 +1605,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match right { SQLExpr::AnyOp(any_expr) => { - return self.parse_sql_binary_any(left, op, *any_expr, schema); + return self.parse_sql_binary_any(left, op, *any_expr, false, schema); } - SQLExpr::AllOp(_) => { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported SQL ALL operator {:?}", - right - ))); + SQLExpr::AllOp(any_expr) => { + return self.parse_sql_binary_any(left, op, *any_expr, true, schema); } _ => {} }; @@ -2277,19 +2291,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema), - SQLExpr::Subquery(q) => { - let with_outer_query_context = self.with_context(|c| c.outer_query_context_schema.push(Arc::new(schema.clone()))); - let alias_name = format!("subquery-{}", self.context.subqueries_plans().unwrap_or_default().unwrap_or_default().len()); - let plan = with_outer_query_context.query_to_plan_with_alias(*q, Some(alias_name), &mut HashMap::new())?; + SQLExpr::Subquery(q) => self.subquery_to_plan(q, SubqueryType::Scalar, schema), - let fields = plan.schema().fields(); - if fields.len() != 1 { - return Err(DataFusionError::Plan(format!("Correlated sub query requires only one column in result set but found: {:?}", fields))); - } - let column = fields.iter().next().unwrap().qualified_column(); - self.context.add_subquery_plan(plan)?; - Ok(Expr::Column(column)) - } + SQLExpr::AnyAllSubquery(q) => self.subquery_to_plan(q, SubqueryType::AnyAll, schema), + + // InSubquery uses `AnyAll` since it's expected to be replaced + SQLExpr::InSubquery { expr, subquery, negated } => Ok(Expr::InSubquery { + expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?), + subquery: Box::new(self.subquery_to_plan(subquery, SubqueryType::AnyAll, schema)?), + negated, + }), SQLExpr::DotExpr { expr, field } => { Ok(Expr::GetIndexedField { @@ -2298,11 +2309,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) } - // FIXME: Exists is unsupported but all the queries we need return false - SQLExpr::Exists(_) => { - warn!("EXISTS(...) is not supported yet. Replacing with scalar `false` value."); - Ok(Expr::Literal(ScalarValue::Boolean(Some(false)))) - } + SQLExpr::Exists(q) => self.subquery_to_plan(q, SubqueryType::Exists, schema), // FIXME: ArraySubquery is unsupported but all the queries we need return empty array SQLExpr::ArraySubquery(_) => { @@ -2714,6 +2721,46 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } } + + fn subquery_to_plan( + &self, + query: Box, + subquery_type: SubqueryType, + schema: &DFSchema, + ) -> Result { + let with_outer_query_context = self.with_context(|c| { + c.outer_query_context_schema.push(Arc::new(schema.clone())) + }); + let alias_name = { + let mut subquery_alias_iter = with_outer_query_context + .subquery_alias_iter + .lock() + .map_err(|_| { + DataFusionError::Plan( + "Unable to lock subquery alias iterator".to_string(), + ) + })?; + let alias_index = subquery_alias_iter.next().ok_or_else(|| { + DataFusionError::Plan( + "Unable to assign an alias to a subquery".to_string(), + ) + })?; + format!("__subquery-{}", alias_index) + }; + let plan = with_outer_query_context.query_to_plan_with_alias( + *query, + Some(alias_name), + &mut HashMap::new(), + )?; + + let fields = plan.schema().fields(); + if fields.len() != 1 { + return Err(DataFusionError::Plan(format!("Correlated sub query requires only one column in result set but found: {:?}", fields))); + } + let column = fields.iter().next().unwrap().qualified_column(); + self.context.add_subquery_plan(plan, subquery_type)?; + Ok(Expr::Column(column)) + } } /// Normalize a SQL object name @@ -4876,29 +4923,137 @@ mod tests { } #[test] - fn subquery() { + fn subquery_select() { let sql = "select person.id, (select lineitem.l_item_id from lineitem where person.id = lineitem.l_item_id limit 1) from person"; - let expected = "Projection: #person.id, #subquery-0.l_item_id\ - \n Subquery\ + let expected = "Projection: #person.id, #__subquery-0.l_item_id\ + \n Subquery: types=[Scalar]\ \n TableScan: person projection=None\ \n Limit: skip=None, fetch=1\ - \n Projection: #lineitem.l_item_id, alias=subquery-0\ + \n Projection: #lineitem.l_item_id, alias=__subquery-0\ \n Filter: ^#person.id = #lineitem.l_item_id\ \n TableScan: lineitem projection=None"; quick_test(sql, expected); } #[test] - fn subquery_no_from() { + fn subquery_select_without_from() { let sql = "select person.id, (select person.age + 1) from person"; - let expected = "Projection: #person.id, #subquery-0.person.age + Int64(1)\ - \n Subquery\ + let expected = "Projection: #person.id, #__subquery-0.person.age + Int64(1)\ + \n Subquery: types=[Scalar]\ \n TableScan: person projection=None\ - \n Projection: ^#person.age + Int64(1), alias=subquery-0\ + \n Projection: ^#person.age + Int64(1), alias=__subquery-0\ \n EmptyRelation"; quick_test(sql, expected); } + #[test] + fn subquery_where() { + let sql = "select person.id from person where person.id > (select lineitem.l_item_id from lineitem limit 1)"; + let expected = "Projection: #person.id\ + \n Filter: #person.id > #__subquery-0.l_item_id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Limit: skip=None, fetch=1\ + \n Projection: #lineitem.l_item_id, alias=__subquery-0\ + \n TableScan: lineitem projection=None"; + quick_test(sql, expected); + } + + #[test] + fn subquery_where_without_from() { + let sql = "select person.id from person where person.id = (select person.id)"; + let expected = "Projection: #person.id\ + \n Filter: #person.id = #__subquery-0.person.id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Projection: ^#person.id, alias=__subquery-0\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_select_and_where() { + let sql = "select person.id, (select person.id) from person where person.id > (select lineitem.l_item_id from lineitem limit 1)"; + let expected = "Projection: #person.id, #__subquery-1.person.id\ + \n Subquery: types=[Scalar]\ + \n Filter: #person.id > #__subquery-0.l_item_id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Limit: skip=None, fetch=1\ + \n Projection: #lineitem.l_item_id, alias=__subquery-0\ + \n TableScan: lineitem projection=None\ + \n Projection: ^#person.id, alias=__subquery-1\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_select_and_where_without_from() { + let sql = "select person.id, (select person.id) from person where person.id = (select person.id)"; + let expected = "Projection: #person.id, #__subquery-1.person.id\ + \n Subquery: types=[Scalar]\ + \n Filter: #person.id = #__subquery-0.person.id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Projection: ^#person.id, alias=__subquery-0\ + \n EmptyRelation\ + \n Projection: ^#person.id, alias=__subquery-1\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_exists() { + let sql = "select person.id, exists(select person.id) from person where exists(select 1 where false)"; + let expected = "Projection: #person.id, #__subquery-1.person.id\ + \n Subquery: types=[Exists]\ + \n Filter: #__subquery-0.Int64(1)\ + \n Subquery: types=[Exists]\ + \n TableScan: person projection=None\ + \n Projection: Int64(1), alias=__subquery-0\ + \n Filter: Boolean(false)\ + \n EmptyRelation\ + \n Projection: ^#person.id, alias=__subquery-1\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_any() { + let sql = "select person.id from person where person.id = any(select person.id from person)"; + let expected = "Projection: #person.id\ + \n Filter: #person.id = ANY(#__subquery-0.person.id)\ + \n Subquery: types=[AnyAll]\ + \n TableScan: person projection=None\ + \n Projection: ^#person.id, alias=__subquery-0\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn subquery_all() { + let sql = "select person.id, person.id = all(select person.id) from person"; + let expected = + "Projection: #person.id, #person.id = ALL(#__subquery-0.person.id)\ + \n Subquery: types=[AnyAll]\ + \n TableScan: person projection=None\ + \n Projection: ^#person.id, alias=__subquery-0\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_in() { + let sql = + "select person.id, person.id in (select person.id from person) from person"; + let expected = "Projection: #person.id, #person.id IN (#__subquery-0.person.id)\ + \n Subquery: types=[AnyAll]\ + \n TableScan: person projection=None\ + \n Projection: ^#person.id, alias=__subquery-0\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + #[test] fn join_on_disjunction_condition() { let sql = "SELECT id, order_id \ diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index aad21190a66a..deb8d700d619 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -293,15 +293,30 @@ where .collect::>>()?, negated: *negated, }), + Expr::InSubquery { + expr: nested_expr, + subquery, + negated, + } => Ok(Expr::InSubquery { + expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?), + subquery: Box::new(clone_with_replacement(&**subquery, replacement_fn)?), + negated: *negated, + }), Expr::BinaryExpr { left, right, op } => Ok(Expr::BinaryExpr { left: Box::new(clone_with_replacement(&**left, replacement_fn)?), op: *op, right: Box::new(clone_with_replacement(&**right, replacement_fn)?), }), - Expr::AnyExpr { left, right, op } => Ok(Expr::AnyExpr { + Expr::AnyExpr { + left, + right, + op, + all, + } => Ok(Expr::AnyExpr { left: Box::new(clone_with_replacement(&**left, replacement_fn)?), op: *op, right: Box::new(clone_with_replacement(&**right, replacement_fn)?), + all: *all, }), Expr::Like(Like { negated, diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 388e95857e95..e9dd80789fc6 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -1015,7 +1015,6 @@ async fn test_extract_date_part() -> Result<()> { #[tokio::test] async fn test_binary_any() -> Result<()> { - // = // int64 test_expression!("1 = ANY([1, 2])", "true"); test_expression!("3 = ANY([1, 2])", "false"); @@ -1026,17 +1025,140 @@ async fn test_binary_any() -> Result<()> { // utf8 test_expression!("'a' = ANY(['a', 'b'])", "true"); test_expression!("'c' = ANY(['a', 'b'])", "false"); + test_expression!("'c' = ANY(['a', NULL])", "NULL"); // bool test_expression!("true = ANY([true, false])", "true"); test_expression!("false = ANY([true, false])", "true"); test_expression!("false = ANY([true, true])", "false"); + // = + test_expression!("3 = ANY([1, 2])", "false"); + test_expression!("1 = ANY([1, 2])", "true"); + test_expression!("NULL = ANY([1, 2])", "NULL"); + test_expression!("'c' = ANY(['a', 'b'])", "false"); + test_expression!("'a' = ANY(['a', 'a'])", "true"); + test_expression!("'b' = ANY([NULL, 'b'])", "true"); + test_expression!("true = ANY([false, true])", "true"); + test_expression!("false = ANY([false, true])", "true"); // <> test_expression!("3 <> ANY([1, 2])", "true"); - test_expression!("1 <> ANY([1, 2])", "false"); - test_expression!("2 <> ANY([1, 2])", "false"); - test_expression!("NULL = ANY([1, 2])", "NULL"); + test_expression!("1 <> ANY([1, 2])", "true"); + test_expression!("NULL <> ANY([1, 2])", "NULL"); test_expression!("'c' <> ANY(['a', 'b'])", "true"); - test_expression!("'a' <> ANY(['a', 'b'])", "false"); + test_expression!("'a' <> ANY(['a', 'a'])", "false"); + test_expression!("'b' <> ANY([NULL, 'b'])", "NULL"); + test_expression!("true <> ANY([false, true])", "true"); + test_expression!("false <> ANY([false, true])", "true"); + // < + test_expression!("3 < ANY([1, 2])", "false"); + test_expression!("1 < ANY([1, 2])", "true"); + test_expression!("NULL < ANY([1, 2])", "NULL"); + test_expression!("'c' < ANY(['a', 'b'])", "false"); + test_expression!("'a' < ANY(['a', 'a'])", "false"); + test_expression!("'b' < ANY([NULL, 'b'])", "NULL"); + test_expression!("true < ANY([false, true])", "false"); + test_expression!("false < ANY([false, true])", "true"); + // <= + test_expression!("3 <= ANY([1, 2])", "false"); + test_expression!("1 <= ANY([1, 2])", "true"); + test_expression!("NULL <= ANY([1, 2])", "NULL"); + test_expression!("'c' <= ANY(['a', 'b'])", "false"); + test_expression!("'a' <= ANY(['a', 'a'])", "true"); + test_expression!("'b' <= ANY([NULL, 'b'])", "true"); + test_expression!("true <= ANY([false, true])", "true"); + test_expression!("false <= ANY([false, true])", "true"); + // > + test_expression!("3 > ANY([1, 2])", "true"); + test_expression!("1 > ANY([1, 2])", "false"); + test_expression!("NULL > ANY([1, 2])", "NULL"); + test_expression!("'c' > ANY(['a', 'b'])", "true"); + test_expression!("'a' > ANY(['a', 'a'])", "false"); + test_expression!("'b' > ANY([NULL, 'b'])", "NULL"); + test_expression!("true > ANY([false, true])", "true"); + test_expression!("false > ANY([false, true])", "false"); + // >= + test_expression!("3 >= ANY([1, 2])", "true"); + test_expression!("1 >= ANY([1, 2])", "true"); + test_expression!("NULL >= ANY([1, 2])", "NULL"); + test_expression!("'c' >= ANY(['a', 'b'])", "true"); + test_expression!("'a' >= ANY(['a', 'a'])", "true"); + test_expression!("'b' >= ANY([NULL, 'b'])", "true"); + test_expression!("true >= ANY([false, true])", "true"); + test_expression!("false >= ANY([false, true])", "true"); + + Ok(()) +} + +#[tokio::test] +async fn test_binary_all() -> Result<()> { + // int64 + test_expression!("1 = ALL([1, 2])", "false"); + test_expression!("3 = ALL([3, 3])", "true"); + test_expression!("NULL = ALL([1, 2])", "NULL"); + // float + test_expression!("1.0 = ALL([1.0, 2.0])", "false"); + test_expression!("3.0 = ALL([3.0, 3.0])", "true"); + // utf8 + test_expression!("'a' = ALL(['a', 'b'])", "false"); + test_expression!("'c' = ALL(['c', 'c'])", "true"); + test_expression!("'c' = ALL(['a', NULL])", "false"); + // bool + test_expression!("true = ALL([true, false])", "false"); + test_expression!("false = ALL([true, false])", "false"); + test_expression!("true = ALL([true, true])", "true"); + // = + test_expression!("3 = ALL([1, 2])", "false"); + test_expression!("1 = ALL([1, 2])", "false"); + test_expression!("NULL = ALL([1, 2])", "NULL"); + test_expression!("'c' = ALL(['a', 'b'])", "false"); + test_expression!("'a' = ALL(['a', 'a'])", "true"); + test_expression!("'b' = ALL([NULL, 'b'])", "NULL"); + test_expression!("true = ALL([false, true])", "false"); + test_expression!("false = ALL([false, true])", "false"); + // <> + test_expression!("3 <> ALL([1, 2])", "true"); + test_expression!("1 <> ALL([1, 2])", "false"); + test_expression!("NULL <> ALL([1, 2])", "NULL"); + test_expression!("'c' <> ALL(['a', 'b'])", "true"); + test_expression!("'a' <> ALL(['a', 'a'])", "false"); + test_expression!("'b' <> ALL([NULL, 'b'])", "false"); + test_expression!("true <> ALL([false, true])", "false"); + test_expression!("false <> ALL([false, true])", "false"); + // < + test_expression!("3 < ALL([1, 2])", "false"); + test_expression!("1 < ALL([1, 2])", "false"); + test_expression!("NULL < ALL([1, 2])", "NULL"); + test_expression!("'c' < ALL(['a', 'b'])", "false"); + test_expression!("'a' < ALL(['a', 'a'])", "false"); + test_expression!("'b' < ALL([NULL, 'b'])", "false"); + test_expression!("true < ALL([false, true])", "false"); + test_expression!("false < ALL([false, true])", "false"); + // <= + test_expression!("3 <= ALL([1, 2])", "false"); + test_expression!("1 <= ALL([1, 2])", "true"); + test_expression!("NULL <= ALL([1, 2])", "NULL"); + test_expression!("'c' <= ALL(['a', 'b'])", "false"); + test_expression!("'a' <= ALL(['a', 'a'])", "true"); + test_expression!("'b' <= ALL([NULL, 'b'])", "NULL"); + test_expression!("true <= ALL([false, true])", "false"); + test_expression!("false <= ALL([false, true])", "true"); + // > + test_expression!("3 > ALL([1, 2])", "true"); + test_expression!("1 > ALL([1, 2])", "false"); + test_expression!("NULL > ALL([1, 2])", "NULL"); + test_expression!("'c' > ALL(['a', 'b'])", "true"); + test_expression!("'a' > ALL(['a', 'a'])", "false"); + test_expression!("'b' > ALL([NULL, 'b'])", "false"); + test_expression!("true > ALL([false, true])", "false"); + test_expression!("false > ALL([false, true])", "false"); + // >= + test_expression!("3 >= ALL([1, 2])", "true"); + test_expression!("1 >= ALL([1, 2])", "false"); + test_expression!("NULL >= ALL([1, 2])", "NULL"); + test_expression!("'c' >= ALL(['a', 'b'])", "true"); + test_expression!("'a' >= ALL(['a', 'a'])", "true"); + test_expression!("'b' >= ALL([NULL, 'b'])", "NULL"); + test_expression!("true >= ALL([false, true])", "true"); + test_expression!("false >= ALL([false, true])", "false"); Ok(()) } diff --git a/datafusion/core/tests/sql/subquery.rs b/datafusion/core/tests/sql/subquery.rs index fea43d7d0b41..9b5d508c41d4 100644 --- a/datafusion/core/tests/sql/subquery.rs +++ b/datafusion/core/tests/sql/subquery.rs @@ -18,7 +18,7 @@ use super::*; #[tokio::test] -async fn subquery_no_from() -> Result<()> { +async fn subquery_select_no_from() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_simple_csv(&ctx).await?; @@ -39,7 +39,7 @@ async fn subquery_no_from() -> Result<()> { } #[tokio::test] -async fn subquery_with_from() -> Result<()> { +async fn subquery_select_with_from() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_simple_csv(&ctx).await?; @@ -79,3 +79,161 @@ async fn subquery_projection_pushdown() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn subquery_where_with_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT DISTINCT c1 FROM aggregate_simple o WHERE (SELECT c3 FROM aggregate_simple p WHERE o.c1 = p.c1 LIMIT 1) ORDER BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00001 |", + "| 0.00003 |", + "| 0.00005 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn subquery_where_no_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = + "SELECT DISTINCT c1 FROM aggregate_simple o WHERE (SELECT NOT c3) ORDER BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00002 |", + "| 0.00004 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +// TODO: plans but does not execute +#[ignore] +#[tokio::test] +async fn subquery_select_and_where_with_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT c1, (SELECT c1 + 1) FROM aggregate_simple o WHERE (SELECT c3 FROM aggregate_simple p WHERE o.c1 = p.c1 LIMIT 1) ORDER BY c1 LIMIT 2"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+------------------+", + "| c1 | c1 Plus Int64(1) |", + "+---------+------------------+", + "| 0.00001 | 1.00001 |", + "| 0.00003 | 1.00003 |", + "+---------+------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +// TODO: plans but does not execute +#[ignore] +#[tokio::test] +async fn subquery_select_and_where_no_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT c1, (SELECT c1 + 1) FROM aggregate_simple o WHERE (SELECT NOT c3) ORDER BY c1 LIMIT 2"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+------------------+", + "| c1 | c1 Plus Int64(1) |", + "+---------+------------------+", + "| 0.00002 | 1.00002 |", + "| 0.00004 | 1.00004 |", + "+---------+------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn subquery_exists() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT DISTINCT c1 FROM aggregate_simple o WHERE EXISTS(SELECT 1 FROM aggregate_simple p WHERE o.c1 * 2 = p.c1) ORDER BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00001 |", + "| 0.00002 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +// TODO: subquery ignores WHERE for some reason and returns all data +#[ignore] +#[tokio::test] +async fn subquery_any() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT DISTINCT c1 FROM aggregate_simple o WHERE c1 = ANY(SELECT c1 FROM aggregate_simple p WHERE c3) ORDER BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00001 |", + "| 0.00003 |", + "| 0.00005 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +// TODO: subqueries don't work with ORDER BY +#[ignore] +#[tokio::test] +async fn subquery_all() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT DISTINCT c1 FROM aggregate_simple o WHERE c1 > ALL(SELECT DISTINCT c1 FROM aggregate_simple p ORDER BY c1 LIMIT 3) ORDER BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00004 |", + "| 0.00005 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 927785ce0f30..27c2efcc2ef0 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,4 +38,4 @@ path = "src/lib.rs" ahash = { version = "0.7", default-features = false } arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "096ef28dde6b1ae43ce89ba2c3a9d98295f2972e", features = ["prettyprint"] } datafusion-common = { path = "../common", version = "7.0.0" } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "b3b40586d4c32a218ffdfcb0462e7e216cf3d6eb" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "2229652dc8fae8f45cbec344b4a1e40cf1bb69d9" } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d007780dd9a7..5c6297da9195 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -109,6 +109,8 @@ pub enum Expr { op: Operator, /// Right-hand side of the expression right: Box, + /// Whether it's an ALL expression (as opposed to ANY/SOME) + all: bool, }, /// LIKE expression Like(Like), @@ -251,6 +253,15 @@ pub enum Expr { /// Whether the expression is negated negated: bool, }, + /// IN subquery + InSubquery { + /// The expression to compare + expr: Box, + /// subquery that will produce a single column of data to compare against + subquery: Box, + /// Whether the expression is negated + negated: bool, + }, /// Represents a reference to all fields in a schema. Wildcard, /// Represents a reference to all fields in a specific schema. @@ -516,8 +527,14 @@ impl fmt::Debug for Expr { Expr::BinaryExpr { left, op, right } => { write!(f, "{:?} {} {:?}", left, op, right) } - Expr::AnyExpr { left, op, right } => { - write!(f, "{:?} {} ANY({:?})", left, op, right) + Expr::AnyExpr { + left, + op, + right, + all, + } => { + let keyword = if *all { "ALL" } else { "ANY" }; + write!(f, "{:?} {} {}({:?})", left, op, keyword, right) } Expr::Sort { expr, @@ -649,6 +666,14 @@ impl fmt::Debug for Expr { write!(f, "{:?} IN ({:?})", expr, list) } } + Expr::InSubquery { + expr, + subquery, + negated, + } => { + let negated = if *negated { "NOT " } else { "" }; + write!(f, "{:?} {}IN ({:?})", expr, negated, subquery) + } Expr::Wildcard => write!(f, "*"), Expr::QualifiedWildcard { qualifier } => write!(f, "{}.*", qualifier), Expr::GetIndexedField { ref expr, key } => { @@ -709,10 +734,16 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { let right = create_name(right, input_schema)?; Ok(format!("{} {} {}", left, op, right)) } - Expr::AnyExpr { left, op, right } => { + Expr::AnyExpr { + left, + op, + right, + all, + } => { let left = create_name(left, input_schema)?; let right = create_name(right, input_schema)?; - Ok(format!("{} {} ANY({})", left, op, right)) + let keyword = if *all { "ALL" } else { "ANY" }; + Ok(format!("{} {} {}({})", left, op, keyword, right)) } Expr::Like(Like { negated, @@ -884,6 +915,16 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Ok(format!("{} IN ({:?})", expr, list)) } } + Expr::InSubquery { + expr, + subquery, + negated, + } => { + let expr = create_name(expr, input_schema)?; + let subquery = create_name(subquery, input_schema)?; + let negated = if *negated { "NOT " } else { "" }; + Ok(format!("{} {}IN ({})", expr, negated, subquery)) + } Expr::Between { expr, negated, diff --git a/datafusion/physical-expr/src/expressions/any.rs b/datafusion/physical-expr/src/expressions/any.rs index e19e2435f782..8e516d0e3383 100644 --- a/datafusion/physical-expr/src/expressions/any.rs +++ b/datafusion/physical-expr/src/expressions/any.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Any expression +//! Any/All expression use std::any::Any; use std::sync::Arc; @@ -38,45 +38,33 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Operator}; macro_rules! compare_op_scalar { - ($LEFT: expr, $LIST_VALUES:expr, $OP:expr, $LIST_VALUES_TYPE:ty, $LIST_FROM_SCALAR: expr) => {{ - let mut builder = BooleanBuilder::new($LEFT.len()); - - if $LIST_FROM_SCALAR { - for i in 0..$LEFT.len() { - if $LEFT.is_null(i) { - builder.append_null()?; - } else { - if $LIST_VALUES.is_null(0) { - builder.append_null()?; - } else { - builder.append_value($OP( - $LEFT.value(i), - $LIST_VALUES - .value(0) - .as_any() - .downcast_ref::<$LIST_VALUES_TYPE>() - .unwrap(), - ))?; - } - } - } + ($LEFT:expr, $LIST_VALUES:expr, $OP:expr, $LIST_VALUES_TYPE:ty, $LIST_FROM_SCALAR:expr, $VALUE_FROM_SCALAR:expr) => {{ + let len = if $VALUE_FROM_SCALAR { + $LIST_VALUES.len() } else { - for i in 0..$LEFT.len() { - if $LEFT.is_null(i) { + $LEFT.len() + }; + let mut builder = BooleanBuilder::new(len); + + for i in 0..len { + let left_i = if $VALUE_FROM_SCALAR { 0 } else { i }; + let list_i = if $LIST_FROM_SCALAR { 0 } else { i }; + + if $LIST_VALUES.is_null(list_i) { + builder.append_value(false)?; + } else { + let list_values = $LIST_VALUES.value(list_i); + let list_values = list_values + .as_any() + .downcast_ref::<$LIST_VALUES_TYPE>() + .unwrap(); + if list_values.is_empty() { + builder.append_value(false)?; + } else if $LEFT.is_null(left_i) { builder.append_null()?; } else { - if $LIST_VALUES.is_null(i) { - builder.append_null()?; - } else { - builder.append_value($OP( - $LEFT.value(i), - $LIST_VALUES - .value(i) - .as_any() - .downcast_ref::<$LIST_VALUES_TYPE>() - .unwrap(), - ))?; - } + let result = $OP($LEFT.value(left_i), list_values); + builder.append_option(result)?; } } } @@ -86,142 +74,207 @@ macro_rules! compare_op_scalar { } macro_rules! make_primitive { - ($VALUES:expr, $IN_VALUES:expr, $NEGATED:expr, $TYPE:ident, $LIST_FROM_SCALAR: expr) => {{ + ($VALUES:expr, $IN_VALUES:expr, $OP:expr, $TYPE:ident, $LIST_FROM_SCALAR:expr, $VALUE_FROM_SCALAR:expr, $ALL:expr) => {{ let left = $VALUES.as_any().downcast_ref::<$TYPE>().expect(&format!( "Unable to downcast values to {}", stringify!($TYPE) )); - if $NEGATED { - Ok(ColumnarValue::Array(Arc::new(neq_primitive( - left, - $IN_VALUES, - $LIST_FROM_SCALAR, - )?))) - } else { - Ok(ColumnarValue::Array(Arc::new(eq_primitive( - left, - $IN_VALUES, - $LIST_FROM_SCALAR, - )?))) - } + Ok(ColumnarValue::Array(Arc::new(compare_primitive( + left, + $IN_VALUES, + $LIST_FROM_SCALAR, + $VALUE_FROM_SCALAR, + $OP, + $ALL, + )?))) }}; } -fn eq_primitive( +fn compare_primitive( array: &PrimitiveArray, list: &ListArray, list_from_scalar: bool, + value_from_scalar: bool, + op: Operator, + all: bool, ) -> Result { + macro_rules! comparator_primitive { + ($($OP:pat = ($FN:tt),)*) => { + match op { + $( + $OP => if all { + |x, v: &PrimitiveArray| wrap_option_primitive(v, true, v.values().iter().all(|y| &x $FN y)) + } else { + |x, v: &PrimitiveArray| wrap_option_primitive(v, false, v.values().iter().any(|y| &x $FN y)) + }, + )* + op => return unsupported_op(op), + } + }; + } + let fun = comparator_primitive!( + Operator::Eq = (==), + Operator::NotEq = (!=), + Operator::Lt = (<), + Operator::LtEq = (<=), + Operator::Gt = (>), + Operator::GtEq = (>=), + ); compare_op_scalar!( array, list, - |x, v: &PrimitiveArray| v.values().contains(&x), + fun, PrimitiveArray, - list_from_scalar + list_from_scalar, + value_from_scalar ) } -fn neq_primitive( - array: &PrimitiveArray, - list: &ListArray, - list_from_scalar: bool, -) -> Result { - compare_op_scalar!( - array, - list, - |x, v: &PrimitiveArray| !v.values().contains(&x), - PrimitiveArray, - list_from_scalar - ) +fn wrap_option_primitive( + v: &PrimitiveArray, + all: bool, + result: bool, +) -> Option { + if result != all { + return Some(!all); + } + if v.null_count() > 0 { + return None; + } + Some(all) } -fn eq_bool( +fn compare_bool( array: &BooleanArray, list: &ListArray, list_from_scalar: bool, + value_from_scalar: bool, + op: Operator, + all: bool, ) -> Result { + macro_rules! comparator_bool { + ($($OP:pat = ($FN:tt inverted $IFN:tt),)*) => { + match op { + $( + $OP => if all { + |x, v: &BooleanArray| { + for i in 0..v.len() { + if !v.is_null(i) && x $IFN v.value(i) { + return Some(false) + } + } + wrap_option_bool(v, true) + } + } else { + |x, v: &BooleanArray| { + for i in 0..v.len() { + if !v.is_null(i) && x $FN v.value(i) { + return Some(true) + } + } + wrap_option_bool(v, false) + } + } + )* + op => return unsupported_op(op), + } + }; + } + let fun = comparator_bool!( + Operator::Eq = (== inverted !=), + Operator::NotEq = (!= inverted ==), + Operator::Lt = (< inverted >=), + Operator::LtEq = (<= inverted >), + Operator::Gt = (> inverted <=), + Operator::GtEq = (>= inverted <), + ); compare_op_scalar!( array, list, - |x, v: &BooleanArray| unsafe { - for i in 0..v.len() { - if v.value_unchecked(i) == x { - return true; - } - } - - false - }, + fun, BooleanArray, - list_from_scalar + list_from_scalar, + value_from_scalar ) } -fn neq_bool( - array: &BooleanArray, - list: &ListArray, - list_from_scalar: bool, -) -> Result { - compare_op_scalar!( - array, - list, - |x, v: &BooleanArray| unsafe { - for i in 0..v.len() { - if v.value_unchecked(i) == x { - return false; - } - } - - true - }, - BooleanArray, - list_from_scalar - ) +fn wrap_option_bool(v: &BooleanArray, all: bool) -> Option { + if v.null_count() > 0 { + return None; + } + Some(all) } -fn eq_utf8( +fn compare_utf8( array: &GenericStringArray, list: &ListArray, list_from_scalar: bool, + value_from_scalar: bool, + op: Operator, + all: bool, ) -> Result { + macro_rules! comparator_utf8 { + ($($OP:pat = ($FN:tt inverted $IFN:tt),)*) => { + match op { + $( + $OP => if all { + |x, v: &GenericStringArray| { + for i in 0..v.len() { + if !v.is_null(i) && x $IFN v.value(i) { + return Some(false) + } + } + wrap_option_utf8(v, true) + } + } else { + |x, v: &GenericStringArray| { + for i in 0..v.len() { + if !v.is_null(i) && x $FN v.value(i) { + return Some(true) + } + } + wrap_option_utf8(v, false) + } + } + )* + op => return unsupported_op(op), + } + }; + } + let fun = comparator_utf8!( + Operator::Eq = (== inverted !=), + Operator::NotEq = (!= inverted ==), + Operator::Lt = (< inverted >=), + Operator::LtEq = (<= inverted >), + Operator::Gt = (> inverted <=), + Operator::GtEq = (>= inverted <), + ); compare_op_scalar!( array, list, - |x, v: &GenericStringArray| unsafe { - for i in 0..v.len() { - if v.value_unchecked(i) == x { - return true; - } - } - - false - }, + fun, GenericStringArray, - list_from_scalar + list_from_scalar, + value_from_scalar ) } -fn neq_utf8( - array: &GenericStringArray, - list: &ListArray, - list_from_scalar: bool, -) -> Result { - compare_op_scalar!( - array, - list, - |x, v: &GenericStringArray| unsafe { - for i in 0..v.len() { - if v.value_unchecked(i) == x { - return false; - } - } +fn wrap_option_utf8( + v: &GenericStringArray, + all: bool, +) -> Option { + if v.null_count() > 0 { + return None; + } + Some(all) +} - true - }, - GenericStringArray, - list_from_scalar - ) +fn unsupported_op(op: Operator) -> Result { + Err(DataFusionError::Execution(format!( + "ANY/ALL does not support operator {}", + op + ))) } /// AnyExpr @@ -230,6 +283,7 @@ pub struct AnyExpr { value: Arc, op: Operator, list: Arc, + all: bool, } impl AnyExpr { @@ -238,8 +292,14 @@ impl AnyExpr { value: Arc, op: Operator, list: Arc, + all: bool, ) -> Self { - Self { value, op, list } + Self { + value, + op, + list, + all, + } } /// Compare for specific utf8 types @@ -247,27 +307,24 @@ impl AnyExpr { &self, array: ArrayRef, list: &ListArray, - negated: bool, list_from_scalar: bool, + value_from_scalar: bool, + op: Operator, + all: bool, ) -> Result { let array = array .as_any() .downcast_ref::>() .unwrap(); - if negated { - Ok(ColumnarValue::Array(Arc::new(neq_utf8( - array, - list, - list_from_scalar, - )?))) - } else { - Ok(ColumnarValue::Array(Arc::new(eq_utf8( - array, - list, - list_from_scalar, - )?))) - } + Ok(ColumnarValue::Array(Arc::new(compare_utf8( + array, + list, + list_from_scalar, + value_from_scalar, + op, + all, + )?))) } /// Compare for specific utf8 types @@ -275,30 +332,28 @@ impl AnyExpr { &self, array: ArrayRef, list: &ListArray, - negated: bool, list_from_scalar: bool, + value_from_scalar: bool, + op: Operator, + all: bool, ) -> Result { let array = array.as_any().downcast_ref::().unwrap(); - if negated { - Ok(ColumnarValue::Array(Arc::new(neq_bool( - array, - list, - list_from_scalar, - )?))) - } else { - Ok(ColumnarValue::Array(Arc::new(eq_bool( - array, - list, - list_from_scalar, - )?))) - } + Ok(ColumnarValue::Array(Arc::new(compare_bool( + array, + list, + list_from_scalar, + value_from_scalar, + op, + all, + )?))) } } impl std::fmt::Display for AnyExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{} {} ANY({})", self.value, self.op, self.list) + let keyword = if self.all { "ALL" } else { "ANY" }; + write!(f, "{} {} {}({})", self.value, self.op, keyword, self.list) } } @@ -312,14 +367,14 @@ impl PhysicalExpr for AnyExpr { Ok(DataType::Boolean) } - fn nullable(&self, input_schema: &Schema) -> Result { - self.value.nullable(input_schema) + fn nullable(&self, _: &Schema) -> Result { + Ok(true) } fn evaluate(&self, batch: &RecordBatch) -> Result { - let value = match self.value.evaluate(batch)? { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array(), + let (value, value_from_scalar) = match self.value.evaluate(batch)? { + ColumnarValue::Array(array) => (array, false), + ColumnarValue::Scalar(scalar) => (scalar.to_array(), true), }; let (list, list_from_scalar) = match self.list.evaluate(batch)? { @@ -331,60 +386,152 @@ impl PhysicalExpr for AnyExpr { .downcast_ref::() .expect("Unable to downcast list to ListArray"); - let negated = match self.op { - Operator::Eq => false, - Operator::NotEq => true, - op => { - return Err(DataFusionError::NotImplemented(format!( - "Operator for ANY expression, actual: {:?}", - op - ))); - } - }; - match value.data_type() { DataType::Float16 => { - make_primitive!(value, as_list, negated, Float16Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + Float16Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::Float32 => { - make_primitive!(value, as_list, negated, Float32Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + Float32Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::Float64 => { - make_primitive!(value, as_list, negated, Float64Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + Float64Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::Int8 => { - make_primitive!(value, as_list, negated, Int8Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + Int8Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::Int16 => { - make_primitive!(value, as_list, negated, Int16Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + Int16Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::Int32 => { - make_primitive!(value, as_list, negated, Int32Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + Int32Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::Int64 => { - make_primitive!(value, as_list, negated, Int64Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + Int64Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::UInt8 => { - make_primitive!(value, as_list, negated, UInt8Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + UInt8Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::UInt16 => { - make_primitive!(value, as_list, negated, UInt16Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + UInt16Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::UInt32 => { - make_primitive!(value, as_list, negated, UInt32Array, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + UInt32Array, + list_from_scalar, + value_from_scalar, + self.all + ) } DataType::UInt64 => { - make_primitive!(value, as_list, negated, UInt64Array, list_from_scalar) - } - DataType::Boolean => { - self.compare_bool(value, as_list, negated, list_from_scalar) - } - DataType::Utf8 => { - self.compare_utf8::(value, as_list, negated, list_from_scalar) - } - DataType::LargeUtf8 => { - self.compare_utf8::(value, as_list, negated, list_from_scalar) + make_primitive!( + value, + as_list, + self.op, + UInt64Array, + list_from_scalar, + value_from_scalar, + self.all + ) } + DataType::Boolean => self.compare_bool( + value, + as_list, + list_from_scalar, + value_from_scalar, + self.op, + self.all, + ), + DataType::Utf8 => self.compare_utf8::( + value, + as_list, + list_from_scalar, + value_from_scalar, + self.op, + self.all, + ), + DataType::LargeUtf8 => self.compare_utf8::( + value, + as_list, + list_from_scalar, + value_from_scalar, + self.op, + self.all, + ), datatype => Result::Err(DataFusionError::NotImplemented(format!( "AnyExpr does not support datatype {:?}.", datatype @@ -404,7 +551,11 @@ fn any_cast( let tmp = list.data_type(input_schema)?; let list_type = match &tmp { DataType::List(f) => f.data_type(), - _ => panic!("wtf"), + _ => { + return Err(DataFusionError::NotImplemented( + "ANY/ALL supports only literal arrays or subqueries".to_string(), + )) + } }; Ok((try_cast(value, input_schema, list_type.clone())?, list)) @@ -415,10 +566,11 @@ pub fn any( value: Arc, op: Operator, list: Arc, + all: bool, input_schema: &Schema, ) -> Result> { let (l, r) = any_cast(value, &op, list, input_schema)?; - Ok(Arc::new(AnyExpr::new(l, op, r))) + Ok(Arc::new(AnyExpr::new(l, op, r, all))) } #[cfg(test)] @@ -431,8 +583,8 @@ mod tests { // applies the any expr to an input batch macro_rules! execute_any { - ($BATCH:expr, $OP:expr, $EXPECTED:expr, $COL_A:expr, $COL_B:expr, $SCHEMA:expr) => {{ - let expr = any($COL_A, $OP, $COL_B, $SCHEMA).unwrap(); + ($BATCH:expr, $OP:expr, $EXPECTED:expr, $COL_A:expr, $COL_B:expr, $ALL:expr, $SCHEMA:expr) => {{ + let expr = any($COL_A, $OP, $COL_B, $ALL, $SCHEMA).unwrap(); let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); let result = result .as_any() @@ -480,6 +632,7 @@ mod tests { vec![Some(true), Some(false), None], col_a.clone(), col_b.clone(), + false, &schema ); @@ -488,8 +641,8 @@ mod tests { // applies the any expr to an input batch and list macro_rules! execute_any_with_list { - ($BATCH:expr, $LIST:expr, $OP:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ - let expr = any($COL, $OP, $LIST, $SCHEMA).unwrap(); + ($BATCH:expr, $LIST:expr, $OP:expr, $EXPECTED:expr, $COL:expr, $ALL:expr, $SCHEMA:expr) => {{ + let expr = any($COL, $OP, $LIST, $ALL, $SCHEMA).unwrap(); let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); let result = result .as_any() @@ -532,6 +685,7 @@ mod tests { Operator::Eq, vec![Some(true), Some(false), None], col_a.clone(), + false, schema ); @@ -570,6 +724,7 @@ mod tests { Operator::Eq, vec![Some(true), Some(false), None], col_a.clone(), + false, schema ); @@ -604,6 +759,7 @@ mod tests { Operator::Eq, vec![Some(true), Some(false), None], col_a.clone(), + false, schema );