diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 78f7f21..b3859f6 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -141,14 +141,18 @@ def make_threshold_matrix(self) -> list[list[float]]: # Both training and test results can be captured in this type @dataclass(kw_only=True) class ThresholdTestResult: + model_id: str + alpha_threshold: float + threshold_ratio: float + true_pos: int + true_neg: int + false_pos: int + false_neg: int precision: float recall: float mcc: float f_measure: float pr_auc: float - model_id: str - alpha_threshold: float - threshold_ratio: float class LinkStepTrainTestModels(LinkStep): @@ -654,25 +658,29 @@ def _capture_prediction_results( predictions.createOrReplaceTempView(f"{table_prefix}predictions") ( - tp_count, - fp_count, - fn_count, - tn_count, + true_pos, + false_pos, + false_neg, + true_neg, ) = _get_confusion_matrix(predictions, dep_var) - precision = metrics_core.precision(tp_count, fp_count) - recall = metrics_core.recall(tp_count, fn_count) - mcc = metrics_core.mcc(tp_count, tn_count, fp_count, fn_count) - f_measure = metrics_core.f_measure(tp_count, fp_count, fn_count) + precision = metrics_core.precision(true_pos, false_pos) + recall = metrics_core.recall(true_pos, false_neg) + mcc = metrics_core.mcc(true_pos, true_neg, false_pos, false_neg) + f_measure = metrics_core.f_measure(true_pos, false_pos, false_neg) result = ThresholdTestResult( + model_id=model, + alpha_threshold=alpha_threshold, + threshold_ratio=threshold_ratio, + true_pos=true_pos, + true_neg=true_neg, + false_pos=false_pos, + false_neg=false_neg, precision=precision, recall=recall, mcc=mcc, f_measure=f_measure, pr_auc=pr_auc, - model_id=model, - alpha_threshold=alpha_threshold, - threshold_ratio=threshold_ratio, ) return result @@ -746,24 +754,23 @@ def _get_confusion_matrix( confusion matrix is the count of true positives, false positives, false negatives, and true negatives for the predictions. - Return a tuple (true_positives, false_positives, false_negatives, - true_negatives). + Return a tuple (true_pos, false_pos, false_neg, true_neg). """ prediction_col = col("prediction") label_col = col(dep_var) confusion_matrix = predictions.select( - count_if((label_col == 1) & (prediction_col == 1)).alias("true_positives"), - count_if((label_col == 0) & (prediction_col == 1)).alias("false_positives"), - count_if((label_col == 1) & (prediction_col == 0)).alias("false_negatives"), - count_if((label_col == 0) & (prediction_col == 0)).alias("true_negatives"), + count_if((label_col == 1) & (prediction_col == 1)).alias("true_pos"), + count_if((label_col == 0) & (prediction_col == 1)).alias("false_pos"), + count_if((label_col == 1) & (prediction_col == 0)).alias("false_neg"), + count_if((label_col == 0) & (prediction_col == 0)).alias("true_neg"), ) [confusion_row] = confusion_matrix.collect() return ( - confusion_row.true_positives, - confusion_row.false_positives, - confusion_row.false_negatives, - confusion_row.true_negatives, + confusion_row.true_pos, + confusion_row.false_pos, + confusion_row.false_neg, + confusion_row.true_neg, )