Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ccdavis committed Nov 19, 2024
1 parent 1f70f66 commit 8e5415f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 108 deletions.
222 changes: 119 additions & 103 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,117 @@ def __init__(self, task) -> None:
],
)

# Takes a list of the PRAUC (Precision / Recall area under the curve) and the scoring strategy to use
def _score_train_test_results(
self, areas: list[float], score_strategy: str = "mean"
) -> float:
if score_strategy == "mean":
return statistics.mean(areas)
else:
raise RuntimeError(f"strategy {score_strategy} not implemented.")

def _train_model(
self, training_data, test_data, model_type, params, dep_var, id_a, id_b
) -> float:
classifier, post_transformer = classifier_core.choose_classifier(
model_type, params, dep_var
)

logger.debug("Training the model on the training data split")
start_train_time = perf_counter()
model = classifier.fit(training_data)
end_train_time = perf_counter()
logger.debug(
f"Successfully trained the model in {end_train_time - start_train_time:.2f}s"
)
predictions_tmp = _get_probability_and_select_pred_columns(
test_data, model, post_transformer, id_a, id_b, dep_var
)
predict_train_tmp = _get_probability_and_select_pred_columns(
training_data, model, post_transformer, id_a, id_b, dep_var
)

test_pred = predictions_tmp.toPandas()
precision, recall, thresholds_raw = precision_recall_curve(
test_pred[f"{dep_var}"],
test_pred["probability"].round(2),
pos_label=1,
)
pr_auc = auc(recall, precision)
print(f"The area under the precision-recall curve is {pr_auc}")
return pr_auc

# Returns a PR AUC list computation for each split of training and test data run through the model using model params
def _collect_train_test_splits(
self, splits, model_type, params, dep_var, id_a, id_b
) -> list[float]:
# Collect auc values so we can pull out the highest
splits_results = []
for split_index, (training_data, test_data) in enumerate(splits, 1):
split_start_info = f"Training and testing the model on train-test split {split_index} of {n_training_iterations}"
print(split_start_info)
logger.debug(split_start_info)
prauc = self._train_model(
training_data, test_data, model_type, params, dep_var, id_a, id_b
)
splits_results.append(prauc)
return splits_results

# Returns a list of dicts like {"score": 0.5, "params": {...}, "threshold": 0.8, "threshold_ratio": 3.3}
# This connects a score to each hyper-parameter combination. and the thresholds listed with it in the config.
def _evaluate_hyperparam_combinations(
self, splits, model_parameters, dep_var, id_a, id_b, config, training_conf
) -> list[dict[str, Any]]:
results = []
for index, params_combo in enumerate(model_parameters, 1):
eval_start_info = f"Starting run {index} of {len(model_parameters)} with these parameters: {params_combo}"
print(eval_start_info)
logger.info(eval_start_info)
params = params_combo.copy()

# These are mixed in with the hyper-parameters, we only need the model type at this stage,
# but the threshold info needs to go away.
model_type = params.pop("type")
threshold, threshold_ratio = self._get_thresholds(
params, config, training_conf
)
params.pop("threshold", None)
params.pop("threshold_ratio", None)

pr_auc_values = self._collect_train_test_splits(
splits, model_type, params, dep_var, id_a, id_b
)
score = self._score_train_test_results(pr_auc_values, "mean")
results.append(
{
"score": score,
"params": params,
"threshold": threshold,
"threshold_ratio": threshold_ratio,
}
)

return results

def _get_thresholds(
self, model_parameters, config, training_conf
) -> tuple[Any, Any]:
alpha_threshold = model_parameters.get(
"threshold", config[training_conf].get("threshold", 0.8)
)
if (
config[training_conf].get("decision", False)
== "drop_duplicate_with_threshold_ratio"
):
threshold_ratio = model_parameters.get(
"threshold_ratio",
threshold_core.get_threshold_ratio(config[training_conf], params),
)
else:
threshold_ratio = False

return alpha_threshold, threshold_ratio

def _run(self) -> None:
training_conf = str(self.task.training_conf)
table_prefix = self.task.table_prefix
Expand All @@ -69,116 +180,21 @@ def _run(self) -> None:

splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed)

# Explode params into all the combinations we want to test with the current model.
model_parameters = self._get_model_parameters(config)

logger.info(
f"There are {len(model_parameters)} sets of model parameters to explore; "
f"each of these has {n_training_iterations} train-test splits to test on"
)

probability_metrics_df = _create_probability_metrics_df()
pr_auc_info = []
for run_index, run in enumerate(model_parameters, 1):
run_start_info = f"Starting run {run_index} of {len(model_parameters)} with these parameters: {run}"
print(run_start_info)
logger.info(run_start_info)
params = run.copy()
model_type = params.pop("type")

alpha_threshold = params.pop(
"threshold", config[training_conf].get("threshold", 0.8)
)
if (
config[training_conf].get("decision", False)
== "drop_duplicate_with_threshold_ratio"
):
threshold_ratio = params.pop(
"threshold_ratio",
threshold_core.get_threshold_ratio(config[training_conf], params),
)
else:
threshold_ratio = False

# Collect auc values so we can pull out the highest
splits_results = []

first = True
for split_index, (training_data, test_data) in enumerate(splits, 1):
split_start_info = f"Training and testing the model on train-test split {split_index} of {n_training_iterations}"
print(split_start_info)
logger.debug(split_start_info)
training_data.cache()
test_data.cache()

classifier, post_transformer = classifier_core.choose_classifier(
model_type, params, dep_var
)
param_evalulation_results = self._evaluate_hyperparam_combinations(
model_parameters, splits, dep_var, id_a, id_b, config, training_conf
)

logger.debug("Training the model on the training data split")
start_train_time = perf_counter()
model = classifier.fit(training_data)
end_train_time = perf_counter()
logger.debug(
f"Successfully trained the model in {end_train_time - start_train_time:.2f}s"
)

predictions_tmp = _get_probability_and_select_pred_columns(
test_data, model, post_transformer, id_a, id_b, dep_var
).cache()
predict_train_tmp = _get_probability_and_select_pred_columns(
training_data, model, post_transformer, id_a, id_b, dep_var
).cache()

test_pred = predictions_tmp.toPandas()
precision, recall, thresholds_raw = precision_recall_curve(
test_pred[f"{dep_var}"],
test_pred["probability"].round(2),
pos_label=1,
)
pr_auc = auc(recall, precision)
print(f"The area under the precision-recall curve is {pr_auc}")
splits_results.append(pr_auc)

thresholds_plus_1 = np.append(thresholds_raw, [np.nan])
param_text = np.full(precision.shape, f"{model_type}_{params}")

if first:
prc = pd.DataFrame(
{
"params": param_text,
"precision": precision,
"recall": recall,
"threshold_gt_eq": thresholds_plus_1,
}
)
self.task.spark.createDataFrame(prc).write.mode(
"overwrite"
).saveAsTable(
f"{self.task.table_prefix}precision_recall_curve_"
+ re.sub("[^A-Za-z0-9]", "_", f"{model_type}{params}")
)

first = False

training_data.unpersist()
test_data.unpersist()

# Aggregate pr auc mean, median, std
auc_mean = statistics.mean(splits_results)
auc_std = statistics.stdev(splits_results)
pr_auc_dict = {
"auc_mean": auc_mean,
"auc_standard_deviation": auc_std,
"model": model_type,
"params": params,
}
print(f"PR AUC for splits on current model and params: {pr_auc_dict}")
pr_auc_info.append(pr_auc_info)
this_model_results = pd.DataFrame(pr_auc_dict)
# I'm not sure what this dataframe is for
probability_metrics_df = pd.concat(
[probability_metrics_df, this_model_results]
)
for eval in param_evalulation_results:
alpha_threshold = eval["threshold"]
threshold_ratio = eval["threshold_ratio"]

threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio)
logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")
Expand Down Expand Up @@ -261,7 +277,7 @@ def _run(self) -> None:
thresholded_metrics_df = _load_thresholded_metrics_df_params(
thresholded_metrics_df
)

_print_thresholded_metrics_df(thresholded_metrics_df)
self._save_training_results(thresholded_metrics_df, self.task.spark)
self._save_otd_data(otd_data, self.task.spark)
Expand Down
9 changes: 4 additions & 5 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ def test_all(
preds.query("id_a == 20 and id_b == 30")["probability"].round(2).iloc[0] > 0.5
)


assert (
preds.query("id_a == 20 and id_b == 30")["second_best_prob"].round(2).iloc[0]
>= 0.6
)

assert preds.query("id_a == 30 and id_b == 30")["prediction"].iloc[0] == 0
assert pd.isnull(
preds.query("id_a == 10 and id_b == 30")["second_best_prob"].iloc[0]
Expand Down Expand Up @@ -368,9 +367,9 @@ def test_step_2_train_gradient_boosted_trees_spark(
preds = spark.table("model_eval_predictions").toPandas()

assert "probability_array" in list(preds.columns)
#import pdb
#pdb.set_trace()

# import pdb
# pdb.set_trace()

training_results = tr.query("model == 'gradient_boosted_trees'")
print(f"XX training_results: {training_results}")
Expand Down

0 comments on commit 8e5415f

Please sign in to comment.