Skip to content

Commit

Permalink
Use all splits on thresholding
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Davis committed Nov 15, 2024
1 parent c9576e8 commit 319129f
Showing 1 changed file with 142 additions and 77 deletions.
219 changes: 142 additions & 77 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,68 @@

from hlink.linking.link_step import LinkStep

# This is a refactor to make the train-test model process faster.
"""
Current algorithm:
1. Prepare test-train data
2. split data into n pairs of training and test data. In our tests n == 10.
3. for every model type, for each combination of hyper-parameters
for train, test in n splits:
train the model with the training data
test the trained model using the test data
capture the probability of correct predictions on each split
Score the model based on some function of the collected probabilities (like 'mean')
Store the score with the model type and hyper-parameters that produced the score
4. Select the best performing model type + hyper-parameter set based on the associated score.
5. With the best scoring parameters and model:
Obtain a single training data and test data split
for each threshold setting combination:
Train the model type with the associated hyper-parameters
Predict the matches on the test data using the trained model
Evaluate the predictions and capture the threshold combination that made it.
6. Print the results of the threshold evaluations
p = hyper-parameter combinations
s = number of splits
t = threshold matrix size (x * y)
complexity = s * p + t -> O(n^2)
We may end up needing to test the thresholds on multiple splits:
s * p + s * t
It's hard to generalize the number of passes on the data since 't' may be pretty large or not at all. 's' will probably be 10 or so and 'p' also can vary a lot from 2 or 3 to 100.
Original Algorithm:
1. Prepare test-train data
2. split data into n pairs of training and test data. In our tests n == 10.
3. for every model type, for each combination of hyper-parameters
for train, test in n splits:
train the model with the training data
test the trained model using the test data
capture the probability of correct predictions on each split
4. With the best scoring parameters and model:
for each threshold setting combination:
Train the model type with the associated hyper-parameters
Predict the matches on the test data using the trained model
Evaluate the predictions and capture the threshold combination and hyper-parameters that made it.
6. Print the results of the threshold evaluations
complexity = p * s * t -> O(n^3)
"""



logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -232,90 +294,93 @@ def _evaluate_threshold_combinations(

# TODO check if we should make a different split, like starting from a different seed?
# or just not re-using one we used in making the PR_AUC mean value?
splits_for_thresholding_eval = splits[0]
thresholding_training_data = splits_for_thresholding_eval[0].cache()
thresholding_test_data = splits_for_thresholding_eval[1].cache()

threshold_matrix = best_results.make_threshold_matrix()

logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")
results_dfs: dict[int, pd.DataFrame] = {}
for i in range(len(threshold_matrix)):
results_dfs[i] = _create_results_df()

thresholding_classifier, thresholding_post_transformer = (
classifier_core.choose_classifier(
best_results.model_type, best_results.hyperparams, dep_var
#splits_for_thresholding_eval = splits[0]
#thresholding_training_data = splits_for_thresholding_eval[0].cache()
#thresholding_test_data = splits_for_thresholding_eval[1].cache()
for split_index, (thresholding_training_data, thresholding_test_data) in enumerate(splits, 1):
cached_training_data = thresholding_training_data.cache()
cached_test_data = thresholding_test_data.cache()

threshold_matrix = best_results.make_threshold_matrix()

logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")
results_dfs: dict[int, pd.DataFrame] = {}
for i in range(len(threshold_matrix)):
results_dfs[i] = _create_results_df()

thresholding_classifier, thresholding_post_transformer = (
classifier_core.choose_classifier(
best_results.model_type, best_results.hyperparams, dep_var
)
)
)
thresholding_model = thresholding_classifier.fit(thresholding_training_data)

thresholding_predictions = _get_probability_and_select_pred_columns(
thresholding_test_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
)
thresholding_predict_train = _get_probability_and_select_pred_columns(
thresholding_training_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
)
thresholding_model = thresholding_classifier.fit(cached_training_data)

i = 0
for threshold_index, (
this_alpha_threshold,
this_threshold_ratio,
) in enumerate(threshold_matrix, 1):

diag = (
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{this_alpha_threshold=} and {this_threshold_ratio=}"
)
logger.debug(diag)
print(diag)
predictions = threshold_core.predict_using_thresholds(
thresholding_predictions,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
thresholding_predictions = _get_probability_and_select_pred_columns(
cached_test_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
)
predict_train = threshold_core.predict_using_thresholds(
thresholding_predict_train,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
thresholding_predict_train = _get_probability_and_select_pred_columns(
cached_training_data,
thresholding_model,
thresholding_post_transformer,
id_a,
id_b,
dep_var,
)

results_dfs[i] = self._capture_results(
predictions,
predict_train,
dep_var,
thresholding_model,
results_dfs[i],
suspicious_data,
i = 0
for threshold_index, (
this_alpha_threshold,
this_threshold_ratio,
best_results.score,
)
i += 1
thresholding_test_data.unpersist()
thresholding_training_data.unpersist()

for i in range(len(threshold_matrix)):
thresholded_metrics_df = _append_results(
thresholded_metrics_df,
results_dfs[i],
best_results.model_type,
best_results.hyperparams,
)
) in enumerate(threshold_matrix, 1):

diag = (
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{this_alpha_threshold=} and {this_threshold_ratio=}"
)
logger.debug(diag)
print(diag)
predictions = threshold_core.predict_using_thresholds(
thresholding_predictions,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
)
predict_train = threshold_core.predict_using_thresholds(
thresholding_predict_train,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
)

results_dfs[i] = self._capture_results(
predictions,
predict_train,
dep_var,
thresholding_model,
results_dfs[i],
suspicious_data,
this_alpha_threshold,
this_threshold_ratio,
best_results.score,
)
i += 1
thresholding_test_data.unpersist()
thresholding_training_data.unpersist()

for i in range(len(threshold_matrix)):
thresholded_metrics_df = _append_results(
thresholded_metrics_df,
results_dfs[i],
best_results.model_type,
best_results.hyperparams,
)

return thresholded_metrics_df, suspicious_data

Expand Down

0 comments on commit 319129f

Please sign in to comment.