From 157f50b2abe48f5eb04a83ba0dc25b505b664253 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Tue, 5 Nov 2024 16:45:13 -0500 Subject: [PATCH] refactored test --- test/test_stepsize.py | 39 ++++++++++----------------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/test/test_stepsize.py b/test/test_stepsize.py index 288bb04f7..25af774fa 100644 --- a/test/test_stepsize.py +++ b/test/test_stepsize.py @@ -48,29 +48,29 @@ def test_adaptive_stepsize_shrinks(cutback_factor, target): assert np.isclose(new_value, expected_value) -@pytest.mark.parametrize("max_stepsize, growth_factor, target", [(4, 3, 1)]) -def test_max_stepsize(max_stepsize, growth_factor, target): +@pytest.mark.parametrize("nb_its, target", [(1, 4), (5, 4), (4, 4)]) +def test_max_stepsize(nb_its, target): """Checks that the stepsize is capped at max stepsize. Args: - max_stepsize (float): maximum stepsize - growth_factor (float): the growth factor + nb_its (int): the current number of iterations target (int): the target number of iterations """ - my_stepsize = F.Stepsize(initial_value=2) - my_stepsize.max_stepsize = max_stepsize - my_stepsize.growth_factor = growth_factor + my_stepsize = F.Stepsize(initial_value=1) + my_stepsize.max_stepsize = 4 + my_stepsize.growth_factor = 1.1 + my_stepsize.cutback_factor = 0.9 my_stepsize.target_nb_iterations = target - current_value = 2 + current_value = 10 new_value = my_stepsize.modify_value( value=current_value, - nb_iterations=my_stepsize.target_nb_iterations - 1, + nb_iterations=nb_its, ) - expected_value = max_stepsize + expected_value = my_stepsize.max_stepsize assert new_value == expected_value @@ -170,22 +170,3 @@ def test_no_milestones(): # Test that setting milestones to None works stepsize.milestones = None assert stepsize.milestones is None - - -def test_modify_for_stepsize(): - """Tests that modify_value returns max_stepsize - when max_stepsize is less than next stepsize. - """ - my_stepsize = F.Stepsize(initial_value=1) - my_stepsize.max_stepsize = 1 - my_stepsize.growth_factor = 1.1 - my_stepsize.target_nb_iterations = 4 - - current_value = 2 - new_value = my_stepsize.modify_value( - value=current_value, - nb_iterations=my_stepsize.target_nb_iterations, - ) - - expected_value = 1 - assert new_value == expected_value