@@ -576,7 +576,6 @@ pub fn combine_bernoullis(
576576 left : & BernoulliDistribution ,
577577 right : & BernoulliDistribution ,
578578) -> Result < BernoulliDistribution > {
579- // TODO: Write tests for this function.
580579 let left_p = left. p_value ( ) ;
581580 let right_p = right. p_value ( ) ;
582581 match op {
@@ -632,7 +631,6 @@ pub fn combine_gaussians(
632631 left : & GaussianDistribution ,
633632 right : & GaussianDistribution ,
634633) -> Result < Option < GaussianDistribution > > {
635- // TODO: Write tests for this function.
636634 match op {
637635 Operator :: Plus => GaussianDistribution :: try_new (
638636 left. mean ( ) . add_checked ( right. mean ( ) ) ?,
@@ -855,14 +853,15 @@ pub fn compute_variance(
855853#[ cfg( test) ]
856854mod tests {
857855 use super :: {
858- compute_mean, compute_median, compute_variance, create_bernoulli_from_comparison,
859- new_unknown_from_binary_op, StatisticsV2 , UniformDistribution ,
856+ combine_bernoullis, combine_gaussians, compute_mean, compute_median,
857+ compute_variance, create_bernoulli_from_comparison, new_unknown_from_binary_op,
858+ BernoulliDistribution , GaussianDistribution , StatisticsV2 , UniformDistribution ,
860859 } ;
861860 use crate :: interval_arithmetic:: { apply_operator, Interval } ;
862861 use crate :: operator:: Operator ;
863862
864863 use arrow:: datatypes:: DataType ;
865- use datafusion_common:: { Result , ScalarValue } ;
864+ use datafusion_common:: { HashSet , Result , ScalarValue } ;
866865
867866 // The test data in the following tests are placed as follows: (stat -> expected answer)
868867 #[ test]
@@ -1306,6 +1305,143 @@ mod tests {
13061305 Ok ( ( ) )
13071306 }
13081307
1308+ #[ test]
1309+ fn test_combine_bernoullis_and_op ( ) -> Result < ( ) > {
1310+ let op = Operator :: And ;
1311+ let left = BernoulliDistribution :: try_new ( ScalarValue :: from ( 0.5 ) ) ?;
1312+ let right = BernoulliDistribution :: try_new ( ScalarValue :: from ( 0.4 ) ) ?;
1313+ let left_null = BernoulliDistribution :: try_new ( ScalarValue :: Null ) ?;
1314+ let right_null = BernoulliDistribution :: try_new ( ScalarValue :: Null ) ?;
1315+
1316+ assert_eq ! (
1317+ combine_bernoullis( & op, & left, & right) ?. p_value( ) ,
1318+ & ScalarValue :: from( 0.5 * 0.4 )
1319+ ) ;
1320+ assert_eq ! (
1321+ combine_bernoullis( & op, & left_null, & right) ?. p_value( ) ,
1322+ & ScalarValue :: Float64 ( None )
1323+ ) ;
1324+ assert_eq ! (
1325+ combine_bernoullis( & op, & left, & right_null) ?. p_value( ) ,
1326+ & ScalarValue :: Float64 ( None )
1327+ ) ;
1328+ assert_eq ! (
1329+ combine_bernoullis( & op, & left_null, & left_null) ?. p_value( ) ,
1330+ & ScalarValue :: Null
1331+ ) ;
1332+
1333+ Ok ( ( ) )
1334+ }
1335+
1336+ #[ test]
1337+ fn test_combine_bernoullis_or_op ( ) -> Result < ( ) > {
1338+ let op = Operator :: Or ;
1339+ let left = BernoulliDistribution :: try_new ( ScalarValue :: from ( 0.6 ) ) ?;
1340+ let right = BernoulliDistribution :: try_new ( ScalarValue :: from ( 0.4 ) ) ?;
1341+ let left_null = BernoulliDistribution :: try_new ( ScalarValue :: Null ) ?;
1342+ let right_null = BernoulliDistribution :: try_new ( ScalarValue :: Null ) ?;
1343+
1344+ assert_eq ! (
1345+ combine_bernoullis( & op, & left, & right) ?. p_value( ) ,
1346+ & ScalarValue :: from( 0.6 + 0.4 - ( 0.6 * 0.4 ) )
1347+ ) ;
1348+ assert_eq ! (
1349+ combine_bernoullis( & op, & left_null, & right) ?. p_value( ) ,
1350+ & ScalarValue :: Float64 ( None )
1351+ ) ;
1352+ assert_eq ! (
1353+ combine_bernoullis( & op, & left, & right_null) ?. p_value( ) ,
1354+ & ScalarValue :: Float64 ( None )
1355+ ) ;
1356+ assert_eq ! (
1357+ combine_bernoullis( & op, & left_null, & left_null) ?. p_value( ) ,
1358+ & ScalarValue :: Null
1359+ ) ;
1360+
1361+ Ok ( ( ) )
1362+ }
1363+
1364+ #[ test]
1365+ fn test_combine_bernoullis_unsupported_ops ( ) -> Result < ( ) > {
1366+ let mut operator_set = operator_set ( ) ;
1367+ operator_set. remove ( & Operator :: And ) ;
1368+ operator_set. remove ( & Operator :: Or ) ;
1369+
1370+ let left = BernoulliDistribution :: try_new ( ScalarValue :: from ( 0.6 ) ) ?;
1371+ let right = BernoulliDistribution :: try_new ( ScalarValue :: from ( 0.4 ) ) ?;
1372+ for op in operator_set {
1373+ assert ! (
1374+ combine_bernoullis( & op, & left, & right) . is_err( ) ,
1375+ "Operator {op} should not be supported for Bernoulli statistics"
1376+ ) ;
1377+ }
1378+
1379+ Ok ( ( ) )
1380+ }
1381+
1382+ #[ test]
1383+ fn test_combine_gaussians_addition ( ) -> Result < ( ) > {
1384+ let op = Operator :: Plus ;
1385+ let left = GaussianDistribution :: try_new (
1386+ ScalarValue :: from ( 3.0 ) ,
1387+ ScalarValue :: from ( 2.0 ) ,
1388+ ) ?;
1389+ let right = GaussianDistribution :: try_new (
1390+ ScalarValue :: from ( 4.0 ) ,
1391+ ScalarValue :: from ( 1.0 ) ,
1392+ ) ?;
1393+
1394+ let result = combine_gaussians ( & op, & left, & right) ?. unwrap ( ) ;
1395+
1396+ assert_eq ! ( result. mean( ) , & ScalarValue :: from( 7.0 ) ) ; // 3.0 + 4.0
1397+ assert_eq ! ( result. variance( ) , & ScalarValue :: from( 3.0 ) ) ; // 2.0 + 1.0
1398+ Ok ( ( ) )
1399+ }
1400+
1401+ #[ test]
1402+ fn test_combine_gaussians_subtraction ( ) -> Result < ( ) > {
1403+ let op = Operator :: Minus ;
1404+ let left = GaussianDistribution :: try_new (
1405+ ScalarValue :: from ( 7.0 ) ,
1406+ ScalarValue :: from ( 2.0 ) ,
1407+ ) ?;
1408+ let right = GaussianDistribution :: try_new (
1409+ ScalarValue :: from ( 4.0 ) ,
1410+ ScalarValue :: from ( 1.0 ) ,
1411+ ) ?;
1412+
1413+ let result = combine_gaussians ( & op, & left, & right) ?. unwrap ( ) ;
1414+
1415+ assert_eq ! ( result. mean( ) , & ScalarValue :: from( 3.0 ) ) ; // 7.0 - 4.0
1416+ assert_eq ! ( result. variance( ) , & ScalarValue :: from( 3.0 ) ) ; // 2.0 + 1.0
1417+
1418+ Ok ( ( ) )
1419+ }
1420+
1421+ #[ test]
1422+ fn test_combine_gaussians_unsupported_ops ( ) -> Result < ( ) > {
1423+ let mut operator_set = operator_set ( ) ;
1424+ operator_set. remove ( & Operator :: Plus ) ;
1425+ operator_set. remove ( & Operator :: Minus ) ;
1426+
1427+ let left = GaussianDistribution :: try_new (
1428+ ScalarValue :: from ( 7.0 ) ,
1429+ ScalarValue :: from ( 2.0 ) ,
1430+ ) ?;
1431+ let right = GaussianDistribution :: try_new (
1432+ ScalarValue :: from ( 4.0 ) ,
1433+ ScalarValue :: from ( 1.0 ) ,
1434+ ) ?;
1435+ for op in operator_set {
1436+ assert ! (
1437+ combine_gaussians( & op, & left, & right) ?. is_none( ) ,
1438+ "Operator {op} should not be supported for Gaussian statistics"
1439+ ) ;
1440+ }
1441+
1442+ Ok ( ( ) )
1443+ }
1444+
13091445 // Expected test results were calculated in Wolfram Mathematica, by using:
13101446 //
13111447 // *METHOD_NAME*[TransformedDistribution[
@@ -1431,4 +1567,44 @@ mod tests {
14311567
14321568 Ok ( ( ) )
14331569 }
1570+
1571+ fn operator_set ( ) -> HashSet < Operator > {
1572+ use super :: Operator :: * ;
1573+
1574+ let all_ops = vec ! [
1575+ And ,
1576+ Or ,
1577+ Eq ,
1578+ NotEq ,
1579+ Gt ,
1580+ GtEq ,
1581+ Lt ,
1582+ LtEq ,
1583+ Plus ,
1584+ Minus ,
1585+ Multiply ,
1586+ Divide ,
1587+ Modulo ,
1588+ IsDistinctFrom ,
1589+ IsNotDistinctFrom ,
1590+ RegexMatch ,
1591+ RegexIMatch ,
1592+ RegexNotMatch ,
1593+ RegexNotIMatch ,
1594+ LikeMatch ,
1595+ ILikeMatch ,
1596+ NotLikeMatch ,
1597+ NotILikeMatch ,
1598+ BitwiseAnd ,
1599+ BitwiseOr ,
1600+ BitwiseXor ,
1601+ BitwiseShiftRight ,
1602+ BitwiseShiftLeft ,
1603+ StringConcat ,
1604+ AtArrow ,
1605+ ArrowAt ,
1606+ ] ;
1607+
1608+ all_ops. into_iter ( ) . collect ( )
1609+ }
14341610}
0 commit comments