Skip to content

Commit

Permalink
[#176] Add a unit test for _get_aggregate_metrics()
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Dec 10, 2024
1 parent c43b57d commit 4aad62e
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_custom_param_grid_builder,
_get_model_parameters,
_get_confusion_matrix,
_get_aggregate_metrics,
)


Expand Down Expand Up @@ -1015,3 +1016,20 @@ def test_get_confusion_matrix(spark: SparkSession) -> None:
assert false_positives == 3
assert false_negatives == 2
assert true_negatives == 1


def test_get_aggregate_metrics() -> None:
true_positives = 3112
false_positives = 205
false_negatives = 1134
true_negatives = 33259

precision, recall, mcc = _get_aggregate_metrics(
true_positives, false_positives, false_negatives, true_negatives
)

assert (
abs(precision - 0.9381972) < 0.0001
), "expected precision to be near 0.9381972"
assert abs(recall - 0.7329251) < 0.0001, "expected recall to be near 0.7329251"
assert abs(mcc - 0.8111208) < 0.0001, "expected MCC to be near 0.8111208"

0 comments on commit 4aad62e

Please sign in to comment.