Skip to content

Commit 6b3ead8

Browse files
added tests for input shape init for composite
1 parent 5c3fe7e commit 6b3ead8

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

src/nemos/basis/_basis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
self._label = str(label)
148148

149149
# specified only after inputs/input shapes are provided
150-
self._input_shape_product = None
150+
self._input_shape_product = getattr(self, "_input_shape_product", None)
151151

152152
# initialize parent to None. This should not end in "_" because it is
153153
# a permanent property of a basis, defined at composite basis init

tests/test_basis.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,6 +2404,33 @@ def test_samples_range_matches_compute_features_requirements(
24042404
class TestAdditiveBasis(CombinedBasis):
24052405
cls = {"eval": AdditiveBasis, "conv": AdditiveBasis}
24062406

2407+
@pytest.mark.parametrize(
2408+
"basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")
2409+
)
2410+
@pytest.mark.parametrize(
2411+
"basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")
2412+
)
2413+
def test_input_shape_product_init(
2414+
self, basis_a, basis_b, basis_class_specific_params
2415+
):
2416+
basis_a_obj = self.instantiate_basis(
2417+
5, basis_a, basis_class_specific_params, window_size=10
2418+
)
2419+
basis_b_obj = self.instantiate_basis(
2420+
6, basis_b, basis_class_specific_params, window_size=10
2421+
)
2422+
add = basis_a_obj + basis_b_obj
2423+
assert add._input_shape_product is None
2424+
basis_a_obj.set_input_shape(())
2425+
add = basis_a_obj + basis_b_obj
2426+
assert add._input_shape_product is None
2427+
basis_b_obj.set_input_shape(())
2428+
add = basis_a_obj + basis_b_obj
2429+
assert add._input_shape_product == (1, 1)
2430+
basis_b_obj.set_input_shape((1, 2, 3))
2431+
add = basis_a_obj + basis_b_obj
2432+
assert add._input_shape_product == (1, 6)
2433+
24072434
@pytest.mark.parametrize("basis_a", list_all_basis_classes())
24082435
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
24092436
def test_len(self, basis_a, basis_b, basis_class_specific_params):
@@ -3513,6 +3540,33 @@ def test_repr_label(self, label, basis_class_specific_params):
35133540
class TestMultiplicativeBasis(CombinedBasis):
35143541
cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis}
35153542

3543+
@pytest.mark.parametrize(
3544+
"basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")
3545+
)
3546+
@pytest.mark.parametrize(
3547+
"basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")
3548+
)
3549+
def test_input_shape_product_init(
3550+
self, basis_a, basis_b, basis_class_specific_params
3551+
):
3552+
basis_a_obj = self.instantiate_basis(
3553+
5, basis_a, basis_class_specific_params, window_size=10
3554+
)
3555+
basis_b_obj = self.instantiate_basis(
3556+
6, basis_b, basis_class_specific_params, window_size=10
3557+
)
3558+
mul = basis_a_obj * basis_b_obj
3559+
assert mul._input_shape_product is None
3560+
basis_a_obj.set_input_shape(())
3561+
mul = basis_a_obj * basis_b_obj
3562+
assert mul._input_shape_product is None
3563+
basis_b_obj.set_input_shape(())
3564+
mul = basis_a_obj * basis_b_obj
3565+
assert mul._input_shape_product == (1, 1)
3566+
basis_b_obj.set_input_shape((1, 2, 3))
3567+
mul = basis_a_obj * basis_b_obj
3568+
assert mul._input_shape_product == (1, 6)
3569+
35163570
@pytest.mark.parametrize("basis_a", list_all_basis_classes())
35173571
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
35183572
def test_len(self, basis_a, basis_b, basis_class_specific_params):

0 commit comments

Comments
 (0)