diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index d779121..4693b9a 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -652,13 +652,13 @@ def _capture_prediction_results( predictions.createOrReplaceTempView(f"{table_prefix}predictions") ( - test_TP_count, - test_FP_count, - test_FN_count, - test_TN_count, + tp_count, + fp_count, + fn_count, + tn_count, ) = _get_confusion_matrix(predictions, dep_var) test_precision, test_recall, test_mcc = _get_aggregate_metrics( - test_TP_count, test_FP_count, test_FN_count, test_TN_count + tp_count, fp_count, fn_count, tn_count ) result = ThresholdTestResult( @@ -690,15 +690,15 @@ def _save_training_results( # ) -def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float: +def _calc_mcc(tp: int, tn: int, fp: int, fn: int) -> float: """ - Given the counts of true positives (TP), true negatives (TN), false - positives (FP), and false negatives (FN) for a model run, compute the + Given the counts of true positives (tp), true negatives (tn), false + positives (fp), and false negatives (fn) for a model run, compute the Matthews Correlation Coefficient (MCC). """ - if (math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) != 0: - mcc = ((TP * TN) - (FP * FN)) / ( - math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) + if (math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))) != 0: + mcc = ((tp * tn) - (fp * fn)) / ( + math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ) else: mcc = 0 @@ -779,7 +779,7 @@ def _get_confusion_matrix( def _get_aggregate_metrics( - TP_count: int, FP_count: int, FN_count: int, TN_count: int + true_positives: int, false_positives: int, false_negatives: int, true_negatives: int ) -> tuple[float, float, float]: """ Given the counts of true positives, false positives, false negatives, and @@ -788,15 +788,15 @@ def _get_aggregate_metrics( Return a tuple of (precision, recall, Matthews Correlation Coefficient). """ - if (TP_count + FP_count) == 0: + if (true_positives + false_positives) == 0: precision = np.nan else: - precision = TP_count / (TP_count + FP_count) - if (TP_count + FN_count) == 0: + precision = true_positives / (true_positives + false_positives) + if (true_positives + false_negatives) == 0: recall = np.nan else: - recall = TP_count / (TP_count + FN_count) - mcc = _calc_mcc(TP_count, TN_count, FP_count, FN_count) + recall = true_positives / (true_positives + false_negatives) + mcc = _calc_mcc(true_positives, true_negatives, false_positives, false_negatives) return precision, recall, mcc