From c0397c598a01f4d5e111474493a66aee9b80a720 Mon Sep 17 00:00:00 2001 From: Colin Davis Date: Fri, 15 Nov 2024 18:25:11 -0600 Subject: [PATCH] Renaming for clarity --- .../link_step_train_test_models.py | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 3d98abe..7896142 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -77,6 +77,7 @@ def _run(self) -> None: ) probability_metrics_df = _create_probability_metrics_df() + pr_auc_info = [] for run_index, run in enumerate(model_parameters, 1): run_start_info = f"Starting run {run_index} of {len(model_parameters)} with these parameters: {run}" print(run_start_info) @@ -98,13 +99,7 @@ def _run(self) -> None: else: threshold_ratio = False - threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio) - logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries") - - results_dfs: dict[int, pd.DataFrame] = {} - for i in range(len(threshold_matrix)): - results_dfs[i] = _create_results_df() - + # Collect auc values so we can pull out the highest splits_results = [] @@ -141,14 +136,13 @@ def _run(self) -> None: test_pred["probability"].round(2), pos_label=1, ) - - thresholds_plus_1 = np.append(thresholds_raw, [np.nan]) - param_text = np.full(precision.shape, f"{model_type}_{params}") - pr_auc = auc(recall, precision) print(f"The area under the precision-recall curve is {pr_auc}") splits_results.append(pr_auc) + thresholds_plus_1 = np.append(thresholds_raw, [np.nan]) + param_text = np.full(precision.shape, f"{model_type}_{params}") + if first: prc = pd.DataFrame( { @@ -177,15 +171,23 @@ def _run(self) -> None: "auc_mean": auc_mean, "auc_standard_deviation": auc_std, "model": model_type, - "params": params, + "params": params } print(f"PR AUC for splits on current model and params: {pr_auc_dict}") + pr_auc_info.append(pr_auc_info) this_model_results = pd.DataFrame(pr_auc_dict) # I'm not sure what this dataframe is for probability_metrics_df = pd.concat( [probability_metrics_df, this_model_results] ) + + threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio) + logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries") + results_dfs: dict[int, pd.DataFrame] = {} + for i in range(len(threshold_matrix)): + results_dfs[i] = _create_results_df() + # TODO check if we should make a different split, like starting from a different seed? # or just not re-using one we used in making the PR_AUC mean value? splits_for_thresholding_eval = splits[0] @@ -217,24 +219,24 @@ def _run(self) -> None: ) i = 0 - for threshold_index, (alpha_threshold, threshold_ratio) in enumerate( + for threshold_index, (this_alpha_threshold, this_threshold_ratio) in enumerate( threshold_matrix, 1 ): logger.debug( f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: " - f"{alpha_threshold=} and {threshold_ratio=}" + f"{this_alpha_threshold=} and {this_threshold_ratio=}" ) predictions = threshold_core.predict_using_thresholds( thresholding_predictions, - alpha_threshold, - threshold_ratio, + this_alpha_threshold, + this_threshold_ratio, config[training_conf], config["id_column"], ) predict_train = threshold_core.predict_using_thresholds( thresholding_predict_train, - alpha_threshold, - threshold_ratio, + this_alpha_threshold, + this_threshold_ratio, config[training_conf], config["id_column"], ) @@ -246,15 +248,15 @@ def _run(self) -> None: thresholding_model, results_dfs[i], otd_data, - alpha_threshold, - threshold_ratio, + this_alpha_threshold, + this_threshold_ratio, pr_auc_dict["auc_mean"], ) i += 1 for i in range(len(threshold_matrix)): thresholded_metrics_df = _append_results( - thresholded_metrics_df, results_dfs[i], model_type, params + thresholded_metrics_df, results_dfs[i], pr_auc_dict["model"], pr_auc_dict["params"] ) _print_thresholded_metrics_df(thresholded_metrics_df)