Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 18, 2024
1 parent de2c230 commit 6f5ddef
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tests/test_transformer_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,17 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, basis_class_specific_param

# changing an attribute in bas should not change trans_bas
if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]:
bas_a._basis1.n_basis_funcs = 10
assert trans_bas_a._basis._basis1.n_basis_funcs == 5
bas_a.basis1.n_basis_funcs = 10
assert trans_bas_a._basis.basis1.n_basis_funcs == 5

# changing an attribute in the transformer basis should not change the original
bas_b = CombinedBasis().instantiate_basis(
5, basis_cls, basis_class_specific_params, window_size=10
)
bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality))
trans_bas_b = bas_b.to_transformer()
trans_bas_b._basis._basis1.n_basis_funcs = 100
assert bas_b._basis1.n_basis_funcs == 5
trans_bas_b._basis.basis1.n_basis_funcs = 100
assert bas_b.basis1.n_basis_funcs == 5
else:
bas_a.n_basis_funcs = 10
assert trans_bas_a.n_basis_funcs == 5
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_transformerbasis_getattr(
)
if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]:
for bas in [
getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2")
getattr(trans_basis._basis, attr) for attr in ("basis1", "basis2")
]:
assert bas.n_basis_funcs == n_basis_funcs
else:
Expand Down Expand Up @@ -292,8 +292,8 @@ def test_transformerbasis_addition(basis_cls, basis_class_specific_params):
== trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality
)
if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]:
assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a
assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b
assert trans_bas_sum.basis1.n_basis_funcs == n_basis_funcs_a
assert trans_bas_sum.basis2.n_basis_funcs == n_basis_funcs_b


@pytest.mark.parametrize(
Expand Down Expand Up @@ -327,8 +327,8 @@ def test_transformerbasis_multiplication(basis_cls, basis_class_specific_params)
== trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality
)
if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]:
assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a
assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b
assert trans_bas_prod.basis1.n_basis_funcs == n_basis_funcs_a
assert trans_bas_prod.basis2.n_basis_funcs == n_basis_funcs_b


@pytest.mark.parametrize(
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_transformerbasis_pickle(
assert isinstance(trans_bas2, basis.TransformerBasis)
if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]:
for bas in [
getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2")
getattr(trans_bas2._basis, attr) for attr in ("basis1", "basis2")
]:
assert bas.n_basis_funcs == n_basis_funcs
else:
Expand Down Expand Up @@ -739,7 +739,7 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params):

# set basis & refit
if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)):
pipe.set_params(bas__basis2__n_basis_funcs=4)
pipe.set_params(bas_basis2__n_basis_funcs=4)
assert (
bas.basis2.n_basis_funcs == 5
) # make sure that the change did not affect bas
Expand Down

0 comments on commit 6f5ddef

Please sign in to comment.