Skip to content

Commit

Permalink
[#167] Just pass the training section of the config to _get_model_par…
Browse files Browse the repository at this point in the history
…ameters()
  • Loading branch information
riley-harper committed Nov 26, 2024
1 parent 7d48380 commit bc0bf7d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
10 changes: 4 additions & 6 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _run(self) -> None:

splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed)

model_parameters = _get_model_parameters(training_conf, config)
model_parameters = _get_model_parameters(config[training_conf])

logger.info(
f"There are {len(model_parameters)} sets of model parameters to explore; "
Expand Down Expand Up @@ -684,11 +684,9 @@ def _custom_param_grid_builder(
return new_params


def _get_model_parameters(
training_conf: str, conf: dict[str, Any]
) -> list[dict[str, Any]]:
model_parameters = conf[training_conf]["model_parameters"]
if "param_grid" in conf[training_conf] and conf[training_conf]["param_grid"]:
def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any]]:
model_parameters = training_config["model_parameters"]
if "param_grid" in training_config and training_config["param_grid"]:
model_parameters = _custom_param_grid_builder(model_parameters)
elif model_parameters == []:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_get_model_parameters_no_param_grid_attribute(training_conf):
]
assert "param_grid" not in training_conf["training"]

model_parameters = _get_model_parameters("training", training_conf)
model_parameters = _get_model_parameters(training_conf["training"])

assert model_parameters == [
{"type": "random_forest", "maxDepth": 3, "numTrees": 50},
Expand All @@ -174,7 +174,7 @@ def test_get_model_parameters_param_grid_false(training_conf):
]
training_conf["training"]["param_grid"] = False

model_parameters = _get_model_parameters("training", training_conf)
model_parameters = _get_model_parameters(training_conf["training"])

assert model_parameters == [
{"type": "logistic_regression", "threshold": 0.3, "threshold_ratio": 1.4},
Expand All @@ -196,7 +196,7 @@ def test_get_model_parameters_param_grid_true(training_conf):
]
training_conf["training"]["param_grid"] = True

model_parameters = _get_model_parameters("training", training_conf)
model_parameters = _get_model_parameters(training_conf["training"])
# 3 settings for maxDepth * 2 settings for numTrees = 6 total settings
assert len(model_parameters) == 6

Expand Down

0 comments on commit bc0bf7d

Please sign in to comment.