diff --git a/hlink/linking/core/threshold.py b/hlink/linking/core/threshold.py index b0523d3..d5cd5ba 100644 --- a/hlink/linking/core/threshold.py +++ b/hlink/linking/core/threshold.py @@ -94,38 +94,44 @@ def _apply_threshold_ratio( 'In order to calculate the threshold ratio based on probabilities, you need to have a "probability" column in your data.' ) - windowSpec = Window.partitionBy(df[id_a]).orderBy( - df["probability"].desc(), df[id_b] - ) + windowSpec = Window.partitionBy(id_a).orderBy(col("probability").desc(), id_b) prob_rank = rank().over(windowSpec) - prob_lead = lead(df["probability"], 1).over(windowSpec) + prob_lead = lead("probability", 1).over(windowSpec) + + should_compute_probability_ratio = ( + col("second_best_prob").isNotNull() + & (col("second_best_prob") >= alpha_threshold) + & (col("prob_rank") == 1) + ) + # To be a match, the row must... + # 1. Have prob_rank 1, so that it's the most likely match, + # 2. Have a probability of at least alpha_threshold, + # and + # 3. Either have no ratio (since there's no second best probability of at + # least alpha_threshold), or have a ratio of more than threshold_ratio. + is_match = ( + (col("probability") >= alpha_threshold) + & (col("prob_rank") == 1) + & ((col("ratio") > threshold_ratio) | col("ratio").isNull()) + ) return ( df.select( - df["*"], + "*", prob_rank.alias("prob_rank"), prob_lead.alias("second_best_prob"), ) - .selectExpr( + .select( "*", - f""" - IF( - second_best_prob IS NOT NULL - AND second_best_prob >= {alpha_threshold} - AND prob_rank == 1, - probability / second_best_prob, - NULL) - AS ratio - """, + when( + should_compute_probability_ratio, + col("probability") / col("second_best_prob"), + ) + .otherwise(None) + .alias("ratio"), ) - .selectExpr( + .select( "*", - f""" - CAST( - probability >= {alpha_threshold} - AND prob_rank == 1 - AND (ratio > {threshold_ratio} OR ratio IS NULL) - AS INT) AS prediction - """, + is_match.cast("integer").alias("prediction"), ) .drop("prob_rank") )