From ad6ce10ecef2fc9bca587bf37fe37f805d2ad139 Mon Sep 17 00:00:00 2001 From: rileyh Date: Thu, 5 Dec 2024 19:05:14 +0000 Subject: [PATCH] [#174] Pass just decision into predict_with_thresholds() instead of the whole training config This makes it clear which part of the config predict_with_thresholds() is using and makes it easier to call. It also means that predict_with_thresholds() does not need to know about the structure of the config. --- hlink/linking/core/threshold.py | 8 ++++---- hlink/linking/matching/link_step_score.py | 3 ++- .../model_exploration/link_step_train_test_models.py | 5 +++-- hlink/tests/core/threshold_test.py | 5 ++--- hlink/tests/matching_scoring_test.py | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/hlink/linking/core/threshold.py b/hlink/linking/core/threshold.py index 720b559..789afd3 100644 --- a/hlink/linking/core/threshold.py +++ b/hlink/linking/core/threshold.py @@ -40,8 +40,8 @@ def predict_using_thresholds( pred_df: DataFrame, alpha_threshold: float, threshold_ratio: float, - training_conf: dict[str, Any], id_col: str, + decision: str | None, ) -> DataFrame: """Adds a prediction column to the given pred_df by applying thresholds. @@ -57,17 +57,17 @@ def predict_using_thresholds( to the "a" record's next best probability value. Only used with the "drop_duplicate_with_threshold_ratio" configuration value. - training_conf: dictionary - the training config section id_col: string the id column + decision: str | None + how to apply the thresholds Returns ------- A Spark DataFrame containing the "prediction" column as well as other intermediate columns generated to create the prediction. """ use_threshold_ratio = ( - training_conf.get("decision", "") == "drop_duplicate_with_threshold_ratio" + decision is not None and decision == "drop_duplicate_with_threshold_ratio" ) if use_threshold_ratio: diff --git a/hlink/linking/matching/link_step_score.py b/hlink/linking/matching/link_step_score.py index b4d192e..12b5da3 100644 --- a/hlink/linking/matching/link_step_score.py +++ b/hlink/linking/matching/link_step_score.py @@ -96,12 +96,13 @@ def _run(self): threshold_ratio = threshold_core.get_threshold_ratio( config[training_conf], chosen_model_params, default=1.3 ) + decision = config[training_conf].get("decision") predictions = threshold_core.predict_using_thresholds( score_tmp, alpha_threshold, threshold_ratio, - config[training_conf], config["id_column"], + decision, ) predictions.write.mode("overwrite").saveAsTable(f"{table_prefix}predictions") pmp = self.task.spark.table(f"{table_prefix}potential_matches_pipeline") diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 1486c53..a05c3ed 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -411,20 +411,21 @@ def _evaluate_threshold_combinations( f"{this_alpha_threshold=} and {this_threshold_ratio=}" ) logger.debug(diag) + decision = training_settings.get("decision") start_predict_time = perf_counter() predictions = threshold_core.predict_using_thresholds( thresholding_predictions, this_alpha_threshold, this_threshold_ratio, - training_settings, id_column, + decision, ) predict_train = threshold_core.predict_using_thresholds( thresholding_predict_train, this_alpha_threshold, this_threshold_ratio, - training_settings, id_column, + decision, ) end_predict_time = perf_counter() diff --git a/hlink/tests/core/threshold_test.py b/hlink/tests/core/threshold_test.py index 3bb0272..b477b09 100644 --- a/hlink/tests/core/threshold_test.py +++ b/hlink/tests/core/threshold_test.py @@ -26,7 +26,7 @@ def test_predict_using_thresholds_default_decision(spark: SparkSession) -> None: # We are using the default decision, so threshold_ratio will be ignored predictions = predict_using_thresholds( - df, alpha_threshold=0.6, threshold_ratio=0.0, training_conf={}, id_col="id" + df, alpha_threshold=0.6, threshold_ratio=0.0, id_col="id", decision=None ) output_rows = ( @@ -64,13 +64,12 @@ def test_predict_using_thresholds_drop_duplicates_decision(spark: SparkSession) (3, "E", 0.8), ] df = spark.createDataFrame(input_rows, schema=["id_a", "id_b", "probability"]) - training_conf = {"decision": "drop_duplicate_with_threshold_ratio"} predictions = predict_using_thresholds( df, alpha_threshold=0.5, threshold_ratio=2.0, - training_conf=training_conf, id_col="id", + decision="drop_duplicate_with_threshold_ratio", ) output_rows = ( diff --git a/hlink/tests/matching_scoring_test.py b/hlink/tests/matching_scoring_test.py index 613e1f6..191663c 100755 --- a/hlink/tests/matching_scoring_test.py +++ b/hlink/tests/matching_scoring_test.py @@ -51,8 +51,8 @@ def test_step_2_alpha_beta_thresholds( score_tmp, alpha_threshold, threshold_ratio, - matching_conf["training"], matching_conf["id_column"], + matching_conf["training"].get("decision"), ) predictions.write.mode("overwrite").saveAsTable("predictions")