Skip to content

Commit

Permalink
[MRG] Fix bug with categorical spaces with different types of ca… (sc…
Browse files Browse the repository at this point in the history
…ikit-optimize#752)

* Fix bug with categorical spaces with different types of categories

* fix for python2
  • Loading branch information
stefanocereda authored and betatim committed Jan 14, 2020
1 parent 4d5379b commit 7720fda
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
9 changes: 7 additions & 2 deletions skopt/space/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,11 @@ def __init__(self, categories, prior=None, transform=None, name=None):
* `name` [str or None]:
Name associated with dimension, e.g., "colors".
"""
self.categories = tuple(categories)
if transform == 'identity':
self.categories = tuple([str(c) for c in categories])
else:
self.categories = tuple(categories)

self.name = name

if transform is None:
Expand All @@ -423,7 +427,8 @@ def __init__(self, categories, prior=None, transform=None, name=None):
self.transformer = CategoricalEncoder()
self.transformer.fit(self.categories)
else:
self.transformer = Identity()
self.transformer = Identity(dtype=type(categories[0]))

self.prior = prior

if prior is None:
Expand Down
17 changes: 14 additions & 3 deletions skopt/space/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@ def inverse_transform(self, X):


class Identity(Transformer):
"""Identity transform."""
"""Identity transform.
If dtype is different from None the transform will cast everything to a
string and the inverse transform will cast to the type defined in dtype."""

def __init__(self, dtype=None):
super(Identity, self).__init__()
self.dtype = dtype

def transform(self, X):
return X
if self.dtype is None:
return X
return [str(x) for x in X]


def inverse_transform(self, Xt):
return Xt
if self.dtype is None:
return Xt
return [self.dtype(x) for x in Xt]


class Log10(Transformer):
Expand Down
8 changes: 8 additions & 0 deletions skopt/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ def test_normalize_dimensions_all_categorical():
assert space.is_categorical


@pytest.mark.fast_test
def test_categoricals_mixed_types():
domain = [[1, 2, 3, 4], ['a', 'b', 'c'], [True, False]]
x = [1, 'a', True]
space = normalize_dimensions(domain)
assert (space.inverse_transform(space.transform([x])) == [x])


@pytest.mark.fast_test
@pytest.mark.parametrize("dimensions, normalizations",
[(((1, 3), (1., 3.)),
Expand Down

0 comments on commit 7720fda

Please sign in to comment.