Skip to content

Commit

Permalink
[#179] Put the raw confusion matrix counts in the ThresholdTestResults
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Dec 12, 2024
1 parent 7f0c48c commit a53c120
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,18 @@ def make_threshold_matrix(self) -> list[list[float]]:
# Both training and test results can be captured in this type
@dataclass(kw_only=True)
class ThresholdTestResult:
model_id: str
alpha_threshold: float
threshold_ratio: float
true_pos: int
true_neg: int
false_pos: int
false_neg: int
precision: float
recall: float
mcc: float
f_measure: float
pr_auc: float
model_id: str
alpha_threshold: float
threshold_ratio: float


class LinkStepTrainTestModels(LinkStep):
Expand Down Expand Up @@ -654,25 +658,29 @@ def _capture_prediction_results(
predictions.createOrReplaceTempView(f"{table_prefix}predictions")

(
tp_count,
fp_count,
fn_count,
tn_count,
true_pos,
false_pos,
false_neg,
true_neg,
) = _get_confusion_matrix(predictions, dep_var)
precision = metrics_core.precision(tp_count, fp_count)
recall = metrics_core.recall(tp_count, fn_count)
mcc = metrics_core.mcc(tp_count, tn_count, fp_count, fn_count)
f_measure = metrics_core.f_measure(tp_count, fp_count, fn_count)
precision = metrics_core.precision(true_pos, false_pos)
recall = metrics_core.recall(true_pos, false_neg)
mcc = metrics_core.mcc(true_pos, true_neg, false_pos, false_neg)
f_measure = metrics_core.f_measure(true_pos, false_pos, false_neg)

result = ThresholdTestResult(
model_id=model,
alpha_threshold=alpha_threshold,
threshold_ratio=threshold_ratio,
true_pos=true_pos,
true_neg=true_neg,
false_pos=false_pos,
false_neg=false_neg,
precision=precision,
recall=recall,
mcc=mcc,
f_measure=f_measure,
pr_auc=pr_auc,
model_id=model,
alpha_threshold=alpha_threshold,
threshold_ratio=threshold_ratio,
)

return result
Expand Down Expand Up @@ -746,24 +754,23 @@ def _get_confusion_matrix(
confusion matrix is the count of true positives, false positives, false
negatives, and true negatives for the predictions.
Return a tuple (true_positives, false_positives, false_negatives,
true_negatives).
Return a tuple (true_pos, false_pos, false_neg, true_neg).
"""
prediction_col = col("prediction")
label_col = col(dep_var)

confusion_matrix = predictions.select(
count_if((label_col == 1) & (prediction_col == 1)).alias("true_positives"),
count_if((label_col == 0) & (prediction_col == 1)).alias("false_positives"),
count_if((label_col == 1) & (prediction_col == 0)).alias("false_negatives"),
count_if((label_col == 0) & (prediction_col == 0)).alias("true_negatives"),
count_if((label_col == 1) & (prediction_col == 1)).alias("true_pos"),
count_if((label_col == 0) & (prediction_col == 1)).alias("false_pos"),
count_if((label_col == 1) & (prediction_col == 0)).alias("false_neg"),
count_if((label_col == 0) & (prediction_col == 0)).alias("true_neg"),
)
[confusion_row] = confusion_matrix.collect()
return (
confusion_row.true_positives,
confusion_row.false_positives,
confusion_row.false_negatives,
confusion_row.true_negatives,
confusion_row.true_pos,
confusion_row.false_pos,
confusion_row.false_neg,
confusion_row.true_neg,
)


Expand Down

0 comments on commit a53c120

Please sign in to comment.