Skip to content

Commit 3bc497c

Browse files
committed
Add tests for bernoulli and gaussians combination
1 parent d5741cd commit 3bc497c

File tree

1 file changed

+181
-5
lines changed

1 file changed

+181
-5
lines changed

datafusion/expr-common/src/statistics.rs

Lines changed: 181 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,6 @@ pub fn combine_bernoullis(
576576
left: &BernoulliDistribution,
577577
right: &BernoulliDistribution,
578578
) -> Result<BernoulliDistribution> {
579-
// TODO: Write tests for this function.
580579
let left_p = left.p_value();
581580
let right_p = right.p_value();
582581
match op {
@@ -632,7 +631,6 @@ pub fn combine_gaussians(
632631
left: &GaussianDistribution,
633632
right: &GaussianDistribution,
634633
) -> Result<Option<GaussianDistribution>> {
635-
// TODO: Write tests for this function.
636634
match op {
637635
Operator::Plus => GaussianDistribution::try_new(
638636
left.mean().add_checked(right.mean())?,
@@ -855,14 +853,15 @@ pub fn compute_variance(
855853
#[cfg(test)]
856854
mod tests {
857855
use super::{
858-
compute_mean, compute_median, compute_variance, create_bernoulli_from_comparison,
859-
new_unknown_from_binary_op, StatisticsV2, UniformDistribution,
856+
combine_bernoullis, combine_gaussians, compute_mean, compute_median,
857+
compute_variance, create_bernoulli_from_comparison, new_unknown_from_binary_op,
858+
BernoulliDistribution, GaussianDistribution, StatisticsV2, UniformDistribution,
860859
};
861860
use crate::interval_arithmetic::{apply_operator, Interval};
862861
use crate::operator::Operator;
863862

864863
use arrow::datatypes::DataType;
865-
use datafusion_common::{Result, ScalarValue};
864+
use datafusion_common::{HashSet, Result, ScalarValue};
866865

867866
// The test data in the following tests are placed as follows: (stat -> expected answer)
868867
#[test]
@@ -1306,6 +1305,143 @@ mod tests {
13061305
Ok(())
13071306
}
13081307

1308+
#[test]
1309+
fn test_combine_bernoullis_and_op() -> Result<()> {
1310+
let op = Operator::And;
1311+
let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
1312+
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1313+
let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1314+
let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1315+
1316+
assert_eq!(
1317+
combine_bernoullis(&op, &left, &right)?.p_value(),
1318+
&ScalarValue::from(0.5 * 0.4)
1319+
);
1320+
assert_eq!(
1321+
combine_bernoullis(&op, &left_null, &right)?.p_value(),
1322+
&ScalarValue::Float64(None)
1323+
);
1324+
assert_eq!(
1325+
combine_bernoullis(&op, &left, &right_null)?.p_value(),
1326+
&ScalarValue::Float64(None)
1327+
);
1328+
assert_eq!(
1329+
combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1330+
&ScalarValue::Null
1331+
);
1332+
1333+
Ok(())
1334+
}
1335+
1336+
#[test]
1337+
fn test_combine_bernoullis_or_op() -> Result<()> {
1338+
let op = Operator::Or;
1339+
let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1340+
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1341+
let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1342+
let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1343+
1344+
assert_eq!(
1345+
combine_bernoullis(&op, &left, &right)?.p_value(),
1346+
&ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
1347+
);
1348+
assert_eq!(
1349+
combine_bernoullis(&op, &left_null, &right)?.p_value(),
1350+
&ScalarValue::Float64(None)
1351+
);
1352+
assert_eq!(
1353+
combine_bernoullis(&op, &left, &right_null)?.p_value(),
1354+
&ScalarValue::Float64(None)
1355+
);
1356+
assert_eq!(
1357+
combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1358+
&ScalarValue::Null
1359+
);
1360+
1361+
Ok(())
1362+
}
1363+
1364+
#[test]
1365+
fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
1366+
let mut operator_set = operator_set();
1367+
operator_set.remove(&Operator::And);
1368+
operator_set.remove(&Operator::Or);
1369+
1370+
let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1371+
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1372+
for op in operator_set {
1373+
assert!(
1374+
combine_bernoullis(&op, &left, &right).is_err(),
1375+
"Operator {op} should not be supported for Bernoulli statistics"
1376+
);
1377+
}
1378+
1379+
Ok(())
1380+
}
1381+
1382+
#[test]
1383+
fn test_combine_gaussians_addition() -> Result<()> {
1384+
let op = Operator::Plus;
1385+
let left = GaussianDistribution::try_new(
1386+
ScalarValue::from(3.0),
1387+
ScalarValue::from(2.0),
1388+
)?;
1389+
let right = GaussianDistribution::try_new(
1390+
ScalarValue::from(4.0),
1391+
ScalarValue::from(1.0),
1392+
)?;
1393+
1394+
let result = combine_gaussians(&op, &left, &right)?.unwrap();
1395+
1396+
assert_eq!(result.mean(), &ScalarValue::from(7.0)); // 3.0 + 4.0
1397+
assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1398+
Ok(())
1399+
}
1400+
1401+
#[test]
1402+
fn test_combine_gaussians_subtraction() -> Result<()> {
1403+
let op = Operator::Minus;
1404+
let left = GaussianDistribution::try_new(
1405+
ScalarValue::from(7.0),
1406+
ScalarValue::from(2.0),
1407+
)?;
1408+
let right = GaussianDistribution::try_new(
1409+
ScalarValue::from(4.0),
1410+
ScalarValue::from(1.0),
1411+
)?;
1412+
1413+
let result = combine_gaussians(&op, &left, &right)?.unwrap();
1414+
1415+
assert_eq!(result.mean(), &ScalarValue::from(3.0)); // 7.0 - 4.0
1416+
assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1417+
1418+
Ok(())
1419+
}
1420+
1421+
#[test]
1422+
fn test_combine_gaussians_unsupported_ops() -> Result<()> {
1423+
let mut operator_set = operator_set();
1424+
operator_set.remove(&Operator::Plus);
1425+
operator_set.remove(&Operator::Minus);
1426+
1427+
let left = GaussianDistribution::try_new(
1428+
ScalarValue::from(7.0),
1429+
ScalarValue::from(2.0),
1430+
)?;
1431+
let right = GaussianDistribution::try_new(
1432+
ScalarValue::from(4.0),
1433+
ScalarValue::from(1.0),
1434+
)?;
1435+
for op in operator_set {
1436+
assert!(
1437+
combine_gaussians(&op, &left, &right)?.is_none(),
1438+
"Operator {op} should not be supported for Gaussian statistics"
1439+
);
1440+
}
1441+
1442+
Ok(())
1443+
}
1444+
13091445
// Expected test results were calculated in Wolfram Mathematica, by using:
13101446
//
13111447
// *METHOD_NAME*[TransformedDistribution[
@@ -1431,4 +1567,44 @@ mod tests {
14311567

14321568
Ok(())
14331569
}
1570+
1571+
fn operator_set() -> HashSet<Operator> {
1572+
use super::Operator::*;
1573+
1574+
let all_ops = vec![
1575+
And,
1576+
Or,
1577+
Eq,
1578+
NotEq,
1579+
Gt,
1580+
GtEq,
1581+
Lt,
1582+
LtEq,
1583+
Plus,
1584+
Minus,
1585+
Multiply,
1586+
Divide,
1587+
Modulo,
1588+
IsDistinctFrom,
1589+
IsNotDistinctFrom,
1590+
RegexMatch,
1591+
RegexIMatch,
1592+
RegexNotMatch,
1593+
RegexNotIMatch,
1594+
LikeMatch,
1595+
ILikeMatch,
1596+
NotLikeMatch,
1597+
NotILikeMatch,
1598+
BitwiseAnd,
1599+
BitwiseOr,
1600+
BitwiseXor,
1601+
BitwiseShiftRight,
1602+
BitwiseShiftLeft,
1603+
StringConcat,
1604+
AtArrow,
1605+
ArrowAt,
1606+
];
1607+
1608+
all_ops.into_iter().collect()
1609+
}
14341610
}

0 commit comments

Comments
 (0)