Skip to content

Commit

Permalink
[#176] Lowercase tp/fp/fn/tn variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Dec 10, 2024
1 parent 4aad62e commit 3efbb0c
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ def _capture_prediction_results(
predictions.createOrReplaceTempView(f"{table_prefix}predictions")

(
test_TP_count,
test_FP_count,
test_FN_count,
test_TN_count,
tp_count,
fp_count,
fn_count,
tn_count,
) = _get_confusion_matrix(predictions, dep_var)
test_precision, test_recall, test_mcc = _get_aggregate_metrics(
test_TP_count, test_FP_count, test_FN_count, test_TN_count
tp_count, fp_count, fn_count, tn_count
)

result = ThresholdTestResult(
Expand Down Expand Up @@ -690,15 +690,15 @@ def _save_training_results(
# )


def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float:
def _calc_mcc(tp: int, tn: int, fp: int, fn: int) -> float:
"""
Given the counts of true positives (TP), true negatives (TN), false
positives (FP), and false negatives (FN) for a model run, compute the
Given the counts of true positives (tp), true negatives (tn), false
positives (fp), and false negatives (fn) for a model run, compute the
Matthews Correlation Coefficient (MCC).
"""
if (math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) != 0:
mcc = ((TP * TN) - (FP * FN)) / (
math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
if (math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))) != 0:
mcc = ((tp * tn) - (fp * fn)) / (
math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
)
else:
mcc = 0
Expand Down Expand Up @@ -779,7 +779,7 @@ def _get_confusion_matrix(


def _get_aggregate_metrics(
TP_count: int, FP_count: int, FN_count: int, TN_count: int
true_positives: int, false_positives: int, false_negatives: int, true_negatives: int
) -> tuple[float, float, float]:
"""
Given the counts of true positives, false positives, false negatives, and
Expand All @@ -788,15 +788,15 @@ def _get_aggregate_metrics(
Return a tuple of (precision, recall, Matthews Correlation Coefficient).
"""
if (TP_count + FP_count) == 0:
if (true_positives + false_positives) == 0:
precision = np.nan
else:
precision = TP_count / (TP_count + FP_count)
if (TP_count + FN_count) == 0:
precision = true_positives / (true_positives + false_positives)
if (true_positives + false_negatives) == 0:
recall = np.nan
else:
recall = TP_count / (TP_count + FN_count)
mcc = _calc_mcc(TP_count, TN_count, FP_count, FN_count)
recall = true_positives / (true_positives + false_negatives)
mcc = _calc_mcc(true_positives, true_negatives, false_positives, false_negatives)
return precision, recall, mcc


Expand Down

0 comments on commit 3efbb0c

Please sign in to comment.