Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ccdavis committed Nov 27, 2024
1 parent 38c1006 commit a94250c
Showing 1 changed file with 43 additions and 22 deletions.
65 changes: 43 additions & 22 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,34 +411,32 @@ def _run(self) -> None:
# Stores suspicious data
otd_data = self._create_otd_data(id_a, id_b)

n_training_iterations = config[training_conf].get("n_training_iterations", 10)
if n_training_iterations < 2:
outer_fold_count= config[training_conf].get("n_training_iterations", 10)
inner_fold_count = 3

if outer_fold_count < 2:
raise RuntimeError("You must use at least two training iterations.")

seed = config[training_conf].get("seed", 2133)

model_evaluation_splits = self._get_splits(
prepped_data, id_a, n_training_iterations, seed
outer_folds = self._get_outer_folds(
prepped_data, id_a, outer_fold_count, seed
)
thresholding_split = model_evaluation_splits.pop()

# Explode params into all the combinations we want to test with the current model.
model_parameters = self._get_model_parameters(config)

logger.info(
f"There are {len(model_parameters)} sets of model parameters to explore; "
f"each of these has {n_training_iterations} train-test splits to test on"
)
for test_data_index, thresholding_test_data in enumerate(outer_folds):
# Explode params into all the combinations we want to test with the current model.
model_parameters = self._get_model_parameters(config)
combined_training_data = _combine(outer_folds, ignore=test_data_index)

hyperparam_evaluation_results = self._evaluate_hyperparam_combinations(
model_parameters,
model_evaluation_splits,
dep_var,
id_a,
id_b,
config,
training_conf,
)
hyperparam_evaluation_results = self._evaluate_hyperparam_combinations(
model_parameters,
combined_training_data,
dep_var,
id_a,
id_b,
config,
training_conf,
)

# TODO: We may want to recreate a new split or set of splits rather than reuse existing splits.
thresholded_metrics_df, suspicious_data = self._evaluate_threshold_combinations(
Expand All @@ -464,6 +462,30 @@ def _run(self) -> None:
self._save_otd_data(suspicious_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")

def _get_outer_folds(
self,
prepped_data: pyspark.sql.DataFrame,
id_a: str,
k_folds: int,
seed: int) -> list[list[pyspark.sql.DataFrame]]:

weights = [1.0/k_folds for i in k_folds]
split_ids = prepped_data.select(id_a).distinct().randomSplit(weights, seed=seed)

splits = []
for ids_a, ids_b in split_ids:
split_a = prepped_data.join(ids_a, on=id_a, how="inner")
split_b = prepped_data.join(ids_b, on=id_a, how="inner")
splits.append([split_a, split_b])
for index, s in enumerate(splits, 1):
training_data = s[0]
test_data = s[1]

print(
f"Split {index}: training rows {training_data.count()} test rows: {test_data.count()}"
)
return splits

def _get_splits(
self,
prepped_data: pyspark.sql.DataFrame,
Expand Down Expand Up @@ -494,7 +516,6 @@ def _get_splits(
split_a = prepped_data.join(ids_a, on=id_a, how="inner")
split_b = prepped_data.join(ids_b, on=id_a, how="inner")
splits.append([split_a, split_b])

else:
print("Splitting randomly n times.")
splits = [
Expand Down

0 comments on commit a94250c

Please sign in to comment.