Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Optax-based LBFGS #749

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 165 additions & 6 deletions blackjax/optimizers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import jax.numpy as jnp
import jax.random
import jaxopt
import optax
import optax.tree_utils as otu
from jax import lax
from jax.flatten_util import ravel_pytree
from jaxopt._src.lbfgs import LbfgsState
Expand All @@ -37,6 +39,22 @@
MIN_STEP_SIZE = 1e-3


class _OptaxLBFGSHistory(NamedTuple):
x: Array
f: Array
g: Array
alpha: Array
update_mask: Array
# store intermediate values to perform checks
not_converged: Array
s: Array
z: Array
s_l: Array
z_l: Array
last: Array
iter: Array


class LBFGSHistory(NamedTuple):
"""Container for the optimization path of a L-BFGS run

Expand All @@ -60,6 +78,7 @@ class LBFGSHistory(NamedTuple):
g: Array
alpha: Array
update_mask: Array
not_converged: Array # for clipping history for shorter inverse hessian calcs and bfgs sampling


def minimize_lbfgs(
Expand Down Expand Up @@ -148,6 +167,143 @@ def minimize_lbfgs(
return last_step, history


def optax_lbfgs(
fun: Callable,
x0: Array,
maxiter: int,
maxcor: float,
gtol: float,
ftol: float,
maxls: int,
# **lbfgs_kwargs, # TODO: insert kwargs to optax.scale_by_zoom_linesearch and optax.value_and_grad_from_state
):
linesearch = optax.scale_by_zoom_linesearch(
max_linesearch_steps=maxls,
verbose=True,
)
solver = optax.lbfgs(
memory_size=maxcor,
linesearch=linesearch,
)
value_and_grad_fun = optax.value_and_grad_from_state(fun)

def lbfgs_one_step(carry, i):
# state is a 3-dim tuple
(params, state), previous_history = carry
value, grad = value_and_grad_fun(params, state=state)
updates, next_state = solver.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)

# ensure num_linesearch_steps is of the same type
info = next_state[2].info._replace(
num_linesearch_steps=jnp.asarray(
next_state[2].info.num_linesearch_steps, dtype=jnp.int32
)
)

next_state = (next_state[0], next_state[1], next_state[2]._replace(info=info))

# LBFGS use a rolling history, getting the correct index here.
iter = state[0].count
# last variable for getting the correct index where updates occur
last = jnp.max(jnp.array([iter - 1, 0], dtype=jnp.int32)) % maxcor
next_params = optax.apply_updates(params, updates)

# Recover alpha and update mask
s_l = next_state[0].diff_params_memory[last]
z_l = next_state[0].diff_updates_memory[last]
alpha_lm1 = previous_history.alpha
alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l)

# TODO: check correct calc for g
# g = next_state[2].grad
# g = state[2].grad
# g = grad
# g = previous_history.g
# g = previous_history.g + z_l
# g = state[2].grad + z_l

not_converged = check_convergence(state, next_state, iter)
history = _OptaxLBFGSHistory(
x=next_params,
f=next_state[2].value,
g=next_state[2].grad,
alpha=alpha_l,
update_mask=mask_l,
not_converged=not_converged,
s=next_state[0].diff_params_memory,
z=next_state[0].diff_updates_memory,
s_l=s_l,
z_l=z_l,
last=jnp.asarray(last, dtype=jnp.int32),
iter=jnp.asarray(iter, dtype=jnp.int32),
)
return ((next_params, next_state), history), not_converged

def check_convergence(state, next_state, iter):
f_delta = (
jnp.abs(state[2].value - next_state[2].value)
/ jnp.asarray(
[jnp.abs(state[2].value), jnp.abs(next_state[2].value), 1.0]
).max()
)
next_state_grad = otu.tree_get(next_state[2], "grad")
error = otu.tree_l2_norm(next_state_grad)
return jnp.array(
(iter == 0) | (error > gtol) & (f_delta > ftol) & (iter < maxiter),
dtype=bool,
)

def non_op(carry, i):
(params, state), previous_history = carry

info = state[2].info._replace(
num_linesearch_steps=jnp.asarray(
state[2].info.num_linesearch_steps, dtype=jnp.int32
)
)
state = (state[0], state[1], state[2]._replace(info=info))

return ((params, state), previous_history), jnp.array(False, dtype=bool)

def scan_body(tup, i):
carry, not_converged = tup
next_tup = jax.lax.cond(not_converged, lbfgs_one_step, non_op, carry, i)
return next_tup, next_tup[0][-1]

x0, init_state = (x0, solver.init(x0))
init_history = _OptaxLBFGSHistory(
x=init_state[0].params,
f=init_state[2].value,
g=init_state[2].grad,
alpha=jnp.ones_like(x0),
update_mask=jnp.zeros_like(x0, dtype=bool),
not_converged=jnp.array(True, dtype=bool),
s=init_state[0].diff_params_memory,
z=init_state[0].diff_updates_memory,
s_l=jnp.zeros_like(x0),
z_l=jnp.zeros_like(x0),
last=jnp.asarray(-1, dtype=jnp.int32),
iter=jnp.asarray(-1, dtype=jnp.int32),
)

# Use lax.scan to accumulate history
(((final_params, final_state), _), _), history = jax.lax.scan(
scan_body,
(((x0, init_state), init_history), True),
jnp.arange(maxiter),
length=maxiter,
)

history = jax.tree.map(
lambda x, y: jnp.concatenate([x[None, ...], y], axis=0),
init_history,
history,
)
return (final_params, final_state), history


def _minimize_lbfgs(
fun: Callable,
x0: Array,
Expand Down Expand Up @@ -181,19 +337,21 @@ def lbfgs_one_step(carry, i):
alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l)

current_grad = previous_history.g + z_l

# check convergence
f_delta = (
jnp.abs(state.value - next_state.value)
/ jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max()
)
not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter)
history = LBFGSHistory(
x=next_params,
f=next_state.value,
g=current_grad,
alpha=alpha_l,
update_mask=mask_l,
not_converged=jnp.array(not_converged, dtype=bool),
)
# check convergence
f_delta = (
jnp.abs(state.value - next_state.value)
/ jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max()
)
not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter)
return (OptStep(params=next_params, state=next_state), history), not_converged

def non_op(carry, it):
Expand Down Expand Up @@ -224,6 +382,7 @@ def scan_body(tup, it):
g=grad0,
alpha=jnp.ones_like(x0),
update_mask=jnp.zeros_like(x0, dtype=bool),
not_converged=jnp.array(True, dtype=bool),
)

((last_step, _), _), history = lax.scan(
Expand Down
13 changes: 6 additions & 7 deletions blackjax/vi/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def approximate(
**lbfgs_kwargs,
)

# Get postions and gradients of the optimization path (including the starting point).
# get the index where lbfgs converged
lbfgs_converged_idx = history.not_converged.sum()
# truncate history to the point of convergence
history = jax.tree.map(lambda x: x[:lbfgs_converged_idx], history)

position = history.x
grad_position = history.g
alpha = history.alpha
Expand Down Expand Up @@ -172,19 +176,14 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad):
# Index and reshape S and Z to be sliding window view shape=(maxiter,
# maxcor, param_dim), so we can vmap over all the iterations.
# This is in effect numpy.lib.stride_tricks.sliding_window_view
path_size = maxiter + 1
path_size = lbfgs_converged_idx
index = jnp.arange(path_size)[:, None] + jnp.arange(maxcor)[None, :]
s_j = s_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1)
z_j = z_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1)
rng_keys = jax.random.split(rng_key, path_size)
elbo, beta, gamma = jax.vmap(path_finder_body_fn)(
rng_keys, s_j, z_j, alpha, position, grad_position
)
elbo = jnp.where(
(jnp.arange(path_size) < (status.iter_num)) & jnp.isfinite(elbo),
elbo,
-jnp.inf,
)

unravel_fn_mapped = jax.vmap(unravel_fn)
pathfinder_result = PathfinderState(
Expand Down
58 changes: 57 additions & 1 deletion tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test optimizers."""

import functools

import chex
Expand All @@ -17,6 +18,7 @@
lbfgs_inverse_hessian_formula_2,
lbfgs_recover_alpha,
minimize_lbfgs,
optax_lbfgs,
)


Expand Down Expand Up @@ -154,5 +156,59 @@ def loss_fn(x):
np.testing.assert_allclose(inv_hess_1, inv_hess_2, rtol=0.01)


class TestOptaxLBFGS(chex.TestCase):
def test_optax_lbfgs(
self,
maxcor=6,
maxiter=1000,
ftol=1e-5,
gtol=1e-8,
maxls=1000,
):
"""Test the optax_lbfgs function for consistency in history and convergence."""

def example_fun(w):
return jnp.sum(100.0 * (w[1:] - w[:-1] ** 2) ** 2 + (1.0 - w[:-1]) ** 2)

x0_example = jnp.zeros((8,))

(final_params, final_state), history = optax_lbfgs(
example_fun,
x0_example,
maxcor=maxcor,
maxiter=maxiter,
ftol=ftol,
gtol=gtol,
maxls=maxls,
)

# test that the history is correct
L = history.iter.shape[0]

for l in range(1, L):
last = history.last[l]
current_s = history.s[l]
sml = jnp.delete(current_s, last, axis=0)

previous_s = history.s[l - 1]
previous_sml = jnp.delete(previous_s, last, axis=0)

np.testing.assert_allclose(
previous_sml,
sml,
err_msg=f"l = {l}, last = {last}, previous_sml = {previous_sml}, sml = {sml}",
)

# additional checks for convergence
expected_solution = jnp.ones((8,))
np.testing.assert_allclose(
final_params,
expected_solution,
rtol=1e-2,
err_msg="Final parameters did not converge to expected solution.",
)


if __name__ == "__main__":
absltest.main()
# absltest.main()
TestOptaxLBFGS().test_optax_lbfgs()
Loading