Skip to content

Commit

Permalink
correctly save suspicious data
Browse files Browse the repository at this point in the history
  • Loading branch information
ccdavis committed Nov 19, 2024
1 parent 1f2bd49 commit 21cac61
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 24 deletions.
19 changes: 9 additions & 10 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,14 @@ def _choose_best_training_results(self, evals: list[ModelEval]) -> ModelEval:
def _evaluate_threshold_combinations(
self,
hyperparam_evaluation_results: list[ModelEval],
suspicious_data: Any,
splits: list[list[pyspark.sql.DataFrame]],
dep_var: str,
id_a: str,
id_b: str,
) -> dict[str, Any]:
) -> tuple[dict[str, Any], Any]:
training_conf = str(self.task.training_conf)
config = self.task.link_run.config

# Stores suspicious data
otd_data = self._create_otd_data(id_a, id_b)
config = self.task.link_run.config

thresholded_metrics_df = _create_thresholded_metrics_df()

Expand Down Expand Up @@ -299,7 +297,7 @@ def _evaluate_threshold_combinations(
dep_var,
thresholding_model,
results_dfs[i],
otd_data,
suspicious_data,
this_alpha_threshold,
this_threshold_ratio,
best_results.score,
Expand All @@ -316,7 +314,7 @@ def _evaluate_threshold_combinations(
best_results.hyperparams,
)

return thresholded_metrics_df
return thresholded_metrics_df, suspicious_data

def _run(self) -> None:
training_conf = str(self.task.training_conf)
Expand Down Expand Up @@ -356,8 +354,8 @@ def _run(self) -> None:
model_parameters, splits, dep_var, id_a, id_b, config, training_conf
)

thresholded_metrics_df = self._evaluate_threshold_combinations(
hyperparam_evaluation_results, splits, dep_var, id_a, id_b
thresholded_metrics_df, suspicious_data = self._evaluate_threshold_combinations(
hyperparam_evaluation_results, otd_data, splits, dep_var, id_a, id_b
)

thresholded_metrics_df = _load_thresholded_metrics_df_params(
Expand All @@ -366,7 +364,7 @@ def _run(self) -> None:

_print_thresholded_metrics_df(thresholded_metrics_df)
self._save_training_results(thresholded_metrics_df, self.task.spark)
self._save_otd_data(otd_data, self.task.spark)
self._save_otd_data(suspicious_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")

def _get_splits(
Expand Down Expand Up @@ -538,6 +536,7 @@ def _save_otd_data(
table_prefix = self.task.table_prefix

if otd_data is None:
print("OTD suspicious data is None, not saving.")
return
id_a = otd_data["id_a"]
id_b = otd_data["id_b"]
Expand Down
16 changes: 2 additions & 14 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ def test_all(
model_exploration.run_step(1)
model_exploration.run_step(2)

prc = spark.table("model_eval_precision_recall_curve_probit__").toPandas()
assert all(
elem in list(prc.columns)
for elem in ["params", "precision", "recall", "threshold_gt_eq"]
)
prc_rf = spark.table(
"model_eval_precision_recall_curve_random_forest__maxdepth___5_0___numtrees___75_0_"
).toPandas()
assert all(
elem in list(prc_rf.columns)
for elem in ["params", "precision", "recall", "threshold_gt_eq"]
)

tr = spark.table("model_eval_training_results").toPandas()

assert tr.__len__() == 3
Expand Down Expand Up @@ -372,6 +359,7 @@ def test_step_2_train_gradient_boosted_trees_spark(
# pdb.set_trace()

training_results = tr.query("model == 'gradient_boosted_trees'")

print(f"XX training_results: {training_results}")

# assert tr.shape == (1, 18)
Expand All @@ -388,7 +376,7 @@ def test_step_2_train_gradient_boosted_trees_spark(
main.do_drop_all("")


def test_step_2_interact_categorial_vars(
def test_step_2_interact_categorical_vars(
spark, training_conf, model_exploration, state_dist_path, training_data_path
):
"""Test matching step 2 training to see if the OneHotEncoding is working"""
Expand Down

0 comments on commit 21cac61

Please sign in to comment.