@@ -2404,6 +2404,33 @@ def test_samples_range_matches_compute_features_requirements(
2404
2404
class TestAdditiveBasis (CombinedBasis ):
2405
2405
cls = {"eval" : AdditiveBasis , "conv" : AdditiveBasis }
2406
2406
2407
+ @pytest .mark .parametrize (
2408
+ "basis_a" , list_all_basis_classes ("Eval" ) + list_all_basis_classes ("Conv" )
2409
+ )
2410
+ @pytest .mark .parametrize (
2411
+ "basis_b" , list_all_basis_classes ("Eval" ) + list_all_basis_classes ("Conv" )
2412
+ )
2413
+ def test_input_shape_product_init (
2414
+ self , basis_a , basis_b , basis_class_specific_params
2415
+ ):
2416
+ basis_a_obj = self .instantiate_basis (
2417
+ 5 , basis_a , basis_class_specific_params , window_size = 10
2418
+ )
2419
+ basis_b_obj = self .instantiate_basis (
2420
+ 6 , basis_b , basis_class_specific_params , window_size = 10
2421
+ )
2422
+ add = basis_a_obj + basis_b_obj
2423
+ assert add ._input_shape_product is None
2424
+ basis_a_obj .set_input_shape (())
2425
+ add = basis_a_obj + basis_b_obj
2426
+ assert add ._input_shape_product is None
2427
+ basis_b_obj .set_input_shape (())
2428
+ add = basis_a_obj + basis_b_obj
2429
+ assert add ._input_shape_product == (1 , 1 )
2430
+ basis_b_obj .set_input_shape ((1 , 2 , 3 ))
2431
+ add = basis_a_obj + basis_b_obj
2432
+ assert add ._input_shape_product == (1 , 6 )
2433
+
2407
2434
@pytest .mark .parametrize ("basis_a" , list_all_basis_classes ())
2408
2435
@pytest .mark .parametrize ("basis_b" , list_all_basis_classes ())
2409
2436
def test_len (self , basis_a , basis_b , basis_class_specific_params ):
@@ -3513,6 +3540,33 @@ def test_repr_label(self, label, basis_class_specific_params):
3513
3540
class TestMultiplicativeBasis (CombinedBasis ):
3514
3541
cls = {"eval" : MultiplicativeBasis , "conv" : MultiplicativeBasis }
3515
3542
3543
+ @pytest .mark .parametrize (
3544
+ "basis_a" , list_all_basis_classes ("Eval" ) + list_all_basis_classes ("Conv" )
3545
+ )
3546
+ @pytest .mark .parametrize (
3547
+ "basis_b" , list_all_basis_classes ("Eval" ) + list_all_basis_classes ("Conv" )
3548
+ )
3549
+ def test_input_shape_product_init (
3550
+ self , basis_a , basis_b , basis_class_specific_params
3551
+ ):
3552
+ basis_a_obj = self .instantiate_basis (
3553
+ 5 , basis_a , basis_class_specific_params , window_size = 10
3554
+ )
3555
+ basis_b_obj = self .instantiate_basis (
3556
+ 6 , basis_b , basis_class_specific_params , window_size = 10
3557
+ )
3558
+ mul = basis_a_obj * basis_b_obj
3559
+ assert mul ._input_shape_product is None
3560
+ basis_a_obj .set_input_shape (())
3561
+ mul = basis_a_obj * basis_b_obj
3562
+ assert mul ._input_shape_product is None
3563
+ basis_b_obj .set_input_shape (())
3564
+ mul = basis_a_obj * basis_b_obj
3565
+ assert mul ._input_shape_product == (1 , 1 )
3566
+ basis_b_obj .set_input_shape ((1 , 2 , 3 ))
3567
+ mul = basis_a_obj * basis_b_obj
3568
+ assert mul ._input_shape_product == (1 , 6 )
3569
+
3516
3570
@pytest .mark .parametrize ("basis_a" , list_all_basis_classes ())
3517
3571
@pytest .mark .parametrize ("basis_b" , list_all_basis_classes ())
3518
3572
def test_len (self , basis_a , basis_b , basis_class_specific_params ):
0 commit comments