@@ -23,11 +23,19 @@ use std::collections::HashMap;
23
23
use std:: sync:: Arc ;
24
24
25
25
use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef , TimeUnit } ;
26
+ use arrow_schema:: { Fields , SchemaBuilder } ;
26
27
use datafusion_common:: config:: ConfigOptions ;
27
- use datafusion_common:: { plan_err, Result } ;
28
- use datafusion_expr:: { AggregateUDF , LogicalPlan , ScalarUDF , TableSource , WindowUDF } ;
28
+ use datafusion_common:: tree_node:: { TransformedResult , TreeNode } ;
29
+ use datafusion_common:: { plan_err, DFSchema , Result , ScalarValue } ;
30
+ use datafusion_expr:: interval_arithmetic:: { Interval , NullableInterval } ;
31
+ use datafusion_expr:: {
32
+ col, lit, AggregateUDF , BinaryExpr , Expr , ExprSchemable , LogicalPlan , Operator ,
33
+ ScalarUDF , TableSource , WindowUDF ,
34
+ } ;
35
+ use datafusion_functions:: core:: expr_ext:: FieldAccessor ;
29
36
use datafusion_optimizer:: analyzer:: Analyzer ;
30
37
use datafusion_optimizer:: optimizer:: Optimizer ;
38
+ use datafusion_optimizer:: simplify_expressions:: GuaranteeRewriter ;
31
39
use datafusion_optimizer:: { OptimizerConfig , OptimizerContext } ;
32
40
use datafusion_sql:: planner:: { ContextProvider , SqlToRel } ;
33
41
use datafusion_sql:: sqlparser:: ast:: Statement ;
@@ -233,3 +241,120 @@ impl TableSource for MyTableSource {
233
241
self . schema . clone ( )
234
242
}
235
243
}
244
+
245
+ #[ test]
246
+ fn test_nested_schema_nullability ( ) {
247
+ let mut builder = SchemaBuilder :: new ( ) ;
248
+ builder. push ( Field :: new ( "foo" , DataType :: Int32 , true ) ) ;
249
+ builder. push ( Field :: new (
250
+ "parent" ,
251
+ DataType :: Struct ( Fields :: from ( vec ! [ Field :: new(
252
+ "child" ,
253
+ DataType :: Int64 ,
254
+ false ,
255
+ ) ] ) ) ,
256
+ true ,
257
+ ) ) ;
258
+ let schema = builder. finish ( ) ;
259
+
260
+ let dfschema = DFSchema :: from_field_specific_qualified_schema (
261
+ vec ! [ Some ( "table_name" . into( ) ) , None ] ,
262
+ & Arc :: new ( schema) ,
263
+ )
264
+ . unwrap ( ) ;
265
+
266
+ let expr = col ( "parent" ) . field ( "child" ) ;
267
+ assert ! ( expr. nullable( & dfschema) . unwrap( ) ) ;
268
+ }
269
+
270
+ #[ test]
271
+ fn test_inequalities_non_null_bounded ( ) {
272
+ let guarantees = vec ! [
273
+ // x ∈ [1, 3] (not null)
274
+ (
275
+ col( "x" ) ,
276
+ NullableInterval :: NotNull {
277
+ values: Interval :: make( Some ( 1_i32 ) , Some ( 3_i32 ) ) . unwrap( ) ,
278
+ } ,
279
+ ) ,
280
+ // s.y ∈ [1, 3] (not null)
281
+ (
282
+ col( "s" ) . field( "y" ) ,
283
+ NullableInterval :: NotNull {
284
+ values: Interval :: make( Some ( 1_i32 ) , Some ( 3_i32 ) ) . unwrap( ) ,
285
+ } ,
286
+ ) ,
287
+ ] ;
288
+
289
+ let mut rewriter = GuaranteeRewriter :: new ( guarantees. iter ( ) ) ;
290
+
291
+ // (original_expr, expected_simplification)
292
+ let simplified_cases = & [
293
+ ( col ( "x" ) . lt ( lit ( 0 ) ) , false ) ,
294
+ ( col ( "s" ) . field ( "y" ) . lt ( lit ( 0 ) ) , false ) ,
295
+ ( col ( "x" ) . lt_eq ( lit ( 3 ) ) , true ) ,
296
+ ( col ( "x" ) . gt ( lit ( 3 ) ) , false ) ,
297
+ ( col ( "x" ) . gt ( lit ( 0 ) ) , true ) ,
298
+ ( col ( "x" ) . eq ( lit ( 0 ) ) , false ) ,
299
+ ( col ( "x" ) . not_eq ( lit ( 0 ) ) , true ) ,
300
+ ( col ( "x" ) . between ( lit ( 0 ) , lit ( 5 ) ) , true ) ,
301
+ ( col ( "x" ) . between ( lit ( 5 ) , lit ( 10 ) ) , false ) ,
302
+ ( col ( "x" ) . not_between ( lit ( 0 ) , lit ( 5 ) ) , false ) ,
303
+ ( col ( "x" ) . not_between ( lit ( 5 ) , lit ( 10 ) ) , true ) ,
304
+ (
305
+ Expr :: BinaryExpr ( BinaryExpr {
306
+ left : Box :: new ( col ( "x" ) ) ,
307
+ op : Operator :: IsDistinctFrom ,
308
+ right : Box :: new ( lit ( ScalarValue :: Null ) ) ,
309
+ } ) ,
310
+ true ,
311
+ ) ,
312
+ (
313
+ Expr :: BinaryExpr ( BinaryExpr {
314
+ left : Box :: new ( col ( "x" ) ) ,
315
+ op : Operator :: IsDistinctFrom ,
316
+ right : Box :: new ( lit ( 5 ) ) ,
317
+ } ) ,
318
+ true ,
319
+ ) ,
320
+ ] ;
321
+
322
+ validate_simplified_cases ( & mut rewriter, simplified_cases) ;
323
+
324
+ let unchanged_cases = & [
325
+ col ( "x" ) . gt ( lit ( 2 ) ) ,
326
+ col ( "x" ) . lt_eq ( lit ( 2 ) ) ,
327
+ col ( "x" ) . eq ( lit ( 2 ) ) ,
328
+ col ( "x" ) . not_eq ( lit ( 2 ) ) ,
329
+ col ( "x" ) . between ( lit ( 3 ) , lit ( 5 ) ) ,
330
+ col ( "x" ) . not_between ( lit ( 3 ) , lit ( 10 ) ) ,
331
+ ] ;
332
+
333
+ validate_unchanged_cases ( & mut rewriter, unchanged_cases) ;
334
+ }
335
+
336
+ fn validate_simplified_cases < T > ( rewriter : & mut GuaranteeRewriter , cases : & [ ( Expr , T ) ] )
337
+ where
338
+ ScalarValue : From < T > ,
339
+ T : Clone ,
340
+ {
341
+ for ( expr, expected_value) in cases {
342
+ let output = expr. clone ( ) . rewrite ( rewriter) . data ( ) . unwrap ( ) ;
343
+ let expected = lit ( ScalarValue :: from ( expected_value. clone ( ) ) ) ;
344
+ assert_eq ! (
345
+ output, expected,
346
+ "{} simplified to {}, but expected {}" ,
347
+ expr, output, expected
348
+ ) ;
349
+ }
350
+ }
351
+ fn validate_unchanged_cases ( rewriter : & mut GuaranteeRewriter , cases : & [ Expr ] ) {
352
+ for expr in cases {
353
+ let output = expr. clone ( ) . rewrite ( rewriter) . data ( ) . unwrap ( ) ;
354
+ assert_eq ! (
355
+ & output, expr,
356
+ "{} was simplified to {}, but expected it to be unchanged" ,
357
+ expr, output
358
+ ) ;
359
+ }
360
+ }
0 commit comments