Skip to content

Latest commit

 

History

History
117 lines (85 loc) · 5.75 KB

README.md

File metadata and controls

117 lines (85 loc) · 5.75 KB

AdaHessian optimizer for Jax and Flax

Jax and Flax implementations of the AdaHessian optimizer, a second order optimizer for neural networks.

Installation

You can install this librarie with:

pip install git+https://github.com/nestordemeure/AdaHessianJax.git

Using AdaHessian with Jax

The implementation provides both a fast way to evaluate the diagonal of the hessian of a program and an optimizer API that stays close to jax.experimental.optimizers in the adahessianJax.jax namespace:

# gets the jax.experimental.optimizers version of the optimizer
from adahessianJax import grad_and_hessian
from adahessianJax.jaxOptimizer import adahessian

# builds an optimizer triplet, no need to pass a learning rate
opt_init, opt_update, get_params = adahessian()

# initialize the optimizer with the initial value of the parameters to optimize
opt_state = opt_init(init_params)
rng = jax.random.PRNGKey(0)

# uses the optimizer, note that we pass the gradient AND a hessian
params = get_params(opt_state)
rng, rng_step = jax.random.split(rng)
gradient, hessian = grad_and_hessian(loss, (params, batch), rng_step)
opt_state = opt_update(i, gradient, hessian, opt_state)

The example folder contains JAX's MNIST classification example updated to be run with Adam or AdaHessian in order to compare both implementations.

Using AdaHessian with Flax

The implementation provides both a fast way to evaluate the diagonal of the hessian of a program and an optimizer API that stays close to Flax optimizers in the adahessianJax.flax namespace:

# gets the flax version of the optimizer
from adahessianJax import grad_and_hessian
from adahessianJax.flaxOptimizer import Adahessian

# defines the optimizer, no need to pass a learning rate
optimizer_def = Adahessian()

# initialize the optimizer with the initial value of the parameters to optimize
optimizer = optimizer_def.create(init_params)
rng = jax.random.PRNGKey(0)

# uses the optimizer, note that we pass the gradient AND a hessian
params = optimizer.target
rng, rng_step = jax.random.split(rng)
gradient, hessian = grad_and_hessian(loss, (params, batch), rng_step)
optimizer = optimizer.apply_gradient(gradient, hessian)

Documentation

jax.adahessian

Argument Description
step_size (float, optional) learning rate (default: 1e-3)
b1(float, optional) the exponential decay rate for the first moment estimates (default: 0.9)
b2(float, optional) the exponential decay rate for the squared hessian estimates (default: 0.999)
eps (float, optional) term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) weight decay (L2 penalty) (default: 0.0)
hessian_power (float, optional) hessian power (default: 1.0)

Returns a (init_fun, update_fun, get_params) triple of functions modeling the optimizer, similarly to the jax.experimental.optimizers API. The only difference is that update_fun takes both a gradient and a hessian parameter.

flax.Adahessian

Argument Description
learning_rate (float, optional) learning rate (default: 1e-3)
beta1(float, optional) the exponential decay rate for the first moment estimates (default: 0.9)
beta2(float, optional) the exponential decay rate for the squared hessian estimates (default: 0.999)
eps (float, optional) term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) weight decay (L2 penalty) (default: 0.0)
hessian_power (float, optional) hessian power (default: 1.0)

Returns an optimizer definition, similarly to the Flax optimizers API. The only difference is that apply_gradient takes both a gradient and a hessian parameter.

grad_and_hessian

Argument Description
fun(Callable) function to be differentiated
fun_input(Tuple) value at which the gradient and hessian of fun should be evaluated
rng(ndarray) a PRNGKey used as the random key
argnum(int, optional) specifies which positional argument to differentiate with respect to (default: 0)

Returns a pair (gradient, hessian) where the first element is the gradient of fun evaluated in fun_input and the second element is the diagonal of its hessian.

This function expects fun_input to be a tuple and will fail otherwise. One can pass (fun_input,) if fun has a single input that is not already a tuple.

value_grad_and_hessian

Argument Description
fun(Callable) function to be differentiated
fun_input(Tuple) value at which the gradient and hessian of fun should be evaluated
rng(ndarray) a PRNGKey used as the random key
argnum(int, optional) specifies which positional argument to differentiate with respect to (default: 0)

Returns a triplet (value, gradient, hessian) where the first element is the value of fun evaluated in fun_input, the second element is its gradient and the third element is the diagonal of its hessian.

This function expects fun_input to be a tuple and will fail otherwise. One can pass (fun_input,) if fun has a single input that is not already a tuple.