Skip to content

Commit

Permalink
giving up for now
Browse files Browse the repository at this point in the history
  • Loading branch information
ccdavis committed Nov 16, 2024
1 parent c0397c5 commit 28c6cde
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _run(self) -> None:
else:
threshold_ratio = False


# Collect auc values so we can pull out the highest
splits_results = []

Expand Down Expand Up @@ -142,7 +141,7 @@ def _run(self) -> None:

thresholds_plus_1 = np.append(thresholds_raw, [np.nan])
param_text = np.full(precision.shape, f"{model_type}_{params}")

if first:
prc = pd.DataFrame(
{
Expand Down Expand Up @@ -171,7 +170,7 @@ def _run(self) -> None:
"auc_mean": auc_mean,
"auc_standard_deviation": auc_std,
"model": model_type,
"params": params
"params": params,
}
print(f"PR AUC for splits on current model and params: {pr_auc_dict}")
pr_auc_info.append(pr_auc_info)
Expand All @@ -181,7 +180,6 @@ def _run(self) -> None:
[probability_metrics_df, this_model_results]
)


threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio)
logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")
results_dfs: dict[int, pd.DataFrame] = {}
Expand Down Expand Up @@ -219,9 +217,10 @@ def _run(self) -> None:
)

i = 0
for threshold_index, (this_alpha_threshold, this_threshold_ratio) in enumerate(
threshold_matrix, 1
):
for threshold_index, (
this_alpha_threshold,
this_threshold_ratio,
) in enumerate(threshold_matrix, 1):
logger.debug(
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{this_alpha_threshold=} and {this_threshold_ratio=}"
Expand Down Expand Up @@ -256,13 +255,13 @@ def _run(self) -> None:

for i in range(len(threshold_matrix)):
thresholded_metrics_df = _append_results(
thresholded_metrics_df, results_dfs[i], pr_auc_dict["model"], pr_auc_dict["params"]
thresholded_metrics_df, results_dfs[i], model_type, params
)

_print_thresholded_metrics_df(thresholded_metrics_df)
thresholded_metrics_df = _load_thresholded_metrics_df_params(
thresholded_metrics_df
)
_print_thresholded_metrics_df(thresholded_metrics_df)
self._save_training_results(thresholded_metrics_df, self.task.spark)
self._save_otd_data(otd_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")
Expand Down

0 comments on commit 28c6cde

Please sign in to comment.