diff --git a/hlink/linking/core/model_metrics.py b/hlink/linking/core/model_metrics.py index 95b5ef8..d75a9b3 100644 --- a/hlink/linking/core/model_metrics.py +++ b/hlink/linking/core/model_metrics.py @@ -35,7 +35,7 @@ def mcc(true_pos: int, true_neg: int, false_pos: int, false_neg: int) -> float: ) ) else: - mcc = 0 + mcc = math.nan return mcc diff --git a/hlink/tests/core/model_metrics_test.py b/hlink/tests/core/model_metrics_test.py index 56df30c..41b70b4 100644 --- a/hlink/tests/core/model_metrics_test.py +++ b/hlink/tests/core/model_metrics_test.py @@ -6,6 +6,7 @@ from hypothesis import assume, given import hypothesis.strategies as st +import pytest from hlink.linking.core.model_metrics import f_measure, mcc, precision, recall @@ -71,6 +72,21 @@ def test_mcc_example() -> None: assert abs(mcc_score - 0.8111208) < 0.0001, "expected MCC to be near 0.8111208" +@pytest.mark.parametrize( + "true_pos,true_neg,false_pos,false_neg", + [(0, 0, 0, 0), (0, 1, 0, 1), (0, 1, 1, 0), (1, 0, 0, 1), (1, 0, 1, 0)], +) +def test_mcc_denom_zero( + true_pos: int, true_neg: int, false_pos: int, false_neg: int +) -> None: + """ + If the denominator of MCC is 0, it's not well-defined, and it returns NaN. This + can happen in a variety of situations if at least 2 of the inputs are 0. + """ + mcc_score = mcc(true_pos, true_neg, false_pos, false_neg) + assert math.isnan(mcc_score) + + def test_precision_example() -> None: true_pos = 3112 false_pos = 205 diff --git a/hlink/tests/hh_model_exploration_test.py b/hlink/tests/hh_model_exploration_test.py index baa4d33..0e80026 100644 --- a/hlink/tests/hh_model_exploration_test.py +++ b/hlink/tests/hh_model_exploration_test.py @@ -54,10 +54,10 @@ def test_all_hh_mod_ev( "parameters", "alpha_threshold", "threshold_ratio", - "precision_test_mean", - "recall_test_mean", - "mcc_test_mean", - "pr_auc_test_mean", + "precision_mean", + "recall_mean", + "mcc_mean", + "pr_auc_mean", ] # TODO we should expect to get most of these columns once the results reporting is finished. @@ -67,13 +67,13 @@ def test_all_hh_mod_ev( "alpha_threshold", "threshold_ratio", # "precision_test_mean", - "precision_test_sd", - "recall_test_mean", - "recall_test_sd", - "mcc_test_sd", - "mcc_test_mean", - "pr_auc_test_mean", - "pr_auc_test_sd", + "precision_sd", + "recall_mean", + "recall_sd", + "mcc_sd", + "mcc_mean", + "pr_auc_mean", + "pr_auc_sd", "maxDepth", "numTrees", ] @@ -83,19 +83,15 @@ def test_all_hh_mod_ev( assert ( 0.6 - < tr.query("model == 'logistic_regression'")["precision_test_mean"].iloc[0] + < tr.query("model == 'logistic_regression'")["precision_mean"].iloc[0] <= 1.0 ) assert tr.query("model == 'logistic_regression'")["alpha_threshold"].iloc[0] == 0.5 assert ( - 0.7 - < tr.query("model == 'logistic_regression'")["pr_auc_test_mean"].iloc[0] - <= 1.0 + 0.7 < tr.query("model == 'logistic_regression'")["pr_auc_mean"].iloc[0] <= 1.0 ) assert ( - 0.9 - < tr.query("model == 'logistic_regression'")["recall_test_mean"].iloc[0] - <= 1.0 + 0.9 < tr.query("model == 'logistic_regression'")["recall_mean"].iloc[0] <= 1.0 ) preds = spark.table("hh_model_eval_predictions").toPandas() diff --git a/hlink/tests/model_exploration_test.py b/hlink/tests/model_exploration_test.py index 38ab80a..cc5db41 100644 --- a/hlink/tests/model_exploration_test.py +++ b/hlink/tests/model_exploration_test.py @@ -759,7 +759,7 @@ def test_step_2_train_decision_tree_spark( print(f"Decision tree results: {tr}") - assert tr.shape == (1, 15) + assert tr.shape == (1, 14) # assert tr.query("model == 'decision_tree'")["precision_mean"].iloc[0] > 0 assert tr.query("model == 'decision_tree'")["maxDepth"].iloc[0] == 3 assert tr.query("model == 'decision_tree'")["minInstancesPerNode"].iloc[0] == 1