Skip to content

Commit

Permalink
update for handing tl with new params
Browse files Browse the repository at this point in the history
  • Loading branch information
pbalapra committed Feb 13, 2022
1 parent 5eed19b commit b2b0ba2
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions skopt/space/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sklearn.utils import check_random_state
from sklearn.utils.fixes import sp_version
from ConfigSpace.util import get_one_exchange_neighbourhood, get_random_neighbor, deactivate_inactive_hyperparameters


if type(sp_version) is not tuple: # Version object since sklearn>=2.3.x
Expand Down Expand Up @@ -1081,21 +1082,59 @@ def rvs(self, n_samples=1, random_state=None):
Points sampled from the space.
"""

#n_samples = 100

rng = check_random_state(random_state)
if self.is_config_space:
req_points = []
if self.tl_sdv is None:
confs = self.config_space.sample_configuration(n_samples)
else:
confs_t = self.tl_sdv.sample(n_samples) # we have to check and fix this!
confs = confs_t.to_dict('records')
confs = self.tl_sdv.sample(n_samples)
print('successfully sampling with tl_sdv! ')

if n_samples == 1:
confs = [confs]
print(confs[0])

#print(confs)

hps_names = self.config_space.get_hyperparameter_names()
for conf in confs:
sdv_names = confs.columns

new_hps_names = list(set(hps_names)-set(sdv_names))
#print(new_hps_names)

rs = np.random.RandomState()

This comment has been minimized.

Copy link
@Deathn0t

Deathn0t Feb 14, 2022

Member

Hello @pbalapra , Is this new rs = np.random.RandomState (also it does not use any seed) necessary? There is rng = check_random_state(random_state) defined just above.


# randomly sample the new hyperparameters
for name in new_hps_names:
hp = self.config_space.get_hyperparameter(name)
rvs = []
for i in range(n_samples):
v = hp._sample(rs)
rv = hp._transform(v)
rvs.append(rv)
confs[name] = rvs

# reoder the column names
confs = confs[hps_names]
#print(confs)

confs = confs.to_dict('records')
for idx, conf in enumerate(confs):
cf = deactivate_inactive_hyperparameters(conf,self.config_space)
confs[idx] = cf.get_dictionary()

# check if other conditions are not met; generate valid 1-exchange neighbor; need to test and develop the logic
if 0:

This comment has been minimized.

Copy link
@Deathn0t

Deathn0t Feb 14, 2022

Member

@pbalapra This if statement is never used?

print('conf invalid...generating valid 1-exchange neighbor')
neighborhood = get_one_exchange_neighbourhood(cf,1)
for new_config in neighborhood:
print(new_config)
print(new_config.is_valid_configuration())
confs[idx] = new_config.get_dictionary()

for idx, conf in enumerate(confs):
point = []
for hps_name in hps_names:
val = np.nan
Expand All @@ -1105,6 +1144,8 @@ def rvs(self, n_samples=1, random_state=None):
val = conf[hps_name]
point.append(val)
req_points.append(point)
#print(req_points[0])

return req_points
else:
if self.tl_sdv is None:
Expand Down

0 comments on commit b2b0ba2

Please sign in to comment.