Skip to content

Commit

Permalink
update __eq__ function for parameters to use it to cover test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Dec 28, 2023
1 parent edcb529 commit 83ee506
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 63 deletions.
7 changes: 6 additions & 1 deletion neps/search_spaces/hyperparameters/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def id(self):
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.choices == other.choices and self.value == other.value
return (self.choices == other.choices
and self.value == other.value
and self.is_fidelity == other.is_fidelity
and self.default == other.default
and self.default_confidence_score == other.default_confidence_score
)

def __repr__(self):
return f"<Categorical, choices: {self.choices}, value: {self.value}>"
Expand Down
3 changes: 2 additions & 1 deletion neps/search_spaces/hyperparameters/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, value: Union[float, int, str], is_fidelity: bool = False):
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.value == other.value
return (self.value == other.value
and self.is_fidelity == other.is_fidelity)

def __repr__(self):
return f"<Constant, value: {self.id}>"
Expand Down
3 changes: 3 additions & 0 deletions neps/search_spaces/hyperparameters/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def __eq__(self, other):
self.lower == other.lower
and self.upper == other.upper
and self.log == other.log
and self.is_fidelity == other.is_fidelity
and self.value == other.value
and self.default == other.default
and self.default_confidence_score == other.default_confidence_score
)

def __repr__(self):
Expand Down
83 changes: 22 additions & 61 deletions tests/test_yaml_search_space/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,20 @@ def test_correct_yaml_file(path):
"""Test the function with a correctly formatted YAML file."""
pipeline_space = pipeline_space_from_yaml(path)
assert isinstance(pipeline_space, dict)
assert isinstance(pipeline_space["param_float1"], FloatParameter)
assert pipeline_space["param_float1"].lower == 0.00001
assert pipeline_space["param_float1"].upper == 0.1
assert pipeline_space["param_float1"].log is True
assert pipeline_space["param_float1"].is_fidelity is False
assert pipeline_space["param_float1"].default is None
assert pipeline_space["param_float1"].default_confidence_score == 0.5
assert isinstance(pipeline_space["param_int1"], IntegerParameter)
assert pipeline_space["param_int1"].lower == -3
assert pipeline_space["param_int1"].upper == 30
assert pipeline_space["param_int1"].log is False
assert pipeline_space["param_int1"].is_fidelity is True
assert pipeline_space["param_int1"].default is None
assert pipeline_space["param_int1"].default_confidence_score == 0.5
assert isinstance(pipeline_space["param_int2"], IntegerParameter)
assert pipeline_space["param_int2"].lower == 100
assert pipeline_space["param_int2"].upper == 30000
assert pipeline_space["param_int2"].log is True
assert pipeline_space["param_int2"].is_fidelity is False
assert pipeline_space["param_int2"].default is None
assert pipeline_space["param_int2"].default_confidence_score == 0.5
assert isinstance(pipeline_space["param_float2"], FloatParameter)
assert pipeline_space["param_float2"].lower == 3.3e-5
assert pipeline_space["param_float2"].upper == 0.15
assert pipeline_space["param_float2"].log is False
assert pipeline_space["param_float2"].is_fidelity is False
assert pipeline_space["param_float2"].default is None
assert pipeline_space["param_float2"].default_confidence_score == 0.5
assert isinstance(pipeline_space["param_cat"], CategoricalParameter)
assert pipeline_space["param_cat"].choices == [2, "sgd", 10e-3]
assert pipeline_space["param_cat"].is_fidelity is False
assert pipeline_space["param_cat"].default is None
assert pipeline_space["param_cat"].default_confidence_score == 2
assert isinstance(pipeline_space["param_const1"], ConstantParameter)
assert pipeline_space["param_const1"].value == 0.5
assert pipeline_space["param_const1"].is_fidelity is False
assert isinstance(pipeline_space["param_const2"], ConstantParameter)
assert pipeline_space["param_const2"].value == 1e3
assert pipeline_space["param_const2"].is_fidelity is True
float1 = FloatParameter(0.00001, 0.1, True, False)
assert float1.__eq__(pipeline_space["param_float1"]) is True
int1 = IntegerParameter(-3, 30, False, True)
assert int1.__eq__(pipeline_space["param_int1"]) is True
int2 = IntegerParameter(100, 30000, True, False)
assert int2.__eq__(pipeline_space["param_int2"]) is True
float2 = FloatParameter(3.3e-5, 0.15, False, False)
assert float2.__eq__(pipeline_space["param_float2"]) is True
cat1 = CategoricalParameter([2, "sgd", 10e-3], False)
assert cat1.__eq__(pipeline_space["param_cat"]) is True
const1 = ConstantParameter(0.5, False)
assert const1.__eq__(pipeline_space["param_const1"]) is True
const2 = ConstantParameter(1e3, True)
assert const2.__eq__(pipeline_space["param_const2"]) is True

test_correct_yaml_file(BASE_PATH + "correct_config.yaml")
test_correct_yaml_file(
Expand All @@ -71,28 +46,14 @@ def test_correct_including_priors_yaml_file():
BASE_PATH + "correct_config_including_priors.yml"
)
assert isinstance(pipeline_space, dict)
assert isinstance(pipeline_space["learning_rate"], FloatParameter)
assert pipeline_space["learning_rate"].lower == 0.00001
assert pipeline_space["learning_rate"].upper == 0.1
assert pipeline_space["learning_rate"].log is True
assert pipeline_space["learning_rate"].is_fidelity is False
assert pipeline_space["learning_rate"].default == 3.3e-2
assert pipeline_space["learning_rate"].default_confidence_score == 0.125
assert isinstance(pipeline_space["num_epochs"], IntegerParameter)
assert pipeline_space["num_epochs"].lower == 3
assert pipeline_space["num_epochs"].upper == 30
assert pipeline_space["num_epochs"].log is False
assert pipeline_space["num_epochs"].is_fidelity is True
assert pipeline_space["num_epochs"].default == 10
assert pipeline_space["num_epochs"].default_confidence_score == 0.5
assert isinstance(pipeline_space["optimizer"], CategoricalParameter)
assert pipeline_space["optimizer"].choices == ["adam", 90e-3, "rmsprop"]
assert pipeline_space["optimizer"].is_fidelity is False
assert pipeline_space["optimizer"].default == 90e-3
assert pipeline_space["optimizer"].default_confidence_score == 4
assert isinstance(pipeline_space["dropout_rate"], ConstantParameter)
assert pipeline_space["dropout_rate"].value == 1e3
assert pipeline_space["dropout_rate"].default == 1e3
float1 = FloatParameter(0.00001, 0.1, True, False, 3.3e-2, "high")
assert float1.__eq__(pipeline_space["learning_rate"]) is True
int1 = IntegerParameter(3, 30, False, True, 10)
assert int1.__eq__(pipeline_space["num_epochs"]) is True
cat1 = CategoricalParameter(["adam", 90e-3, "rmsprop"], False, 90e-3, "medium")
assert cat1.__eq__(pipeline_space["optimizer"]) is True
const1 = ConstantParameter(1e3, True)
assert const1.__eq__(pipeline_space["dropout_rate"]) is True


@pytest.mark.neps_api
Expand Down

0 comments on commit 83ee506

Please sign in to comment.