Skip to content

Commit

Permalink
[#167] Add a normal distribution to randomized parameter search
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Dec 2, 2024
1 parent 0becd32 commit 5d0ea0b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
10 changes: 8 additions & 2 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,13 +704,19 @@ def _choose_randomized_parameters(
# the parameter should be sampled.
elif isinstance(value, collections.abc.Mapping):
distribution = value["distribution"]
low = value["low"]
high = value["high"]

if distribution == "randint":
low = value["low"]
high = value["high"]
parameter_choices[key] = rng.randint(low, high)
elif distribution == "uniform":
low = value["low"]
high = value["high"]
parameter_choices[key] = rng.uniform(low, high)
elif distribution == "normal":
mean = value["mean"]
stdev = value["standard_deviation"]
parameter_choices[key] = rng.normalvariate(mean, stdev)
else:
raise ValueError("unknown distribution")
# All other types (including strings) are passed through unchanged.
Expand Down
9 changes: 9 additions & 0 deletions hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ def test_get_model_parameters_search_strategy_randomized_sample_from_distributio
"type": "decision_tree",
"maxDepth": {"distribution": "randint", "low": 1, "high": 20},
"minInfoGain": {"distribution": "uniform", "low": 0.0, "high": 100.0},
"minWeightFractionPerNode": {
"distribution": "normal",
"mean": 10.0,
"standard_deviation": 2.5,
},
}
]

Expand All @@ -396,6 +401,10 @@ def test_get_model_parameters_search_strategy_randomized_sample_from_distributio
assert parameter_choice["type"] == "decision_tree"
assert 1 <= parameter_choice["maxDepth"] <= 20
assert 0.0 <= parameter_choice["minInfoGain"] <= 100.0
# Technically a normal distribution can return any value, even ones very
# far from its mean. So we can't assert on the value returned here. But
# there definitely should be a value of some sort in the dictionary.
assert "minWeightFractionPerNode" in parameter_choice


def test_get_model_parameters_search_strategy_randomized_take_values(training_conf):
Expand Down

0 comments on commit 5d0ea0b

Please sign in to comment.