Skip to content

Commit

Permalink
[#179] Simplify _aggregate_per_threshold_results()
Browse files Browse the repository at this point in the history
By pulling the mean and stdev calculation code out into its own
function, we can reduce some of the duplication. And in this case
catching a StatisticsError seems simpler than checking for certain
conditions to be met before calling the statistics functions.
  • Loading branch information
riley-harper committed Dec 12, 2024
1 parent a53c120 commit d87c5de
Showing 1 changed file with 24 additions and 29 deletions.
53 changes: 24 additions & 29 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,24 @@ def _combine_by_threshold_matrix_entry(
return results


def _compute_mean_and_stdev(values: list[float]) -> (float, float):
"""
Given a list of floats, return a tuple (mean, stdev). If there aren't enough
values to compute the mean and/or stdev, return np.nan for that entry.
"""
try:
mean = statistics.mean(values)
except statistics.StatisticsError:
mean = np.nan

try:
stdev = statistics.stdev(values)
except statistics.StatisticsError:
stdev = np.nan

return (mean, stdev)


def _aggregate_per_threshold_results(
thresholded_metrics_df: pd.DataFrame,
prediction_results: list[ThresholdTestResult],
Expand All @@ -824,33 +842,10 @@ def _aggregate_per_threshold_results(
pr_auc_test = [r.pr_auc for r in prediction_results if not math.isnan(r.pr_auc)]
mcc_test = [r.mcc for r in prediction_results if not math.isnan(r.mcc)]

# # variance requires at least two values
precision_test_sd = (
statistics.stdev(precision_test) if len(precision_test) > 1 else np.nan
)
recall_test_sd = statistics.stdev(recall_test) if len(recall_test) > 1 else np.nan
pr_auc_test_sd = statistics.stdev(pr_auc_test) if len(pr_auc_test) > 1 else np.nan
mcc_test_sd = statistics.stdev(mcc_test) if len(mcc_test) > 1 else np.nan

# Deal with tiny test data. This should never arise in practice but if it did we ought
# to issue a warning.
if len(precision_test) < 1:
# raise RuntimeError("Not enough training data to get any valid precision values.")
precision_test_mean = np.nan
else:
precision_test_mean = (
statistics.mean(precision_test)
if len(precision_test) > 1
else precision_test[0]
)

if len(recall_test) < 1:
# raise RuntimeError("Not enough training data to get any valid recall values.")
recall_test_mean = np.nan
else:
recall_test_mean = (
statistics.mean(recall_test) if len(recall_test) > 1 else recall_test[0]
)
(precision_test_mean, precision_test_sd) = _compute_mean_and_stdev(precision_test)
(recall_test_mean, recall_test_sd) = _compute_mean_and_stdev(recall_test)
(pr_auc_test_mean, pr_auc_test_sd) = _compute_mean_and_stdev(pr_auc_test)
(mcc_test_mean, mcc_test_sd) = _compute_mean_and_stdev(mcc_test)

new_desc = pd.DataFrame(
{
Expand All @@ -862,9 +857,9 @@ def _aggregate_per_threshold_results(
"precision_test_sd": [precision_test_sd],
"recall_test_mean": [recall_test_mean],
"recall_test_sd": [recall_test_sd],
"pr_auc_test_mean": [statistics.mean(pr_auc_test)],
"pr_auc_test_mean": [pr_auc_test_mean],
"pr_auc_test_sd": [pr_auc_test_sd],
"mcc_test_mean": [statistics.mean(mcc_test)],
"mcc_test_mean": [mcc_test_mean],
"mcc_test_sd": [mcc_test_sd],
},
)
Expand Down

0 comments on commit d87c5de

Please sign in to comment.