Skip to content

Commit

Permalink
feat(optim): Adadelta RAdam Adamax optimizer support (#171)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin-eecs <[email protected]>
  • Loading branch information
JieRen98 and Benjamin-eecs committed Jul 22, 2023
1 parent 3f68378 commit ef51cc8
Show file tree
Hide file tree
Showing 30 changed files with 1,775 additions and 13 deletions.
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ repos:
- id: detect-private-key
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v16.0.6
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
hooks:
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Implement `Adadelta`, `RAdam`, `Adamax` optimizer by [@JieRen98](https://github.com/JieRen98) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#171](https://github.com/metaopt/torchopt/pull/171).

### Changed

Expand Down
60 changes: 60 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ Functional Optimizers
.. autosummary::

FuncOptimizer
adadelta
adagrad
adam
adamw
adamax
radam
rmsprop
sgd

Expand All @@ -42,6 +45,11 @@ Wrapper for Function Optimizer
.. autoclass:: FuncOptimizer
:members:

Functional AdaDelta Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: adadelta

Functional AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -57,6 +65,16 @@ Functional AdamW Optimizer

.. autofunction:: adamw

Functional AdaMax Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: adamax

Functional RAdam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: radam

Functional RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -76,12 +94,23 @@ Classic Optimizers

.. autosummary::

AdaDelta
Adadelta
AdaGrad
Adagrad
Adam
AdamW
AdaMax
Adamax
RAdam
RMSProp
SGD

Classic AdaDelta Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdaDelta

Classic AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -97,6 +126,16 @@ Classic AdamW Optimizer

.. autoclass:: AdamW

Classic AdaMax Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdaMax

Classic RAdam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RAdam

Classic RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -116,12 +155,23 @@ Differentiable Meta-Optimizers

.. autosummary::

MetaAdaDelta
MetaAdadelta
MetaAdaGrad
MetaAdagrad
MetaAdam
MetaAdamW
MetaAdaMax
MetaAdamax
MetaRAdam
MetaRMSProp
MetaSGD

Differentiable Meta-AdaDelta Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaAdaDelta

Differentiable Meta-AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -137,6 +187,16 @@ Differentiable Meta-AdamW Optimizer

.. autoclass:: MetaAdamW

Differentiable Meta-AdaMax Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaAdaMax

Differentiable Meta-RAdam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaRAdam

Differentiable Meta-RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/explicit_diff/explicit_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,15 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho
.. autosummary::

torchopt.MetaOptimizer
torchopt.MetaAdaDelta
torchopt.MetaAdadelta
torchopt.MetaAdaGrad
torchopt.MetaAdagrad
torchopt.MetaAdam
torchopt.MetaAdamW
torchopt.AdaMax
torchopt.MetaAdamax
torchopt.MetaRAdam
torchopt.MetaRMSProp
torchopt.MetaSGD

Expand Down
9 changes: 9 additions & 0 deletions docs/source/optimizer/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`,
.. autosummary::

torchopt.FuncOptimizer
torchopt.adadelta
torchopt.adagrad
torchopt.adam
torchopt.adamw
torchopt.adamax
torchopt.radam
torchopt.rmsprop
torchopt.sgd

Expand Down Expand Up @@ -85,9 +88,15 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi
.. autosummary::

torchopt.Optimizer
torchopt.AdaDelta
torchopt.Adadelta
torchopt.AdaGrad
torchopt.Adagrad
torchopt.Adam
torchopt.AdamW
torchopt.AdaMax
torchopt.Adamax
torchopt.RAdam
torchopt.RMSProp
torchopt.SGD

Expand Down
7 changes: 7 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,10 @@ ctx
Duchi
invertible
AdaGrad
Adadelta
Zeiler
radam
adamax
RAdam
AdaDelta
AdaMax
171 changes: 171 additions & 0 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,63 @@ def test_sgd(
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
rho=[0.9, 0.95],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
use_chain_flat=[True, False],
)
def test_adadelta(
dtype: torch.dtype,
lr: float,
rho: float,
eps: float,
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.adadelta(
lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
)
optim_state = optim.init(params)
optim_ref = torch.optim.Adadelta(
model_ref.parameters(),
lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
Expand Down Expand Up @@ -210,6 +267,120 @@ def test_adam(
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
use_chain_flat=[True, False],
)
def test_radam(
dtype: torch.dtype,
lr: float,
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.radam(
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
optim_state = optim.init(params)
optim_ref = torch.optim.RAdam(
model_ref.parameters(),
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
use_chain_flat=[True, False],
)
def test_adamax(
dtype: torch.dtype,
lr: float,
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.adamax(
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
optim_state = optim.init(params)
optim_ref = torch.optim.Adamax(
model_ref.parameters(),
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
outer_lr=[1e-2, 1e-3, 1e-4],
Expand Down
Loading

0 comments on commit ef51cc8

Please sign in to comment.