diff --git a/hlink/linking/core/threshold.py b/hlink/linking/core/threshold.py index 789afd3..b0f57a0 100644 --- a/hlink/linking/core/threshold.py +++ b/hlink/linking/core/threshold.py @@ -81,7 +81,7 @@ def predict_using_thresholds( def _apply_alpha_threshold(pred_df: DataFrame, alpha_threshold: float) -> DataFrame: return pred_df.selectExpr( "*", - f"case when probability >= {alpha_threshold} then 1 else 0 end as prediction", + f"CASE WHEN probability >= {alpha_threshold} THEN 1 ELSE 0 END AS prediction", ) @@ -95,39 +95,39 @@ def _apply_threshold_ratio( raise NameError( 'In order to calculate the threshold ratio based on probabilities, you need to have a "probability" column in your data.' ) - else: - windowSpec = Window.partitionBy(df[f"{id_a}"]).orderBy( - df["probability"].desc(), df[f"{id_b}"] + + windowSpec = Window.partitionBy(df[id_a]).orderBy( + df["probability"].desc(), df[id_b] + ) + prob_rank = rank().over(windowSpec) + prob_lead = lead(df["probability"], 1).over(windowSpec) + return ( + df.select( + df["*"], + prob_rank.alias("prob_rank"), + prob_lead.alias("second_best_prob"), ) - prob_rank = rank().over(windowSpec) - prob_lead = lead(df["probability"], 1).over(windowSpec) - return ( - df.select( - df["*"], - prob_rank.alias("prob_rank"), - prob_lead.alias("second_best_prob"), - ) - .selectExpr( - "*", - f""" - IF( - second_best_prob IS NOT NULL - AND second_best_prob >= {alpha_threshold} - AND prob_rank == 1, - probability / second_best_prob, - NULL) - as ratio - """, - ) - .selectExpr( - "*", - f""" - CAST( - probability >= {alpha_threshold} - AND prob_rank == 1 - AND (ratio > {threshold_ratio} OR ratio is NULL) - as INT) as prediction - """, - ) - .drop("prob_rank") + .selectExpr( + "*", + f""" + IF( + second_best_prob IS NOT NULL + AND second_best_prob >= {alpha_threshold} + AND prob_rank == 1, + probability / second_best_prob, + NULL) + AS ratio + """, ) + .selectExpr( + "*", + f""" + CAST( + probability >= {alpha_threshold} + AND prob_rank == 1 + AND (ratio > {threshold_ratio} OR ratio IS NULL) + AS INT) AS prediction + """, + ) + .drop("prob_rank") + )