Skip to content

Commit

Permalink
refactored test
Browse files Browse the repository at this point in the history
  • Loading branch information
RemDelaporteMathurin committed Nov 5, 2024
1 parent c5c7760 commit 157f50b
Showing 1 changed file with 10 additions and 29 deletions.
39 changes: 10 additions & 29 deletions test/test_stepsize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,29 @@ def test_adaptive_stepsize_shrinks(cutback_factor, target):
assert np.isclose(new_value, expected_value)


@pytest.mark.parametrize("max_stepsize, growth_factor, target", [(4, 3, 1)])
def test_max_stepsize(max_stepsize, growth_factor, target):
@pytest.mark.parametrize("nb_its, target", [(1, 4), (5, 4), (4, 4)])
def test_max_stepsize(nb_its, target):
"""Checks that the stepsize is capped at max
stepsize.
Args:
max_stepsize (float): maximum stepsize
growth_factor (float): the growth factor
nb_its (int): the current number of iterations
target (int): the target number of iterations
"""

my_stepsize = F.Stepsize(initial_value=2)
my_stepsize.max_stepsize = max_stepsize
my_stepsize.growth_factor = growth_factor
my_stepsize = F.Stepsize(initial_value=1)
my_stepsize.max_stepsize = 4
my_stepsize.growth_factor = 1.1
my_stepsize.cutback_factor = 0.9
my_stepsize.target_nb_iterations = target

current_value = 2
current_value = 10
new_value = my_stepsize.modify_value(
value=current_value,
nb_iterations=my_stepsize.target_nb_iterations - 1,
nb_iterations=nb_its,
)

expected_value = max_stepsize
expected_value = my_stepsize.max_stepsize
assert new_value == expected_value


Expand Down Expand Up @@ -170,22 +170,3 @@ def test_no_milestones():
# Test that setting milestones to None works
stepsize.milestones = None
assert stepsize.milestones is None


def test_modify_for_stepsize():
"""Tests that modify_value returns max_stepsize
when max_stepsize is less than next stepsize.
"""
my_stepsize = F.Stepsize(initial_value=1)
my_stepsize.max_stepsize = 1
my_stepsize.growth_factor = 1.1
my_stepsize.target_nb_iterations = 4

current_value = 2
new_value = my_stepsize.modify_value(
value=current_value,
nb_iterations=my_stepsize.target_nb_iterations,
)

expected_value = 1
assert new_value == expected_value

0 comments on commit 157f50b

Please sign in to comment.