Skip to content

Replace OptaxOptimistixProximalGradient with a port of jaxopt.ProximalGradient #380

@bagibence

Description

@bagibence

As Optimistix doesn't implement ProximalGradient, this is currently set to use a custom implementation in OptaxOptimistixProximalGradient based on Optax's SGD with linesearch followed by the proximal operator.
This seems to work in practice, but is not as theoretically sound as jaxopt.ProximalGradient which implements the FISTA algorithm.

To be able to eventually stop vendoring (and drop support for) JAXopt, we decided to add a port of JAXopt's ProximalGradient that adheres to Optimistix's interface, and replace OptaxOptimistixProximalGradient with this when using the Optimistix backend.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions