From b2b0ba2e19245ab3b82241741f8901476625b1e5 Mon Sep 17 00:00:00 2001 From: Prasanna Balaprakash Date: Sun, 13 Feb 2022 15:42:27 -0600 Subject: [PATCH] update for handing tl with new params --- skopt/space/space.py | 49 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/skopt/space/space.py b/skopt/space/space.py index 2bd000aa..e7df8347 100644 --- a/skopt/space/space.py +++ b/skopt/space/space.py @@ -9,6 +9,7 @@ from sklearn.utils import check_random_state from sklearn.utils.fixes import sp_version +from ConfigSpace.util import get_one_exchange_neighbourhood, get_random_neighbor, deactivate_inactive_hyperparameters if type(sp_version) is not tuple: # Version object since sklearn>=2.3.x @@ -1081,21 +1082,59 @@ def rvs(self, n_samples=1, random_state=None): Points sampled from the space. """ + #n_samples = 100 + rng = check_random_state(random_state) if self.is_config_space: req_points = [] if self.tl_sdv is None: confs = self.config_space.sample_configuration(n_samples) else: - confs_t = self.tl_sdv.sample(n_samples) # we have to check and fix this! - confs = confs_t.to_dict('records') + confs = self.tl_sdv.sample(n_samples) print('successfully sampling with tl_sdv! ') if n_samples == 1: confs = [confs] - print(confs[0]) + + #print(confs) + hps_names = self.config_space.get_hyperparameter_names() - for conf in confs: + sdv_names = confs.columns + + new_hps_names = list(set(hps_names)-set(sdv_names)) + #print(new_hps_names) + + rs = np.random.RandomState() + + # randomly sample the new hyperparameters + for name in new_hps_names: + hp = self.config_space.get_hyperparameter(name) + rvs = [] + for i in range(n_samples): + v = hp._sample(rs) + rv = hp._transform(v) + rvs.append(rv) + confs[name] = rvs + + # reoder the column names + confs = confs[hps_names] + #print(confs) + + confs = confs.to_dict('records') + for idx, conf in enumerate(confs): + cf = deactivate_inactive_hyperparameters(conf,self.config_space) + confs[idx] = cf.get_dictionary() + + # check if other conditions are not met; generate valid 1-exchange neighbor; need to test and develop the logic + if 0: + print('conf invalid...generating valid 1-exchange neighbor') + neighborhood = get_one_exchange_neighbourhood(cf,1) + for new_config in neighborhood: + print(new_config) + print(new_config.is_valid_configuration()) + confs[idx] = new_config.get_dictionary() + + for idx, conf in enumerate(confs): point = [] for hps_name in hps_names: val = np.nan @@ -1105,6 +1144,8 @@ def rvs(self, n_samples=1, random_state=None): val = conf[hps_name] point.append(val) req_points.append(point) + #print(req_points[0]) + return req_points else: if self.tl_sdv is None: