From 7817ed586f50d2116f6866cba0943e49f11fdd68 Mon Sep 17 00:00:00 2001 From: rileyh Date: Wed, 11 Dec 2024 14:22:55 -0600 Subject: [PATCH] [#179] Factor away _get_aggregate_metrics() This function is now simple enough that we can just inline it in the one place where it's called. --- hlink/linking/core/model_metrics.py | 2 ++ .../link_step_train_test_models.py | 30 ++++------------ hlink/tests/core/model_metrics_test.py | 36 +++++++++++++++++++ hlink/tests/model_exploration_test.py | 18 ---------- 4 files changed, 44 insertions(+), 42 deletions(-) create mode 100644 hlink/tests/core/model_metrics_test.py diff --git a/hlink/linking/core/model_metrics.py b/hlink/linking/core/model_metrics.py index 3352cb2..cbbda1a 100644 --- a/hlink/linking/core/model_metrics.py +++ b/hlink/linking/core/model_metrics.py @@ -4,6 +4,8 @@ # https://github.com/ipums/hlink import math +import numpy as np + def mcc(tp: int, tn: int, fp: int, fn: int) -> float: """ diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index c3477d2..d00b7c4 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -658,14 +658,14 @@ def _capture_prediction_results( fn_count, tn_count, ) = _get_confusion_matrix(predictions, dep_var) - test_precision, test_recall, test_mcc = _get_aggregate_metrics( - tp_count, fp_count, fn_count, tn_count - ) + precision = metrics_core.precision(tp_count, fp_count) + recall = metrics_core.recall(tp_count, fn_count) + mcc = metrics_core.mcc(tp_count, tn_count, fp_count, fn_count) result = ThresholdTestResult( - precision=test_precision, - recall=test_recall, - mcc=test_mcc, + precision=precision, + recall=recall, + mcc=mcc, pr_auc=pr_auc, model_id=model, alpha_threshold=alpha_threshold, @@ -764,24 +764,6 @@ def _get_confusion_matrix( ) -def _get_aggregate_metrics( - true_positives: int, false_positives: int, false_negatives: int, true_negatives: int -) -> tuple[float, float, float]: - """ - Given the counts of true positives, false positives, false negatives, and - true negatives for a model run, compute several metrics to evaluate the - model's quality. - - Return a tuple of (precision, recall, Matthews Correlation Coefficient). - """ - precision = metrics_core.precision(true_positives, false_positives) - recall = metrics_core.recall(true_positives, false_negatives) - mcc = metrics_core.mcc( - true_positives, true_negatives, false_positives, false_negatives - ) - return precision, recall, mcc - - # The outer list entries hold results from each outer fold, the inner list has a ThresholdTestResult per threshold # matrix entry. We need to get data for each threshold entry together. Basically we need to invert the data. def _combine_by_threshold_matrix_entry( diff --git a/hlink/tests/core/model_metrics_test.py b/hlink/tests/core/model_metrics_test.py new file mode 100644 index 0000000..c8d046d --- /dev/null +++ b/hlink/tests/core/model_metrics_test.py @@ -0,0 +1,36 @@ +# This file is part of the ISRDI's hlink. +# For copyright and licensing information, see the NOTICE and LICENSE files +# in this project's top-level directory, and also on-line at: +# https://github.com/ipums/hlink + +from hlink.linking.core.model_metrics import mcc, precision, recall + + +def test_mcc_example() -> None: + tp = 3112 + fp = 205 + fn = 1134 + tn = 33259 + + mcc_score = mcc(tp, tn, fp, fn) + assert abs(mcc_score - 0.8111208) < 0.0001, "expected MCC to be near 0.8111208" + + +def test_precision_example() -> None: + tp = 3112 + fp = 205 + + precision_score = precision(tp, fp) + assert ( + abs(precision_score - 0.9381972) < 0.0001 + ), "expected precision to be near 0.9381972" + + +def test_recall_example() -> None: + tp = 3112 + fn = 1134 + + recall_score = recall(tp, fn) + assert ( + abs(recall_score - 0.7329251) < 0.0001 + ), "expected recall to be near 0.7329251" diff --git a/hlink/tests/model_exploration_test.py b/hlink/tests/model_exploration_test.py index 7222dbb..7414ef4 100644 --- a/hlink/tests/model_exploration_test.py +++ b/hlink/tests/model_exploration_test.py @@ -13,7 +13,6 @@ _custom_param_grid_builder, _get_model_parameters, _get_confusion_matrix, - _get_aggregate_metrics, ) @@ -1016,20 +1015,3 @@ def test_get_confusion_matrix(spark: SparkSession) -> None: assert false_positives == 3 assert false_negatives == 2 assert true_negatives == 1 - - -def test_get_aggregate_metrics() -> None: - true_positives = 3112 - false_positives = 205 - false_negatives = 1134 - true_negatives = 33259 - - precision, recall, mcc = _get_aggregate_metrics( - true_positives, false_positives, false_negatives, true_negatives - ) - - assert ( - abs(precision - 0.9381972) < 0.0001 - ), "expected precision to be near 0.9381972" - assert abs(recall - 0.7329251) < 0.0001, "expected recall to be near 0.7329251" - assert abs(mcc - 0.8111208) < 0.0001, "expected MCC to be near 0.8111208"