diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 7896142..b6fdf28 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -99,7 +99,6 @@ def _run(self) -> None: else: threshold_ratio = False - # Collect auc values so we can pull out the highest splits_results = [] @@ -142,7 +141,7 @@ def _run(self) -> None: thresholds_plus_1 = np.append(thresholds_raw, [np.nan]) param_text = np.full(precision.shape, f"{model_type}_{params}") - + if first: prc = pd.DataFrame( { @@ -171,7 +170,7 @@ def _run(self) -> None: "auc_mean": auc_mean, "auc_standard_deviation": auc_std, "model": model_type, - "params": params + "params": params, } print(f"PR AUC for splits on current model and params: {pr_auc_dict}") pr_auc_info.append(pr_auc_info) @@ -181,7 +180,6 @@ def _run(self) -> None: [probability_metrics_df, this_model_results] ) - threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio) logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries") results_dfs: dict[int, pd.DataFrame] = {} @@ -219,9 +217,10 @@ def _run(self) -> None: ) i = 0 - for threshold_index, (this_alpha_threshold, this_threshold_ratio) in enumerate( - threshold_matrix, 1 - ): + for threshold_index, ( + this_alpha_threshold, + this_threshold_ratio, + ) in enumerate(threshold_matrix, 1): logger.debug( f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: " f"{this_alpha_threshold=} and {this_threshold_ratio=}" @@ -256,13 +255,13 @@ def _run(self) -> None: for i in range(len(threshold_matrix)): thresholded_metrics_df = _append_results( - thresholded_metrics_df, results_dfs[i], pr_auc_dict["model"], pr_auc_dict["params"] + thresholded_metrics_df, results_dfs[i], model_type, params ) - _print_thresholded_metrics_df(thresholded_metrics_df) thresholded_metrics_df = _load_thresholded_metrics_df_params( thresholded_metrics_df ) + _print_thresholded_metrics_df(thresholded_metrics_df) self._save_training_results(thresholded_metrics_df, self.task.spark) self._save_otd_data(otd_data, self.task.spark) self.task.spark.sql("set spark.sql.shuffle.partitions=200")