diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 6025998..d779121 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -21,7 +21,7 @@ from pyspark.ml import Model, Transformer import pyspark.sql from pyspark.sql import DataFrame -from pyspark.sql.functions import count, mean +from pyspark.sql.functions import col, count, count_if, mean from functools import reduce import hlink.linking.core.threshold as threshold_core import hlink.linking.core.classifier as classifier_core @@ -752,27 +752,30 @@ def _get_confusion_matrix( predictions: pyspark.sql.DataFrame, dep_var: str, ) -> tuple[int, int, int, int]: - TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1)) - TP_count = TP.count() - - FP = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 1)) - FP_count = FP.count() - - # print( - # f"Confusion matrix -- true positives and false positivesTP {TP_count} FP {FP_count}" - # ) - - FN = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 0)) - FN_count = FN.count() - - TN = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 0)) - TN_count = TN.count() - - # print( - # f"Confusion matrix -- true negatives and false negatives: FN {FN_count} TN {TN_count}" - # ) + """ + Compute the confusion matrix for the given DataFrame of predictions. The + confusion matrix is the count of true positives, false positives, false + negatives, and true negatives for the predictions. - return TP_count, FP_count, FN_count, TN_count + Return a tuple (true_positives, false_positives, false_negatives, + true_negatives). + """ + prediction_col = col("prediction") + label_col = col(dep_var) + + confusion_matrix = predictions.select( + count_if((label_col == 1) & (prediction_col == 1)).alias("true_positives"), + count_if((label_col == 0) & (prediction_col == 1)).alias("false_positives"), + count_if((label_col == 1) & (prediction_col == 0)).alias("false_negatives"), + count_if((label_col == 0) & (prediction_col == 0)).alias("true_negatives"), + ) + [confusion_row] = confusion_matrix.collect() + return ( + confusion_row.true_positives, + confusion_row.false_positives, + confusion_row.false_negatives, + confusion_row.true_negatives, + ) def _get_aggregate_metrics(