Skip to content

Commit

Permalink
Shortened the ClippedAdam centered variance test and added an option …
Browse files Browse the repository at this point in the history
…to run the full test with plots via a pytest command line option.
  • Loading branch information
BenZickel committed Jan 25, 2025
1 parent efa56d9 commit 4d8d1c1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
10 changes: 10 additions & 0 deletions tests/optim/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytest.mark.stage("unit"))
if "init" not in item.keywords:
item.add_marker(pytest.mark.init(rng_seed=123))


def pytest_addoption(parser):
parser.addoption("--plot", action="store", default="FALSE")


def pytest_generate_tests(metafunc):
option_value = metafunc.config.option.plot != "FALSE"
if "plot" in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("plot", [option_value])
21 changes: 16 additions & 5 deletions tests/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,20 @@ def step(svi, optimizer):
assert_equal(actual, expected)


def test_centered_clipped_adam(plot_results=False):
def test_centered_clipped_adam(plot):
"""
Test the centered variance option of the ClippedAdam optimizer.
In order to create plots run pytest with the plot command line
option set to True, i.e. by executing
'pytest tests/optim/test_optim.py::test_centered_clipped_adam --plot True'
"""
if not plot:
lr_vec = [0.1, 0.001]
else:
lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]

w = torch.Tensor([1, 500])

def loss_fn(p):
Expand Down Expand Up @@ -484,14 +497,12 @@ def get_convergence_vec(lr_vec, centered_variance):
ultimate_loss_vec.append(ultimate_loss)
convergence_rate_vec.append(convergence_rate)
convergence_iter_vec.append(convergence_iter)
print(lr, centered_variance, ultimate_loss, convergence_rate)
return (
torch.Tensor(ultimate_loss_vec),
torch.Tensor(convergence_rate_vec),
convergence_iter_vec,
)

lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
(
centered_ultimate_loss_vec,
centered_convergence_rate_vec,
Expand All @@ -508,10 +519,10 @@ def get_convergence_vec(lr_vec, centered_variance):
# Verify convergence rate improvement
assert (
(centered_convergence_rate_vec / convergence_rate_vec)
> (torch.Tensor([1.2] * len(lr_vec)).cumprod(0))
> ((0.12 / torch.Tensor(lr_vec)).log() * 1.08)
).all()

if plot_results:
if plot:
from matplotlib import pyplot as plt

plt.figure(figsize=(6, 8))
Expand Down

0 comments on commit 4d8d1c1

Please sign in to comment.