diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 8e391b8..a5e0273 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -60,6 +60,7 @@ def _run(self) -> None: .cache() ) + # Stores suspicious data otd_data = self._create_otd_data(id_a, id_b) n_training_iterations = config[training_conf].get("n_training_iterations", 10) @@ -101,6 +102,9 @@ def _run(self) -> None: for i in range(len(threshold_matrix)): results_dfs[i] = _create_results_df() + # Collect auc values so we can pull out the highest + splits_results = [] + first = True for split_index, (training_data, test_data) in enumerate(splits, 1): split_start_info = f"Training and testing the model on train-test split {split_index} of {n_training_iterations}" @@ -140,6 +144,13 @@ def _run(self) -> None: pr_auc = auc(recall, precision) print(f"The area under the precision-recall curve is {pr_auc}") + splits_results.append( + { + "auc": pr_auc, + "predictions_tmp": predictions_tmp, + "predict_train_tmp": predict_train_tmp, + } + ) if first: prc = pd.DataFrame( @@ -159,45 +170,54 @@ def _run(self) -> None: first = False - i = 0 - for threshold_index, (alpha_threshold, threshold_ratio) in enumerate( - threshold_matrix, 1 - ): - logger.debug( - f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: " - f"{alpha_threshold=} and {threshold_ratio=}" - ) - predictions = threshold_core.predict_using_thresholds( - predictions_tmp, - alpha_threshold, - threshold_ratio, - config[training_conf], - config["id_column"], - ) - predict_train = threshold_core.predict_using_thresholds( - predict_train_tmp, - alpha_threshold, - threshold_ratio, - config[training_conf], - config["id_column"], - ) - - results_dfs[i] = self._capture_results( - predictions, - predict_train, - dep_var, - model, - results_dfs[i], - otd_data, - alpha_threshold, - threshold_ratio, - pr_auc, - ) - i += 1 - training_data.unpersist() test_data.unpersist() + # pluck out predictions_tmp, predict_train_tmp associated with highest pr_auc + best_pr_auc = 0.0 + best_predictions_tmp = None + best_predict_train_tmp = None + for a in splits_results: + if a["auc"] > best_pr_auc: + best_prediction_tmp = a["predictions_tmp"] + best_predict_train_tmp = a["predict_train_tmp"] + + i = 0 + for threshold_index, (alpha_threshold, threshold_ratio) in enumerate( + threshold_matrix, 1 + ): + logger.debug( + f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: " + f"{alpha_threshold=} and {threshold_ratio=}" + ) + predictions = threshold_core.predict_using_thresholds( + best_predictions_tmp, + alpha_threshold, + threshold_ratio, + config[training_conf], + config["id_column"], + ) + predict_train = threshold_core.predict_using_thresholds( + best_predict_train_tmp, + alpha_threshold, + threshold_ratio, + config[training_conf], + config["id_column"], + ) + + results_dfs[i] = self._capture_results( + predictions, + predict_train, + dep_var, + model, + results_dfs[i], + otd_data, + alpha_threshold, + threshold_ratio, + best_pr_auc, + ) + i += 1 + for i in range(len(threshold_matrix)): desc_df = _append_results(desc_df, results_dfs[i], model_type, params)