Skip to content

Commit

Permalink
better output for tracking progress of train-test
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Davis committed Dec 2, 2024
1 parent 761e38f commit 3e0cb90
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class ModelEval:
threshold: float | list[float]
threshold_ratio: float | list[float] | bool

def print(self):
return f"{self.model_type} {self.score} params: {self.hyperparams}"

def make_threshold_matrix(self) -> list[list[float]]:
return _calc_threshold_matrix(self.threshold, self.threshold_ratio)

Expand Down Expand Up @@ -204,6 +207,7 @@ def _evaluate_hyperparam_combinations(
config,
training_conf,
) -> list[ModelEval]:
print("Begin evaluating all selected hyperparameters.")
results = []
for index, params_combo in enumerate(all_model_parameter_combos, 1):
eval_start_info = f"Starting run {index} of {len(all_model_parameter_combos)} with these parameters: {params_combo}"
Expand Down Expand Up @@ -239,6 +243,7 @@ def _evaluate_hyperparam_combinations(
threshold=threshold,
threshold_ratio=threshold_ratio,
)
print(f"{index}: {model_eval.print()}")
results.append(model_eval)
return results

Expand Down Expand Up @@ -457,6 +462,10 @@ def _run(self) -> None:
training_conf,
)

print(
f"Take the best hyper-parameter set from {len(hyperparam_evaluation_results)} results and test every threshold combination against it..."
)

thresholded_metrics_df, suspicious_data = (
self._evaluate_threshold_combinations(
hyperparam_evaluation_results,
Expand Down Expand Up @@ -491,12 +500,17 @@ def _split_into_folds(
def _combine_folds(
self, folds: list[pyspark.sql.DataFrame], ignore=None
) -> pyspark.sql.DataFrame:

folds_to_combine = []
for fold_number, fold in enumerate(folds, 0):
if fold_number != ignore:
folds_to_combine.append(fold)

return reduce(DataFrame.unionAll, folds_to_combine)
combined = reduce(DataFrame.unionAll, folds_to_combine).cache()
print(
f"Combine non-test outer folds into {combined.count()} training data records."
)
return combined

def _get_outer_folds(
self, prepped_data: pyspark.sql.DataFrame, id_a: str, k_folds: int, seed: int
Expand Down

0 comments on commit 3e0cb90

Please sign in to comment.