Skip to content

Conversation

@bagibence
Copy link
Collaborator

As Optimistix doesn't implement ProximalGradient, this is currently set to use a custom implementation in OptaxOptimistixProximalGradient based on Optax's SGD with linesearch followed by the proximal operator.
This seems to work in practice, but doesn't support the original Nesterov acceleration and is not as theoretically sound as jaxopt.ProximalGradient which implements the FISTA algorithm.

This PR adds OptimistixFISTA, a port of JAXopt's ProximalGradient that adheres to Optimistix's interface, and replaces OptaxOptimistixProximalGradient with this when using the Optimistix backend.

Similarly, to support the original Nesterov acceleration in gradient descent, a short modification of this as OptimistixNAG is added. NAG stands for Nesterov's Accelerated Gradient.

Additional changes:

  • Remove OptimistixOptaxProximalGradient
  • Remove stateful_scale_by_learning_rate (and related code) which was only needed for OptimistixOptaxProximalGradient
  • Switch OptimistixOptaxGradientDescent to use optax.scale_by_backtracking_linesearch

OptimistixOptaxGradientDescent is kept but is replaced in the tests. Ideally, it would be tested the same way as OptimistixNAG.
@BalzaniEdoardo, @billbrod is that fine with you, or is a way to include this in the tests without unnecessarily rerunning others required?

Fixes #380

@@ -0,0 +1,333 @@
"""Adaptation of JAXopt's ProximalGradient (FISTA) as an Optimistix IterativeSolver."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this as _fista_linesearch.py or something like that. True that the algorithm is inspired by the jaxopt implementation but you did change enough of it to match the optimistix API that we can omit the fact that it is a port.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it to _fista.py and rephrased the module's docstring, keeping "Adapted from JAXopt." to credit them.

**_common_solvers,
"GradientDescent": nmo.solvers.OptimistixOptaxGradientDescent,
"ProximalGradient": nmo.solvers.OptimistixOptaxProximalGradient,
# TODO: OptaxOptimistixGradientDescent is not tested
Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the default is using OptimistixNAG but we could still use the Optax GD if one wants to.

One way to test this is to use unittest mock to mock the solver registry in the test:

from unittest.mock import patch


# In your test
with patch.dict('nemos.solvers.solver_registry', 
                {'GradientDescent': nmo.solvers.OptimistixOptaxGradientDescent}):
    model = Model(solver_name='GradientDescent2', ...)
    model.instantiate_solver()
    
    assert isinstance(model._solver, nmo.solvers.OptimistixOptaxGradientDescent)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with that is that it would have to be done for every test that uses gradient descent.
I added a solution where only tests referencing "GradientDescent" explicitly are rerun with the OptimistixOptaxGradientDescent implementation and added it as a step to the CI job for the Optimistix backend.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this so that it is now controlled by pytest -o override_solver=GradientDescent:OptimistixOptaxGradientDescent and testing both implementations is done in the same tox environment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants