Skip to content

Commit

Permalink
adding 'softmax_sampling' based on stochastic policies
Browse files Browse the repository at this point in the history
  • Loading branch information
Deathn0t committed Jan 19, 2022
1 parent 50117e6 commit a715fb2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ doc/auto_examples

# vim users
.*.swp
.vscode/settings.json
17 changes: 11 additions & 6 deletions skopt/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,17 +262,17 @@ def __init__(
else:
acq_optimizer = "sampling"

if acq_optimizer not in ["lbfgs", "sampling"]:
if acq_optimizer not in ["lbfgs", "sampling", "softmax_sampling"]:
raise ValueError(
"Expected acq_optimizer to be 'lbfgs' or "
"'sampling', got {0}".format(acq_optimizer)
"'sampling' or 'softmax_sampling', got {0}".format(acq_optimizer)
)

if not has_gradients(self.base_estimator_) and acq_optimizer != "sampling":
if not has_gradients(self.base_estimator_) and not("sampling" in acq_optimizer):
raise ValueError(
"The regressor {0} should run with "
"acq_optimizer"
"='sampling'.".format(type(base_estimator))
"The regressor {0} should run with a 'sampling' "
"acq_optimizer such as "
"'sampling' or 'softmax_sampling'.".format(type(base_estimator))
)
self.acq_optimizer = acq_optimizer

Expand Down Expand Up @@ -655,6 +655,11 @@ def _tell(self, x, y, fit=True):
if self.acq_optimizer == "sampling":
next_x = X[np.argmin(values)]

elif self.acq_optimizer == "softmax_sampling":
probs = values / np.sum(values)
idx = np.argmax(self.rng.multinomial(1, probs))
next_x = X[idx]

# Use BFGS to find the mimimum of the acquisition function, the
# minimization starts from `n_restarts_optimizer` different
# points and the best minimum is used
Expand Down

0 comments on commit a715fb2

Please sign in to comment.