Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMC: Joint tuning and pretuning #776

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions blackjax/smc/inner_kernel_tuning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable, Dict, NamedTuple, Tuple

import jax

from blackjax.base import SamplingAlgorithm
from blackjax.smc.base import SMCInfo, SMCState
from blackjax.types import ArrayTree, PRNGKey
Expand Down Expand Up @@ -28,8 +30,11 @@ def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]],
mcmc_parameter_update_fn: Callable[
[PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree]
],
num_mcmc_steps: int = 10,
smc_returns_state_with_parameter_override=False,
**extra_parameters,
) -> Callable:
"""In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner
Expand All @@ -40,7 +45,8 @@ def build_kernel(
----------
smc_algorithm
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
a sampling algorithm that returns an SMCState and SMCInfo pair).
a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this
to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True
logprior_fn
A function that computes the log density of the prior distribution
loglikelihood_fn
Expand All @@ -54,7 +60,30 @@ def build_kernel(
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.
extra_parameters:
parameters to be used for the creation of the smc_algorithm.
smc_returns_state_with_parameter_override:
a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override.
this is used in order to compose different adaptation mechanisms, such as pretuning with tuning.
"""
if smc_returns_state_with_parameter_override:

def extract_state_for_delegate(state):
return state

def compose_new_state(new_state, new_parameter_override):
composed_parameter_override = (
new_state.parameter_override | new_parameter_override
)
return StateWithParameterOverride(
new_state.sampler_state, composed_parameter_override
)

else:

def extract_state_for_delegate(state):
return state.sampler_state

def compose_new_state(new_state, new_parameter_override):
return StateWithParameterOverride(new_state, new_parameter_override)

def kernel(
rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters
Expand All @@ -69,9 +98,14 @@ def kernel(
num_mcmc_steps=num_mcmc_steps,
**extra_parameters,
).step
new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters)
new_parameter_override = mcmc_parameter_update_fn(new_state, info)
return StateWithParameterOverride(new_state, new_parameter_override), info
parameter_update_key, step_key = jax.random.split(rng_key, 2)
new_state, info = step_fn(
step_key, extract_state_for_delegate(state), **extra_step_parameters
)
new_parameter_override = mcmc_parameter_update_fn(
parameter_update_key, new_state, info
)
return compose_new_state(new_state, new_parameter_override), info

return kernel

Expand All @@ -83,9 +117,12 @@ def as_top_level_api(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]],
mcmc_parameter_update_fn: Callable[
[PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree]
],
initial_parameter_value,
num_mcmc_steps: int = 10,
smc_returns_state_with_parameter_override=False,
**extra_parameters,
) -> SamplingAlgorithm:
"""In the context of an SMC sampler (whose step_fn returning state
Expand Down Expand Up @@ -130,6 +167,7 @@ def as_top_level_api(
resampling_fn,
mcmc_parameter_update_fn,
num_mcmc_steps,
smc_returns_state_with_parameter_override,
**extra_parameters,
)

Expand Down
12 changes: 9 additions & 3 deletions blackjax/smc/pretuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,21 @@ def update_parameter_distribution(
)


def default_measure_factory(state):
inverse_mass_matrix = state.parameter_override["inverse_mass_matrix"]
if not (len(inverse_mass_matrix.shape) == 3 and inverse_mass_matrix.shape[0] == 1):
raise ValueError("ESJD only works if chains share the inverse_mass_matrix.")

return esjd(inverse_mass_matrix[0])


def build_pretune(
mcmc_init_fn: Callable,
mcmc_step_fn: Callable,
alpha: float,
sigma_parameters: ArrayLikeTree,
n_particles: int,
performance_of_chain_measure_factory: Callable = lambda state: esjd(
state.parameter_override["inverse_mass_matrix"]
),
performance_of_chain_measure_factory: Callable = default_measure_factory,
natural_parameters: Optional[List[str]] = None,
positive_parameters: Optional[List[str]] = None,
):
Expand Down
12 changes: 5 additions & 7 deletions blackjax/smc/tuning/from_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"particles_means",
"particles_stds",
"particles_covariance_matrix",
"mass_matrix_from_particles",
"inverse_mass_matrix_from_particles",
]


Expand All @@ -28,18 +28,16 @@ def particles_covariance_matrix(particles):
return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False)


def mass_matrix_from_particles(particles) -> Array:
def inverse_mass_matrix_from_particles(particles) -> Array:
"""
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf
Computing a mass matrix to be used in HMC from particles.
Given the particles covariance matrix, set all non-diagonal elements as zero,
take the inverse, and keep the diagonal.
Computing an inverse mass matrix to be used in HMC from particles.

Returns
-------
A mass Matrix
An inverse mass matrix
"""
return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0))
return jnp.diag(jnp.var(particles_as_rows(particles), axis=0))


def particles_as_rows(particles):
Expand Down
Loading