Skip to content

Commit

Permalink
[#167] Simplify the interface to _custom_param_grid_builder()
Browse files Browse the repository at this point in the history
We can just pass the list of model_parameters from the config file to this
function.
  • Loading branch information
riley-harper committed Nov 26, 2024
1 parent c5f5b13 commit 605369b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]:

model_parameters = conf[training_conf]["model_parameters"]
if "param_grid" in conf[training_conf] and conf[training_conf]["param_grid"]:
model_parameters = _custom_param_grid_builder(training_conf, conf)
model_parameters = _custom_param_grid_builder(model_parameters)
elif model_parameters == []:
raise ValueError(
"No model parameters found. In 'training' config, either supply 'model_parameters' or 'param_grid'."
Expand Down Expand Up @@ -665,10 +665,10 @@ def _create_desc_df() -> pd.DataFrame:


def _custom_param_grid_builder(
training_conf: str, conf: dict[str, Any]
model_parameters: list[dict[str, Any]]
) -> list[dict[str, Any]]:
print("Building param grid for models")
given_parameters = conf[training_conf]["model_parameters"]
given_parameters = model_parameters
new_params = []
for run in given_parameters:
params = run.copy()
Expand Down
12 changes: 4 additions & 8 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,13 @@ def test_all(
main.do_drop_all("")


def test_step_2_param_grid(main, training_conf):
"""Test matching step 2 training to see if the custom param grid builder is working"""

training_conf["training"]["model_parameters"] = [
def test_custom_param_grid_builder():
"""Test matching step 2's custom param grid builder"""
model_parameters = [
{"type": "random_forest", "maxDepth": [3, 4, 5], "numTrees": [50, 100]},
{"type": "probit", "threshold": [0.5, 0.7]},
]

param_grid = _custom_param_grid_builder("training", training_conf)
param_grid = _custom_param_grid_builder(model_parameters)

expected = [
{"maxDepth": 3, "numTrees": 50, "type": "random_forest"},
Expand All @@ -145,8 +143,6 @@ def test_step_2_param_grid(main, training_conf):
assert len(param_grid) == len(expected)
assert all(m in expected for m in param_grid)

main.do_drop_all("")


# -------------------------------------
# Tests that probably should be moved
Expand Down

0 comments on commit 605369b

Please sign in to comment.