Skip to content

Commit

Permalink
fixing optimizer for tl
Browse files Browse the repository at this point in the history
  • Loading branch information
pbalapra committed Feb 13, 2022
1 parent ce59878 commit 5eed19b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
8 changes: 7 additions & 1 deletion skopt/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def __init__(
del args["self"]
self.specs = {"args": args, "function": "Optimizer"}
self.rng = check_random_state(random_state)
print(tl_sdv)
self.tl_sdv = tl_sdv

# Configure acquisition function

Expand Down Expand Up @@ -380,12 +382,16 @@ def copy(self, random_state=None):
acq_func_kwargs=self.acq_func_kwargs,
acq_optimizer_kwargs=self.acq_optimizer_kwargs,
random_state=random_state,
tl_sdv=self.tl_sdv,
tl_sdv=self.tl_sdv
)

optimizer._initial_samples = self._initial_samples

optimizer.sampled = self.sampled[:]

if hasattr(self, "tl_sdv"):
optimizer.tl_sdv = self.tl_sdv

if hasattr(self, "gains_"):
optimizer.gains_ = np.copy(self.gains_)
if self.Xi:
Expand Down
13 changes: 8 additions & 5 deletions skopt/space/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ def __init__(self, low, high, prior="uniform", base=10, transform=None,
if high <= low:
raise ValueError("the lower bound {} has to be less than the"
" upper bound {}".format(low, high))
if prior not in ["uniform", "log-uniform"]:
raise ValueError("prior should be 'uniform' or 'log-uniform'"
if prior not in ["uniform", "log-uniform","normal"]:
raise ValueError("prior should be 'normal', 'uniform' or 'log-uniform'"
" got {}".format(prior))
self.low = low
self.high = high
Expand Down Expand Up @@ -1080,16 +1080,20 @@ def rvs(self, n_samples=1, random_state=None):
points : list of lists, shape=(n_points, n_dims)
Points sampled from the space.
"""

rng = check_random_state(random_state)
if self.is_config_space:
req_points = []
if self.tl_sdv is None:
confs = self.config_space.sample_configuration(n_samples)
else:
confs = self.tl_sdv.sample(n_samples) # we have to check and fix this!
confs_t = self.tl_sdv.sample(n_samples) # we have to check and fix this!
confs = confs_t.to_dict('records')
print('successfully sampling with tl_sdv! ')

if n_samples == 1:
confs = [confs]

print(confs[0])
hps_names = self.config_space.get_hyperparameter_names()
for conf in confs:
point = []
Expand All @@ -1101,7 +1105,6 @@ def rvs(self, n_samples=1, random_state=None):
val = conf[hps_name]
point.append(val)
req_points.append(point)

return req_points
else:
if self.tl_sdv is None:
Expand Down

0 comments on commit 5eed19b

Please sign in to comment.