Skip to content

Commit

Permalink
Tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ccdavis committed Nov 25, 2024
1 parent 3b22f14 commit efa67f7
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 92 deletions.
196 changes: 109 additions & 87 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _evaluate_threshold_combinations(
self,
hyperparam_evaluation_results: list[ModelEval],
suspicious_data: Any,
splits: list[list[pyspark.sql.DataFrame]],
split: list[pyspark.sql.DataFrame],
dep_var: str,
id_a: str,
id_b: str,
Expand All @@ -297,101 +297,96 @@ def _evaluate_threshold_combinations(

print(f"\n======== Best Model and Parameters ========\n")
print(f"\t{best_results}\n")
print("=============================================\n]\n")
print("=============================================\n\n")

# TODO check if we should make a different split, like starting from a different seed?
# or just not re-using one we used in making the PR_AUC mean value?
# splits_for_thresholding_eval = splits[0]
# thresholding_training_data = splits_for_thresholding_eval[0].cache()
# thresholding_test_data = splits_for_thresholding_eval[1].cache()
threshold_matrix = best_results.make_threshold_matrix()
logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")
print(
f"Testing the best model + parameters against all {len(threshold_matrix)} threshold combinations."
f"\nTesting the best model + parameters against all {len(threshold_matrix)} threshold combinations.\n"
)
results_dfs: dict[int, pd.DataFrame] = {}
for i in range(len(threshold_matrix)):
results_dfs[i] = _create_results_df()

for split_index, (
thresholding_training_data,
thresholding_test_data,
) in enumerate(splits, 1):
cached_training_data = thresholding_training_data.cache()
cached_test_data = thresholding_test_data.cache()

thresholding_classifier, thresholding_post_transformer = (
classifier_core.choose_classifier(
best_results.model_type, best_results.hyperparams, dep_var
)
thresholding_training_data = split[0]
thresholding_test_data = split[1]

cached_training_data = thresholding_training_data.cache()
cached_test_data = thresholding_test_data.cache()

thresholding_classifier, thresholding_post_transformer = (
classifier_core.choose_classifier(
best_results.model_type, best_results.hyperparams, dep_var
)
thresholding_model = thresholding_classifier.fit(cached_training_data)
)
thresholding_model = thresholding_classifier.fit(cached_training_data)

thresholding_predictions = _get_probability_and_select_pred_columns(
cached_test_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
)
thresholding_predict_train = _get_probability_and_select_pred_columns(
cached_training_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
)

thresholding_predictions = _get_probability_and_select_pred_columns(
cached_test_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
i = 0
for threshold_index, (
this_alpha_threshold,
this_threshold_ratio,
) in enumerate(threshold_matrix, 1):

diag = (
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{this_alpha_threshold=} and {this_threshold_ratio=}"
)
thresholding_predict_train = _get_probability_and_select_pred_columns(
cached_training_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
logger.debug(diag)
predictions = threshold_core.predict_using_thresholds(
thresholding_predictions,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
)
predict_train = threshold_core.predict_using_thresholds(
thresholding_predict_train,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
)

i = 0
for threshold_index, (
results_dfs[i] = self._capture_results(
predictions,
predict_train,
dep_var,
thresholding_model,
results_dfs[i],
suspicious_data,
this_alpha_threshold,
this_threshold_ratio,
) in enumerate(threshold_matrix, 1):

diag = (
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{this_alpha_threshold=} and {this_threshold_ratio=}"
)
logger.debug(diag)
predictions = threshold_core.predict_using_thresholds(
thresholding_predictions,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
)
predict_train = threshold_core.predict_using_thresholds(
thresholding_predict_train,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
)

results_dfs[i] = self._capture_results(
predictions,
predict_train,
dep_var,
thresholding_model,
results_dfs[i],
suspicious_data,
this_alpha_threshold,
this_threshold_ratio,
best_results.score,
)
i += 1
thresholding_test_data.unpersist()
thresholding_training_data.unpersist()

for i in range(len(threshold_matrix)):
thresholded_metrics_df = _append_results(
thresholded_metrics_df,
results_dfs[i],
best_results.model_type,
best_results.hyperparams,
)
best_results.score,
)

# for i in range(len(threshold_matrix)):
thresholded_metrics_df = _append_results(
thresholded_metrics_df,
results_dfs[i],
best_results.model_type,
best_results.hyperparams,
)
i += 1

thresholding_test_data.unpersist()
thresholding_training_data.unpersist()

return thresholded_metrics_df, suspicious_data

Expand All @@ -417,10 +412,15 @@ def _run(self) -> None:
otd_data = self._create_otd_data(id_a, id_b)

n_training_iterations = config[training_conf].get("n_training_iterations", 10)
if n_training_iterations < 2:
raise RuntimeError("You must use at least two training iterations.")

seed = config[training_conf].get("seed", 2133)

splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed)
model_evaluation_splits = self._get_splits(
prepped_data, id_a, n_training_iterations, seed
)
thresholding_split = model_evaluation_splits.pop()

# Explode params into all the combinations we want to test with the current model.
model_parameters = self._get_model_parameters(config)
Expand All @@ -431,22 +431,35 @@ def _run(self) -> None:
)

hyperparam_evaluation_results = self._evaluate_hyperparam_combinations(
model_parameters, splits, dep_var, id_a, id_b, config, training_conf
model_parameters,
model_evaluation_splits,
dep_var,
id_a,
id_b,
config,
training_conf,
)

# TODO: We may want to recreate a new split or set of splits rather than reuse existing splits.
thresholded_metrics_df, suspicious_data = self._evaluate_threshold_combinations(
hyperparam_evaluation_results, otd_data, splits, dep_var, id_a, id_b
hyperparam_evaluation_results,
otd_data,
thresholding_split,
dep_var,
id_a,
id_b,
)

# TODO: thresholded_metrics_df has one row per split currently and we may want to
# crunch that set down to get the mean or median of some measures across all the splits.
# thresholded_metrics_df has one row per threshold combination.
thresholded_metrics_df = _load_thresholded_metrics_df_params(
thresholded_metrics_df
)

print("*** Final thresholded metrics ***")
_print_thresholded_metrics_df(thresholded_metrics_df)

_print_thresholded_metrics_df(
thresholded_metrics_df.sort_values(by="mcc_test_mean", ascending=False)
)
self._save_training_results(thresholded_metrics_df, self.task.spark)
self._save_otd_data(suspicious_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")
Expand All @@ -464,6 +477,7 @@ def _get_splits(
itself a list of two DataFrames which are the splits of prepped_data.
The split DataFrames are roughly equal in size.
"""
print(f"Splitting prepped data that starts with {prepped_data.count()} rows.")
if self.task.link_run.config[f"{self.task.training_conf}"].get(
"split_by_id_a", False
):
Expand All @@ -486,6 +500,14 @@ def _get_splits(
for i in range(n_training_iterations)
]

print(f"There are {len(splits)}")
for index, s in enumerate(splits, 1):
training_data = s[0]
test_data = s[1]

print(
f"Split {index}: training rows {training_data.count()} test rows: {test_data.count()}"
)
return splits

def _custom_param_grid_builder(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -884,7 +906,7 @@ def _append_results(
thresholded_metrics_df = pd.concat(
[thresholded_metrics_df, new_desc], ignore_index=True
)
# _print_thresholded_metrics_df(thresholded_metrics_df)

return thresholded_metrics_df


Expand Down
8 changes: 3 additions & 5 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def test_all(
tr = spark.table("model_eval_training_results").toPandas()
print(f"Test all results: {tr}")

# We need 8 rows because there are 4 splits and we test each combination of thresholds against
# each split -- in this case there are only 2 threshold combinations.
assert tr.__len__() == 8
assert tr.__len__() == 2
assert tr.query("threshold_ratio == 1.01")["precision_test_mean"].iloc[0] >= 0.5
assert tr.query("threshold_ratio == 1.3")["alpha_threshold"].iloc[0] == 0.8

Expand Down Expand Up @@ -370,11 +368,11 @@ def test_step_2_train_gradient_boosted_trees_spark(

training_results = tr.query("model == 'gradient_boosted_trees'")

print(f"XX training_results: {training_results}")
# print(f"XX training_results: {training_results}")

# assert tr.shape == (1, 18)
assert (
tr.query("model == 'gradient_boosted_trees'")["precision_test_mean"].iloc[1] > 0
tr.query("model == 'gradient_boosted_trees'")["precision_test_mean"].iloc[0] > 0
)
assert tr.query("model == 'gradient_boosted_trees'")["maxDepth"].iloc[0] == 5
assert (
Expand Down

0 comments on commit efa67f7

Please sign in to comment.