Skip to content

Commit

Permalink
fix(bo): Switch to using id_generator
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Jan 26, 2025
1 parent 3d98089 commit bc0d3b3
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions neps/optimizers/bayesian_optimization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
import math
from collections.abc import Mapping
from dataclasses import dataclass
Expand Down Expand Up @@ -89,9 +90,7 @@ def __call__(

n_to_sample = 1 if n is None else n
n_sampled = len(trials)
config_ids = iter(
str(i + 1) for i in range(n_sampled, n_sampled + n_to_sample + 1)
)
id_generator = iter(str(i) for i in itertools.count(n_sampled + 1))

# If the amount of configs evaluated is less than the initial design
# requirement, keep drawing from initial design
Expand All @@ -103,6 +102,8 @@ def __call__(
sampled_configs: list[SampledConfig] = []

if n_evaluated < self.n_initial_design:
# For reproducibility, we need to ensure we do the same sample of all
# configs each time.
design_samples = make_initial_design(
parameters=parameters,
encoder=self.encoder,
Expand All @@ -111,13 +112,21 @@ def __call__(
seed=None, # TODO: Seeding, however we need to avoid repeating configs
sample_size=self.n_initial_design,
)

# Then take the subset we actually need
design_samples = design_samples[n_evaluated:]
for sample in design_samples:
sample.update(self.space.constants)

sampled_configs.extend(
[
SampledConfig(id=config_id, config=config)
for config_id, config in zip(config_ids, design_samples, strict=True)
for config_id, config in zip(
id_generator,
design_samples,
# NOTE: We use a generator for the ids so no need for strict
strict=False,
)
]
)

Expand Down Expand Up @@ -187,7 +196,12 @@ def __call__(
sampled_configs.extend(
[
SampledConfig(id=config_id, config=config)
for config_id, config in zip(config_ids, configs, strict=True)
for config_id, config in zip(
id_generator,
configs,
# NOTE: We use a generator for the ids so no need for strict
strict=False,
)
]
)
return sampled_configs[0] if n is None else sampled_configs

0 comments on commit bc0d3b3

Please sign in to comment.