From 3f0cbb7956765f1c25e0785beb1b6ad8a7092038 Mon Sep 17 00:00:00 2001 From: Hugo Simon-Onfroy <85559558+hsimonfroy@users.noreply.github.com> Date: Wed, 19 Feb 2025 00:14:35 +0100 Subject: [PATCH] MCLMC adaptation total num steps and initial guess (#778) * total_num_tuning_integrator_steps * Initial params for MCLMC adaptation --- blackjax/adaptation/mclmc_adaptation.py | 34 +++++++++++++++++-------- blackjax/diagnostics.py | 3 +++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 60fd46359..fa644898a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -51,6 +51,7 @@ def mclmc_find_L_and_step_size( trust_in_estimate=1.5, num_effective_samples=150, diagonal_preconditioning=True, + params=None, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -79,6 +80,8 @@ def mclmc_find_L_and_step_size( The number of effective samples for the MCMC algorithm. diagonal_preconditioning Whether to do diagonal preconditioning (i.e. a mass matrix) + params + Initial params to start tuning from (optional) Returns ------- @@ -105,10 +108,19 @@ def mclmc_find_L_and_step_size( ) """ dim = pytree_size(state.position) - params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) - ) + if params is None: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) + ) + part1_key, part2_key = jax.random.split(rng_key, 2) + total_num_tuning_integrator_steps = 0 + + num_steps1, num_steps2 = round(num_steps * frac_tune1), round( + num_steps * frac_tune2 + ) + num_steps2 += diagonal_preconditioning * (num_steps2 // 3) + num_steps3 = round(num_steps * frac_tune3) state, params = make_L_step_size_adaptation( kernel=mclmc_kernel, @@ -120,13 +132,15 @@ def mclmc_find_L_and_step_size( num_effective_samples=num_effective_samples, diagonal_preconditioning=diagonal_preconditioning, )(state, params, num_steps, part1_key) + total_num_tuning_integrator_steps += num_steps1 + num_steps2 - if frac_tune3 != 0: + if num_steps3 >= 2: # at least 2 samples for ESS estimation state, params = make_adaptation_L( mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) + total_num_tuning_integrator_steps += num_steps3 - return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3) + return state, params, total_num_tuning_integrator_steps def make_L_step_size_adaptation( @@ -225,10 +239,10 @@ def step(iteration_state, weight_and_key): )[0] def L_step_size_adaptation(state, params, num_steps, rng_key): - num_steps1, num_steps2 = ( - int(num_steps * frac_tune1) + 1, - int(num_steps * frac_tune2) + 1, + num_steps1, num_steps2 = round(num_steps * frac_tune1), round( + num_steps * frac_tune2 ) + L_step_size_adaptation_keys = jax.random.split( rng_key, num_steps1 + num_steps2 + 1 ) @@ -259,7 +273,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = jnp.sqrt(dim) # readjust the stepsize - steps = num_steps2 // 3 # we do some small number of steps + steps = round(num_steps2 / 3) # we do some small number of steps keys = jax.random.split(final_key, steps) state, params, _, (_, average) = run_steps( xs=(jnp.ones(steps), keys), state=state, params=params @@ -274,7 +288,7 @@ def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" def adaptation_L(state, params, num_steps, key): - num_steps_3 = int(num_steps * frac) + num_steps_3 = round(num_steps * frac) adaptation_L_keys = jax.random.split(key, num_steps_3) def step(state, key): diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index 93480302e..257ce759c 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -115,6 +115,9 @@ def effective_sample_size( sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis num_chains = input_shape[chain_axis] num_samples = input_shape[sample_axis] + assert ( + num_samples > 1 + ), f"The input array must have at least 2 samples, got only {num_samples}." mean_across_chain = input_array.mean(axis=sample_axis, keepdims=True) # Compute autocovariance estimates for every lag for the input array using FFT.