Skip to content

Commit

Permalink
MCLMC adaptation total num steps and initial guess (#778)
Browse files Browse the repository at this point in the history
* total_num_tuning_integrator_steps

* Initial params for MCLMC adaptation
  • Loading branch information
hsimonfroy authored Feb 18, 2025
1 parent a053bed commit 3f0cbb7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
34 changes: 24 additions & 10 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def mclmc_find_L_and_step_size(
trust_in_estimate=1.5,
num_effective_samples=150,
diagonal_preconditioning=True,
params=None,
):
"""
Finds the optimal value of the parameters for the MCLMC algorithm.
Expand Down Expand Up @@ -79,6 +80,8 @@ def mclmc_find_L_and_step_size(
The number of effective samples for the MCMC algorithm.
diagonal_preconditioning
Whether to do diagonal preconditioning (i.e. a mass matrix)
params
Initial params to start tuning from (optional)
Returns
-------
Expand All @@ -105,10 +108,19 @@ def mclmc_find_L_and_step_size(
)
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,))
)
if params is None:
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,))
)

part1_key, part2_key = jax.random.split(rng_key, 2)
total_num_tuning_integrator_steps = 0

num_steps1, num_steps2 = round(num_steps * frac_tune1), round(
num_steps * frac_tune2
)
num_steps2 += diagonal_preconditioning * (num_steps2 // 3)
num_steps3 = round(num_steps * frac_tune3)

state, params = make_L_step_size_adaptation(
kernel=mclmc_kernel,
Expand All @@ -120,13 +132,15 @@ def mclmc_find_L_and_step_size(
num_effective_samples=num_effective_samples,
diagonal_preconditioning=diagonal_preconditioning,
)(state, params, num_steps, part1_key)
total_num_tuning_integrator_steps += num_steps1 + num_steps2

if frac_tune3 != 0:
if num_steps3 >= 2: # at least 2 samples for ESS estimation
state, params = make_adaptation_L(
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)
total_num_tuning_integrator_steps += num_steps3

return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3)
return state, params, total_num_tuning_integrator_steps


def make_L_step_size_adaptation(
Expand Down Expand Up @@ -225,10 +239,10 @@ def step(iteration_state, weight_and_key):
)[0]

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = (
int(num_steps * frac_tune1) + 1,
int(num_steps * frac_tune2) + 1,
num_steps1, num_steps2 = round(num_steps * frac_tune1), round(
num_steps * frac_tune2
)

L_step_size_adaptation_keys = jax.random.split(
rng_key, num_steps1 + num_steps2 + 1
)
Expand Down Expand Up @@ -259,7 +273,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L = jnp.sqrt(dim)

# readjust the stepsize
steps = num_steps2 // 3 # we do some small number of steps
steps = round(num_steps2 / 3) # we do some small number of steps
keys = jax.random.split(final_key, steps)
state, params, _, (_, average) = run_steps(
xs=(jnp.ones(steps), keys), state=state, params=params
Expand All @@ -274,7 +288,7 @@ def make_adaptation_L(kernel, frac, Lfactor):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):
num_steps_3 = int(num_steps * frac)
num_steps_3 = round(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps_3)

def step(state, key):
Expand Down
3 changes: 3 additions & 0 deletions blackjax/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def effective_sample_size(
sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis
num_chains = input_shape[chain_axis]
num_samples = input_shape[sample_axis]
assert (
num_samples > 1
), f"The input array must have at least 2 samples, got only {num_samples}."

mean_across_chain = input_array.mean(axis=sample_axis, keepdims=True)
# Compute autocovariance estimates for every lag for the input array using FFT.
Expand Down

0 comments on commit 3f0cbb7

Please sign in to comment.