Skip to content

Commit

Permalink
[#167] Pull the edge case logic for "type" out of _choose_randomized_…
Browse files Browse the repository at this point in the history
…parameters()
  • Loading branch information
riley-harper committed Nov 27, 2024
1 parent 907818e commit 65cb5ff
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,8 @@ def _choose_randomized_parameters(model_parameters: dict[str, Any]) -> dict[str,
parameter_choices = dict()

for key, value in model_parameters.items():
if key == "type":
parameter_choices[key] = value
# If it's a Sequence (usually list), choose one of the values at random.
elif isinstance(value, collections.abc.Sequence):
if isinstance(value, collections.abc.Sequence):
parameter_choices[key] = random.choice(value)
# If it's a Mapping (usually dict), it defines a distribution from which
# the parameter should be sampled.
Expand Down Expand Up @@ -757,7 +755,14 @@ def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any
return_parameters = []
for _ in range(num_samples):
parameter_spec = random.choice(model_parameters)
randomized = _choose_randomized_parameters(parameter_spec)
model_type = parameter_spec["type"]
sample_parameters = dict(
(key, value)
for (key, value) in parameter_spec.items()
if key != "type"
)
randomized = _choose_randomized_parameters(sample_parameters)
randomized["type"] = model_type
return_parameters.append(randomized)

return return_parameters
Expand Down

0 comments on commit 65cb5ff

Please sign in to comment.