From a715fb238684ae2ab171eb8eeb6699f5670a1204 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Wed, 19 Jan 2022 15:54:35 +0100 Subject: [PATCH] adding 'softmax_sampling' based on stochastic policies --- .gitignore | 1 + skopt/optimizer/optimizer.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 8cd458d2d..2f70da6b2 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ doc/auto_examples # vim users .*.swp +.vscode/settings.json diff --git a/skopt/optimizer/optimizer.py b/skopt/optimizer/optimizer.py index 28923d728..c150299b8 100644 --- a/skopt/optimizer/optimizer.py +++ b/skopt/optimizer/optimizer.py @@ -262,17 +262,17 @@ def __init__( else: acq_optimizer = "sampling" - if acq_optimizer not in ["lbfgs", "sampling"]: + if acq_optimizer not in ["lbfgs", "sampling", "softmax_sampling"]: raise ValueError( "Expected acq_optimizer to be 'lbfgs' or " - "'sampling', got {0}".format(acq_optimizer) + "'sampling' or 'softmax_sampling', got {0}".format(acq_optimizer) ) - if not has_gradients(self.base_estimator_) and acq_optimizer != "sampling": + if not has_gradients(self.base_estimator_) and not("sampling" in acq_optimizer): raise ValueError( - "The regressor {0} should run with " - "acq_optimizer" - "='sampling'.".format(type(base_estimator)) + "The regressor {0} should run with a 'sampling' " + "acq_optimizer such as " + "'sampling' or 'softmax_sampling'.".format(type(base_estimator)) ) self.acq_optimizer = acq_optimizer @@ -655,6 +655,11 @@ def _tell(self, x, y, fit=True): if self.acq_optimizer == "sampling": next_x = X[np.argmin(values)] + elif self.acq_optimizer == "softmax_sampling": + probs = values / np.sum(values) + idx = np.argmax(self.rng.multinomial(1, probs)) + next_x = X[idx] + # Use BFGS to find the mimimum of the acquisition function, the # minimization starts from `n_restarts_optimizer` different # points and the best minimum is used