diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 10fa963..709e2cc 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -411,34 +411,32 @@ def _run(self) -> None: # Stores suspicious data otd_data = self._create_otd_data(id_a, id_b) - n_training_iterations = config[training_conf].get("n_training_iterations", 10) - if n_training_iterations < 2: + outer_fold_count= config[training_conf].get("n_training_iterations", 10) + inner_fold_count = 3 + + if outer_fold_count < 2: raise RuntimeError("You must use at least two training iterations.") seed = config[training_conf].get("seed", 2133) - model_evaluation_splits = self._get_splits( - prepped_data, id_a, n_training_iterations, seed + outer_folds = self._get_outer_folds( + prepped_data, id_a, outer_fold_count, seed ) - thresholding_split = model_evaluation_splits.pop() - - # Explode params into all the combinations we want to test with the current model. - model_parameters = self._get_model_parameters(config) - logger.info( - f"There are {len(model_parameters)} sets of model parameters to explore; " - f"each of these has {n_training_iterations} train-test splits to test on" - ) + for test_data_index, thresholding_test_data in enumerate(outer_folds): + # Explode params into all the combinations we want to test with the current model. + model_parameters = self._get_model_parameters(config) + combined_training_data = _combine(outer_folds, ignore=test_data_index) - hyperparam_evaluation_results = self._evaluate_hyperparam_combinations( - model_parameters, - model_evaluation_splits, - dep_var, - id_a, - id_b, - config, - training_conf, - ) + hyperparam_evaluation_results = self._evaluate_hyperparam_combinations( + model_parameters, + combined_training_data, + dep_var, + id_a, + id_b, + config, + training_conf, + ) # TODO: We may want to recreate a new split or set of splits rather than reuse existing splits. thresholded_metrics_df, suspicious_data = self._evaluate_threshold_combinations( @@ -464,6 +462,30 @@ def _run(self) -> None: self._save_otd_data(suspicious_data, self.task.spark) self.task.spark.sql("set spark.sql.shuffle.partitions=200") + def _get_outer_folds( + self, + prepped_data: pyspark.sql.DataFrame, + id_a: str, + k_folds: int, + seed: int) -> list[list[pyspark.sql.DataFrame]]: + + weights = [1.0/k_folds for i in k_folds] + split_ids = prepped_data.select(id_a).distinct().randomSplit(weights, seed=seed) + + splits = [] + for ids_a, ids_b in split_ids: + split_a = prepped_data.join(ids_a, on=id_a, how="inner") + split_b = prepped_data.join(ids_b, on=id_a, how="inner") + splits.append([split_a, split_b]) + for index, s in enumerate(splits, 1): + training_data = s[0] + test_data = s[1] + + print( + f"Split {index}: training rows {training_data.count()} test rows: {test_data.count()}" + ) + return splits + def _get_splits( self, prepped_data: pyspark.sql.DataFrame, @@ -494,7 +516,6 @@ def _get_splits( split_a = prepped_data.join(ids_a, on=id_a, how="inner") split_b = prepped_data.join(ids_b, on=id_a, how="inner") splits.append([split_a, split_b]) - else: print("Splitting randomly n times.") splits = [