From 1f70f664355da9d2a5f11466f6b5ba59c1880efa Mon Sep 17 00:00:00 2001 From: Colin Davis Date: Mon, 18 Nov 2024 12:35:18 -0600 Subject: [PATCH] wip --- .../link_step_train_test_models.py | 3 ++- hlink/tests/model_exploration_test.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index b6fdf28..385926b 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -261,6 +261,7 @@ def _run(self) -> None: thresholded_metrics_df = _load_thresholded_metrics_df_params( thresholded_metrics_df ) + _print_thresholded_metrics_df(thresholded_metrics_df) self._save_training_results(thresholded_metrics_df, self.task.spark) self._save_otd_data(otd_data, self.task.spark) @@ -744,7 +745,7 @@ def _create_thresholded_metrics_df() -> pd.DataFrame: return pd.DataFrame( columns=[ "model", - "pa rameters", + "parameters", "alpha_threshold", "threshold_ratio", "precision_test_mean", diff --git a/hlink/tests/model_exploration_test.py b/hlink/tests/model_exploration_test.py index 7ef1f92..36ee92f 100644 --- a/hlink/tests/model_exploration_test.py +++ b/hlink/tests/model_exploration_test.py @@ -100,12 +100,15 @@ def test_all( preds = spark.table("model_eval_predictions").toPandas() assert ( - preds.query("id_a == 20 and id_b == 30")["second_best_prob"].round(2).iloc[0] - >= 0.6 + preds.query("id_a == 20 and id_b == 30")["probability"].round(2).iloc[0] > 0.5 ) + + assert ( - preds.query("id_a == 20 and id_b == 30")["probability"].round(2).iloc[0] > 0.5 + preds.query("id_a == 20 and id_b == 30")["second_best_prob"].round(2).iloc[0] + >= 0.6 ) + assert preds.query("id_a == 30 and id_b == 30")["prediction"].iloc[0] == 0 assert pd.isnull( preds.query("id_a == 10 and id_b == 30")["second_best_prob"].iloc[0] @@ -365,6 +368,12 @@ def test_step_2_train_gradient_boosted_trees_spark( preds = spark.table("model_eval_predictions").toPandas() assert "probability_array" in list(preds.columns) + + #import pdb + #pdb.set_trace() + + training_results = tr.query("model == 'gradient_boosted_trees'") + print(f"XX training_results: {training_results}") # assert tr.shape == (1, 18) assert (