From 319129fb3a1fcbbb6952d814ae67be8ae94fa4d3 Mon Sep 17 00:00:00 2001 From: Colin Davis Date: Fri, 15 Nov 2024 17:09:13 -0600 Subject: [PATCH] Use all splits on thresholding --- .../link_step_train_test_models.py | 219 ++++++++++++------ 1 file changed, 142 insertions(+), 77 deletions(-) diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 4e61479..da6507a 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -23,6 +23,68 @@ from hlink.linking.link_step import LinkStep +# This is a refactor to make the train-test model process faster. +""" + +Current algorithm: + +1. Prepare test-train data +2. split data into n pairs of training and test data. In our tests n == 10. +3. for every model type, for each combination of hyper-parameters + for train, test in n splits: + train the model with the training data + test the trained model using the test data + capture the probability of correct predictions on each split + Score the model based on some function of the collected probabilities (like 'mean') + Store the score with the model type and hyper-parameters that produced the score + +4. Select the best performing model type + hyper-parameter set based on the associated score. +5. With the best scoring parameters and model: + Obtain a single training data and test data split + for each threshold setting combination: + Train the model type with the associated hyper-parameters + Predict the matches on the test data using the trained model + Evaluate the predictions and capture the threshold combination that made it. +6. Print the results of the threshold evaluations + +p = hyper-parameter combinations +s = number of splits +t = threshold matrix size (x * y) + +complexity = s * p + t -> O(n^2) + +We may end up needing to test the thresholds on multiple splits: + + s * p + s * t + +It's hard to generalize the number of passes on the data since 't' may be pretty large or not at all. 's' will probably be 10 or so and 'p' also can vary a lot from 2 or 3 to 100. + + +Original Algorithm: + + +1. Prepare test-train data +2. split data into n pairs of training and test data. In our tests n == 10. +3. for every model type, for each combination of hyper-parameters + for train, test in n splits: + train the model with the training data + test the trained model using the test data + capture the probability of correct predictions on each split + + 4. With the best scoring parameters and model: + for each threshold setting combination: + Train the model type with the associated hyper-parameters + Predict the matches on the test data using the trained model + Evaluate the predictions and capture the threshold combination and hyper-parameters that made it. +6. Print the results of the threshold evaluations + +complexity = p * s * t -> O(n^3) + + +""" + + + logger = logging.getLogger(__name__) @@ -232,90 +294,93 @@ def _evaluate_threshold_combinations( # TODO check if we should make a different split, like starting from a different seed? # or just not re-using one we used in making the PR_AUC mean value? - splits_for_thresholding_eval = splits[0] - thresholding_training_data = splits_for_thresholding_eval[0].cache() - thresholding_test_data = splits_for_thresholding_eval[1].cache() - - threshold_matrix = best_results.make_threshold_matrix() - - logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries") - results_dfs: dict[int, pd.DataFrame] = {} - for i in range(len(threshold_matrix)): - results_dfs[i] = _create_results_df() - - thresholding_classifier, thresholding_post_transformer = ( - classifier_core.choose_classifier( - best_results.model_type, best_results.hyperparams, dep_var + #splits_for_thresholding_eval = splits[0] + #thresholding_training_data = splits_for_thresholding_eval[0].cache() + #thresholding_test_data = splits_for_thresholding_eval[1].cache() + for split_index, (thresholding_training_data, thresholding_test_data) in enumerate(splits, 1): + cached_training_data = thresholding_training_data.cache() + cached_test_data = thresholding_test_data.cache() + + threshold_matrix = best_results.make_threshold_matrix() + + logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries") + results_dfs: dict[int, pd.DataFrame] = {} + for i in range(len(threshold_matrix)): + results_dfs[i] = _create_results_df() + + thresholding_classifier, thresholding_post_transformer = ( + classifier_core.choose_classifier( + best_results.model_type, best_results.hyperparams, dep_var + ) ) - ) - thresholding_model = thresholding_classifier.fit(thresholding_training_data) - - thresholding_predictions = _get_probability_and_select_pred_columns( - thresholding_test_data, - thresholding_model, - thresholding_post_transformer, - id_a, - id_b, - dep_var, - ) - thresholding_predict_train = _get_probability_and_select_pred_columns( - thresholding_training_data, - thresholding_model, - thresholding_post_transformer, - id_a, - id_b, - dep_var, - ) + thresholding_model = thresholding_classifier.fit(cached_training_data) - i = 0 - for threshold_index, ( - this_alpha_threshold, - this_threshold_ratio, - ) in enumerate(threshold_matrix, 1): - - diag = ( - f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: " - f"{this_alpha_threshold=} and {this_threshold_ratio=}" - ) - logger.debug(diag) - print(diag) - predictions = threshold_core.predict_using_thresholds( - thresholding_predictions, - this_alpha_threshold, - this_threshold_ratio, - config[training_conf], - config["id_column"], + thresholding_predictions = _get_probability_and_select_pred_columns( + cached_test_data, + thresholding_model, + thresholding_post_transformer, + id_a, + id_b, + dep_var, ) - predict_train = threshold_core.predict_using_thresholds( - thresholding_predict_train, - this_alpha_threshold, - this_threshold_ratio, - config[training_conf], - config["id_column"], + thresholding_predict_train = _get_probability_and_select_pred_columns( + cached_training_data, + thresholding_model, + thresholding_post_transformer, + id_a, + id_b, + dep_var, ) - results_dfs[i] = self._capture_results( - predictions, - predict_train, - dep_var, - thresholding_model, - results_dfs[i], - suspicious_data, + i = 0 + for threshold_index, ( this_alpha_threshold, this_threshold_ratio, - best_results.score, - ) - i += 1 - thresholding_test_data.unpersist() - thresholding_training_data.unpersist() - - for i in range(len(threshold_matrix)): - thresholded_metrics_df = _append_results( - thresholded_metrics_df, - results_dfs[i], - best_results.model_type, - best_results.hyperparams, - ) + ) in enumerate(threshold_matrix, 1): + + diag = ( + f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: " + f"{this_alpha_threshold=} and {this_threshold_ratio=}" + ) + logger.debug(diag) + print(diag) + predictions = threshold_core.predict_using_thresholds( + thresholding_predictions, + this_alpha_threshold, + this_threshold_ratio, + config[training_conf], + config["id_column"], + ) + predict_train = threshold_core.predict_using_thresholds( + thresholding_predict_train, + this_alpha_threshold, + this_threshold_ratio, + config[training_conf], + config["id_column"], + ) + + results_dfs[i] = self._capture_results( + predictions, + predict_train, + dep_var, + thresholding_model, + results_dfs[i], + suspicious_data, + this_alpha_threshold, + this_threshold_ratio, + best_results.score, + ) + i += 1 + thresholding_test_data.unpersist() + thresholding_training_data.unpersist() + + for i in range(len(threshold_matrix)): + thresholded_metrics_df = _append_results( + thresholded_metrics_df, + results_dfs[i], + best_results.model_type, + best_results.hyperparams, + ) return thresholded_metrics_df, suspicious_data