Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optim): Adadelta RAdam Adamax optimizer support #171

Merged
merged 39 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2bbe86e
feat(optim): support adadelta
JieRen98 Jun 18, 2023
d7d8e9c
feat(optim): support adadelta
JieRen98 Jun 19, 2023
6e48ede
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Jun 19, 2023
1af1159
feat(optim): support adadelta
JieRen98 Jun 19, 2023
bc2de8f
feat(optim): support adadelta
JieRen98 Jun 19, 2023
5094ff2
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Jun 19, 2023
ecc0019
feat(optim): support adadelta
JieRen98 Jun 24, 2023
f427051
feat(optim): support RAdam
JieRen98 Jun 29, 2023
b189e19
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Jun 29, 2023
4630e17
Merge branch 'main' into pr/JieRen98/171
XuehaiPan Jul 1, 2023
19ea7e4
fix(optim): fix RAdam
JieRen98 Jul 2, 2023
61479db
fix(optim): fix RAdam
JieRen98 Jul 2, 2023
1ad5985
fix(optim): fix RAdam
JieRen98 Jul 2, 2023
d3c7fda
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Jul 2, 2023
93b4a47
feat(optim): support adamax
JieRen98 Jul 2, 2023
1a3f651
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Jul 2, 2023
30fec6a
fix(optim): fix lint
JieRen98 Jul 3, 2023
eb66f8f
fix(optim): add step counter
JieRen98 Jul 3, 2023
77ab347
fix(optim): fix return type
JieRen98 Jul 3, 2023
545cfe5
fix(optim): fix for None case
JieRen98 Jul 5, 2023
ea449dc
Merge remote-tracking branch 'upstream/main' into main
Benjamin-eecs Jul 22, 2023
67326f3
chore: update CHANGELOG
Benjamin-eecs Jul 22, 2023
8bd00ee
docs: update optim.rst
Benjamin-eecs Jul 22, 2023
537fe63
fix: pass lint
Benjamin-eecs Jul 22, 2023
978f640
docs: update api.rst
Benjamin-eecs Jul 22, 2023
921b76e
docs: update explicit_diff.rst
Benjamin-eecs Jul 22, 2023
47c9ba3
chore: update CHANGELOG
Benjamin-eecs Jul 22, 2023
a31e7da
fix: pass lint
Benjamin-eecs Jul 22, 2023
ce9fbb1
chore: update CHANGELOG
Benjamin-eecs Jul 22, 2023
90b149a
test: update import
Benjamin-eecs Jul 22, 2023
db65c57
fix: update naming
Benjamin-eecs Jul 22, 2023
fcd408a
test: update import
Benjamin-eecs Jul 22, 2023
76d91c9
fix: update docstring
Benjamin-eecs Jul 22, 2023
b5e5b9c
docs: fix naming
Benjamin-eecs Jul 22, 2023
d407d2c
fix: update docstring
Benjamin-eecs Jul 22, 2023
f157e2b
fix: pass lint
Benjamin-eecs Jul 22, 2023
9b3b8c2
fix: pass lint
Benjamin-eecs Jul 22, 2023
e222f6d
fix: pass lint
Benjamin-eecs Jul 22, 2023
3e90e72
chore: remove clang-format in pre-commit
Benjamin-eecs Jul 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ repos:
- id: detect-private-key
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v16.0.6
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
hooks:
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Implement `Adadelta`, `RAdam`, `Adamax` optimizer by [@JieRen98](https://github.com/JieRen98) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#171](https://github.com/metaopt/torchopt/pull/171).

### Changed

Expand Down
60 changes: 60 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ Functional Optimizers
.. autosummary::

FuncOptimizer
adadelta
adagrad
adam
adamw
adamax
radam
rmsprop
sgd

Expand All @@ -42,6 +45,11 @@ Wrapper for Function Optimizer
.. autoclass:: FuncOptimizer
:members:

Functional AdaDelta Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: adadelta

Functional AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -57,6 +65,16 @@ Functional AdamW Optimizer

.. autofunction:: adamw

Functional AdaMax Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: adamax

Functional RAdam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: radam

Functional RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -76,12 +94,23 @@ Classic Optimizers

.. autosummary::

AdaDelta
Adadelta
AdaGrad
Adagrad
Adam
AdamW
AdaMax
Adamax
RAdam
RMSProp
SGD

Classic AdaDelta Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdaDelta

Classic AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -97,6 +126,16 @@ Classic AdamW Optimizer

.. autoclass:: AdamW

Classic AdaMax Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdaMax

Classic RAdam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RAdam

Classic RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -116,12 +155,23 @@ Differentiable Meta-Optimizers

.. autosummary::

MetaAdaDelta
MetaAdadelta
MetaAdaGrad
MetaAdagrad
MetaAdam
MetaAdamW
MetaAdaMax
MetaAdamax
MetaRAdam
MetaRMSProp
MetaSGD

Differentiable Meta-AdaDelta Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaAdaDelta

Differentiable Meta-AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -137,6 +187,16 @@ Differentiable Meta-AdamW Optimizer

.. autoclass:: MetaAdamW

Differentiable Meta-AdaMax Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaAdaMax

Differentiable Meta-RAdam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaRAdam

Differentiable Meta-RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/explicit_diff/explicit_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,15 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho
.. autosummary::

torchopt.MetaOptimizer
torchopt.MetaAdaDelta
torchopt.MetaAdadelta
torchopt.MetaAdaGrad
torchopt.MetaAdagrad
torchopt.MetaAdam
torchopt.MetaAdamW
torchopt.AdaMax
torchopt.MetaAdamax
torchopt.MetaRAdam
torchopt.MetaRMSProp
torchopt.MetaSGD

Expand Down
9 changes: 9 additions & 0 deletions docs/source/optimizer/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`,
.. autosummary::

torchopt.FuncOptimizer
torchopt.adadelta
torchopt.adagrad
torchopt.adam
torchopt.adamw
torchopt.adamax
torchopt.radam
torchopt.rmsprop
torchopt.sgd

Expand Down Expand Up @@ -85,9 +88,15 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi
.. autosummary::

torchopt.Optimizer
torchopt.AdaDelta
torchopt.Adadelta
torchopt.AdaGrad
torchopt.Adagrad
torchopt.Adam
torchopt.AdamW
torchopt.AdaMax
torchopt.Adamax
torchopt.RAdam
torchopt.RMSProp
torchopt.SGD

Expand Down
7 changes: 7 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,10 @@ ctx
Duchi
invertible
AdaGrad
Adadelta
Zeiler
radam
adamax
RAdam
AdaDelta
AdaMax
171 changes: 171 additions & 0 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,63 @@ def test_sgd(
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
rho=[0.9, 0.95],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
use_chain_flat=[True, False],
)
def test_adadelta(
dtype: torch.dtype,
lr: float,
rho: float,
eps: float,
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.adadelta(
lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
)
optim_state = optim.init(params)
optim_ref = torch.optim.Adadelta(
model_ref.parameters(),
lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
Expand Down Expand Up @@ -210,6 +267,120 @@ def test_adam(
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
use_chain_flat=[True, False],
)
def test_radam(
dtype: torch.dtype,
lr: float,
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.radam(
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
optim_state = optim.init(params)
optim_ref = torch.optim.RAdam(
model_ref.parameters(),
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
use_chain_flat=[True, False],
)
def test_adamax(
dtype: torch.dtype,
lr: float,
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.adamax(
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
optim_state = optim.init(params)
optim_ref = torch.optim.Adamax(
model_ref.parameters(),
lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
outer_lr=[1e-2, 1e-3, 1e-4],
Expand Down
Loading