Skip to content

Commit

Permalink
[#167] Pull _custom_param_grid_builder() out of the LinkStepTrainTest…
Browse files Browse the repository at this point in the history
…Models class
  • Loading branch information
riley-harper committed Nov 26, 2024
1 parent 72cda30 commit c5f5b13
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 33 deletions.
63 changes: 33 additions & 30 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,35 +244,6 @@ def _get_splits(

return splits

def _custom_param_grid_builder(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
print("Building param grid for models")
given_parameters = conf[f"{self.task.training_conf}"]["model_parameters"]
new_params = []
for run in given_parameters:
params = run.copy()
model_type = params.pop("type")

# dropping thresholds to prep for scikitlearn model exploration refactor
threshold = params.pop("threshold", False)
threshold_ratio = params.pop("threshold_ratio", False)

keys = params.keys()
values = params.values()

params_exploded = []
for prod in itertools.product(*values):
params_exploded.append(dict(zip(keys, prod)))

for subdict in params_exploded:
subdict["type"] = model_type
if threshold:
subdict["threshold"] = threshold
if threshold_ratio:
subdict["threshold_ratio"] = threshold_ratio

new_params.extend(params_exploded)
return new_params

def _capture_results(
self,
predictions: pyspark.sql.DataFrame,
Expand Down Expand Up @@ -332,7 +303,7 @@ def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]:

model_parameters = conf[training_conf]["model_parameters"]
if "param_grid" in conf[training_conf] and conf[training_conf]["param_grid"]:
model_parameters = self._custom_param_grid_builder(conf)
model_parameters = _custom_param_grid_builder(training_conf, conf)
elif model_parameters == []:
raise ValueError(
"No model parameters found. In 'training' config, either supply 'model_parameters' or 'param_grid'."
Expand Down Expand Up @@ -691,3 +662,35 @@ def _create_desc_df() -> pd.DataFrame:
"mcc_train_sd",
]
)


def _custom_param_grid_builder(
training_conf: str, conf: dict[str, Any]
) -> list[dict[str, Any]]:
print("Building param grid for models")
given_parameters = conf[training_conf]["model_parameters"]
new_params = []
for run in given_parameters:
params = run.copy()
model_type = params.pop("type")

# dropping thresholds to prep for scikitlearn model exploration refactor
threshold = params.pop("threshold", False)
threshold_ratio = params.pop("threshold_ratio", False)

keys = params.keys()
values = params.values()

params_exploded = []
for prod in itertools.product(*values):
params_exploded.append(dict(zip(keys, prod)))

for subdict in params_exploded:
subdict["type"] = model_type
if threshold:
subdict["threshold"] = threshold
if threshold_ratio:
subdict["threshold_ratio"] = threshold_ratio

new_params.extend(params_exploded)
return new_params
6 changes: 3 additions & 3 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import hlink.linking.core.threshold as threshold_core
from hlink.linking.model_exploration.link_step_train_test_models import (
LinkStepTrainTestModels,
_custom_param_grid_builder,
)


Expand Down Expand Up @@ -121,16 +122,15 @@ def test_all(
main.do_drop_all("")


def test_step_2_param_grid(spark, main, training_conf, model_exploration, fake_self):
def test_step_2_param_grid(main, training_conf):
"""Test matching step 2 training to see if the custom param grid builder is working"""

training_conf["training"]["model_parameters"] = [
{"type": "random_forest", "maxDepth": [3, 4, 5], "numTrees": [50, 100]},
{"type": "probit", "threshold": [0.5, 0.7]},
]

link_step = LinkStepTrainTestModels(model_exploration)
param_grid = link_step._custom_param_grid_builder(training_conf)
param_grid = _custom_param_grid_builder("training", training_conf)

expected = [
{"maxDepth": 3, "numTrees": 50, "type": "random_forest"},
Expand Down

0 comments on commit c5f5b13

Please sign in to comment.