Skip to content

Commit

Permalink
Messing around with refactoring model exploration
Browse files Browse the repository at this point in the history
  • Loading branch information
ccdavis committed Nov 14, 2024
1 parent d89a6da commit 5507b4b
Showing 1 changed file with 56 additions and 36 deletions.
92 changes: 56 additions & 36 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _run(self) -> None:
.cache()
)

# Stores suspicious data
otd_data = self._create_otd_data(id_a, id_b)

n_training_iterations = config[training_conf].get("n_training_iterations", 10)
Expand Down Expand Up @@ -101,6 +102,9 @@ def _run(self) -> None:
for i in range(len(threshold_matrix)):
results_dfs[i] = _create_results_df()

# Collect auc values so we can pull out the highest
splits_results = []

first = True
for split_index, (training_data, test_data) in enumerate(splits, 1):
split_start_info = f"Training and testing the model on train-test split {split_index} of {n_training_iterations}"
Expand Down Expand Up @@ -140,6 +144,13 @@ def _run(self) -> None:

pr_auc = auc(recall, precision)
print(f"The area under the precision-recall curve is {pr_auc}")
splits_results.append(
{
"auc": pr_auc,
"predictions_tmp": predictions_tmp,
"predict_train_tmp": predict_train_tmp,
}
)

if first:
prc = pd.DataFrame(
Expand All @@ -159,45 +170,54 @@ def _run(self) -> None:

first = False

i = 0
for threshold_index, (alpha_threshold, threshold_ratio) in enumerate(
threshold_matrix, 1
):
logger.debug(
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{alpha_threshold=} and {threshold_ratio=}"
)
predictions = threshold_core.predict_using_thresholds(
predictions_tmp,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)
predict_train = threshold_core.predict_using_thresholds(
predict_train_tmp,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)

results_dfs[i] = self._capture_results(
predictions,
predict_train,
dep_var,
model,
results_dfs[i],
otd_data,
alpha_threshold,
threshold_ratio,
pr_auc,
)
i += 1

training_data.unpersist()
test_data.unpersist()

# pluck out predictions_tmp, predict_train_tmp associated with highest pr_auc
best_pr_auc = 0.0
best_predictions_tmp = None
best_predict_train_tmp = None
for a in splits_results:
if a["auc"] > best_pr_auc:
best_prediction_tmp = a["predictions_tmp"]
best_predict_train_tmp = a["predict_train_tmp"]

i = 0
for threshold_index, (alpha_threshold, threshold_ratio) in enumerate(
threshold_matrix, 1
):
logger.debug(
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{alpha_threshold=} and {threshold_ratio=}"
)
predictions = threshold_core.predict_using_thresholds(
best_predictions_tmp,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)
predict_train = threshold_core.predict_using_thresholds(
best_predict_train_tmp,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)

results_dfs[i] = self._capture_results(
predictions,
predict_train,
dep_var,
model,
results_dfs[i],
otd_data,
alpha_threshold,
threshold_ratio,
best_pr_auc,
)
i += 1

for i in range(len(threshold_matrix)):
desc_df = _append_results(desc_df, results_dfs[i], model_type, params)

Expand Down

0 comments on commit 5507b4b

Please sign in to comment.