diff --git a/hlink/linking/core/model_metrics.py b/hlink/linking/core/model_metrics.py index 7222c55..3352cb2 100644 --- a/hlink/linking/core/model_metrics.py +++ b/hlink/linking/core/model_metrics.py @@ -18,3 +18,21 @@ def mcc(tp: int, tn: int, fp: int, fn: int) -> float: else: mcc = 0 return mcc + + +def precision(tp: int, fp: int) -> float: + if (tp + fp) == 0: + precision = np.nan + else: + precision = tp / (tp + fp) + + return precision + + +def recall(tp: int, fn: int) -> float: + if (tp + fn) == 0: + recall = np.nan + else: + recall = tp / (tp + fn) + + return recall diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 4498ed1..c3477d2 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -774,14 +774,8 @@ def _get_aggregate_metrics( Return a tuple of (precision, recall, Matthews Correlation Coefficient). """ - if (true_positives + false_positives) == 0: - precision = np.nan - else: - precision = true_positives / (true_positives + false_positives) - if (true_positives + false_negatives) == 0: - recall = np.nan - else: - recall = true_positives / (true_positives + false_negatives) + precision = metrics_core.precision(true_positives, false_positives) + recall = metrics_core.recall(true_positives, false_negatives) mcc = metrics_core.mcc( true_positives, true_negatives, false_positives, false_negatives )