From bc0bf7d6d254c60f580ba3c192ac93a96b449660 Mon Sep 17 00:00:00 2001 From: rileyh Date: Tue, 26 Nov 2024 15:23:47 -0600 Subject: [PATCH] [#167] Just pass the training section of the config to _get_model_parameters() --- .../model_exploration/link_step_train_test_models.py | 10 ++++------ hlink/tests/model_exploration_test.py | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index 9ef97ee..47c0a8d 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -67,7 +67,7 @@ def _run(self) -> None: splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed) - model_parameters = _get_model_parameters(training_conf, config) + model_parameters = _get_model_parameters(config[training_conf]) logger.info( f"There are {len(model_parameters)} sets of model parameters to explore; " @@ -684,11 +684,9 @@ def _custom_param_grid_builder( return new_params -def _get_model_parameters( - training_conf: str, conf: dict[str, Any] -) -> list[dict[str, Any]]: - model_parameters = conf[training_conf]["model_parameters"] - if "param_grid" in conf[training_conf] and conf[training_conf]["param_grid"]: +def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any]]: + model_parameters = training_config["model_parameters"] + if "param_grid" in training_config and training_config["param_grid"]: model_parameters = _custom_param_grid_builder(model_parameters) elif model_parameters == []: raise ValueError( diff --git a/hlink/tests/model_exploration_test.py b/hlink/tests/model_exploration_test.py index e349500..facd03b 100644 --- a/hlink/tests/model_exploration_test.py +++ b/hlink/tests/model_exploration_test.py @@ -156,7 +156,7 @@ def test_get_model_parameters_no_param_grid_attribute(training_conf): ] assert "param_grid" not in training_conf["training"] - model_parameters = _get_model_parameters("training", training_conf) + model_parameters = _get_model_parameters(training_conf["training"]) assert model_parameters == [ {"type": "random_forest", "maxDepth": 3, "numTrees": 50}, @@ -174,7 +174,7 @@ def test_get_model_parameters_param_grid_false(training_conf): ] training_conf["training"]["param_grid"] = False - model_parameters = _get_model_parameters("training", training_conf) + model_parameters = _get_model_parameters(training_conf["training"]) assert model_parameters == [ {"type": "logistic_regression", "threshold": 0.3, "threshold_ratio": 1.4}, @@ -196,7 +196,7 @@ def test_get_model_parameters_param_grid_true(training_conf): ] training_conf["training"]["param_grid"] = True - model_parameters = _get_model_parameters("training", training_conf) + model_parameters = _get_model_parameters(training_conf["training"]) # 3 settings for maxDepth * 2 settings for numTrees = 6 total settings assert len(model_parameters) == 6