diff --git a/hlink/tests/model_exploration_test.py b/hlink/tests/model_exploration_test.py index 7414ef4..7222dbb 100644 --- a/hlink/tests/model_exploration_test.py +++ b/hlink/tests/model_exploration_test.py @@ -13,6 +13,7 @@ _custom_param_grid_builder, _get_model_parameters, _get_confusion_matrix, + _get_aggregate_metrics, ) @@ -1015,3 +1016,20 @@ def test_get_confusion_matrix(spark: SparkSession) -> None: assert false_positives == 3 assert false_negatives == 2 assert true_negatives == 1 + + +def test_get_aggregate_metrics() -> None: + true_positives = 3112 + false_positives = 205 + false_negatives = 1134 + true_negatives = 33259 + + precision, recall, mcc = _get_aggregate_metrics( + true_positives, false_positives, false_negatives, true_negatives + ) + + assert ( + abs(precision - 0.9381972) < 0.0001 + ), "expected precision to be near 0.9381972" + assert abs(recall - 0.7329251) < 0.0001, "expected recall to be near 0.7329251" + assert abs(mcc - 0.8111208) < 0.0001, "expected MCC to be near 0.8111208"