-
Notifications
You must be signed in to change notification settings - Fork 13
Add port of FISTA and NAG from JAXopt #411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Conversation
In JAXopt solvers it was already removed. It is created by the regularizer, not settable by users.
Also correct OptimistixOptax... names in the dev notes
Also remove unused code.
src/nemos/solvers/_fista_port.py
Outdated
| @@ -0,0 +1,333 @@ | |||
| """Adaptation of JAXopt's ProximalGradient (FISTA) as an Optimistix IterativeSolver.""" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rename this as _fista_linesearch.py or something like that. True that the algorithm is inspired by the jaxopt implementation but you did change enough of it to match the optimistix API that we can omit the fact that it is a port.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed it to _fista.py and rephrased the module's docstring, keeping "Adapted from JAXopt." to credit them.
tests/conftest.py
Outdated
| **_common_solvers, | ||
| "GradientDescent": nmo.solvers.OptimistixOptaxGradientDescent, | ||
| "ProximalGradient": nmo.solvers.OptimistixOptaxProximalGradient, | ||
| # TODO: OptaxOptimistixGradientDescent is not tested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the default is using OptimistixNAG but we could still use the Optax GD if one wants to.
One way to test this is to use unittest mock to mock the solver registry in the test:
from unittest.mock import patch
# In your test
with patch.dict('nemos.solvers.solver_registry',
{'GradientDescent': nmo.solvers.OptimistixOptaxGradientDescent}):
model = Model(solver_name='GradientDescent2', ...)
model.instantiate_solver()
assert isinstance(model._solver, nmo.solvers.OptimistixOptaxGradientDescent)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem with that is that it would have to be done for every test that uses gradient descent.
I added a solution where only tests referencing "GradientDescent" explicitly are rerun with the OptimistixOptaxGradientDescent implementation and added it as a step to the CI job for the Optimistix backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this so that it is now controlled by pytest -o override_solver=GradientDescent:OptimistixOptaxGradientDescent and testing both implementations is done in the same tox environment.
As Optimistix doesn't implement
ProximalGradient, this is currently set to use a custom implementation inOptaxOptimistixProximalGradientbased on Optax's SGD with linesearch followed by the proximal operator.This seems to work in practice, but doesn't support the original Nesterov acceleration and is not as theoretically sound as
jaxopt.ProximalGradientwhich implements the FISTA algorithm.This PR adds
OptimistixFISTA, a port of JAXopt'sProximalGradientthat adheres to Optimistix's interface, and replacesOptaxOptimistixProximalGradientwith this when using the Optimistix backend.Similarly, to support the original Nesterov acceleration in gradient descent, a short modification of this as
OptimistixNAGis added. NAG stands for Nesterov's Accelerated Gradient.Additional changes:
OptimistixOptaxProximalGradientstateful_scale_by_learning_rate(and related code) which was only needed forOptimistixOptaxProximalGradientOptimistixOptaxGradientDescentto useoptax.scale_by_backtracking_linesearchOptimistixOptaxGradientDescentis kept but is replaced in the tests. Ideally, it would be tested the same way asOptimistixNAG.@BalzaniEdoardo, @billbrod is that fine with you, or is a way to include this in the tests without unnecessarily rerunning others required?
Fixes #380