Skip to content

Commit

Permalink
FunMC: Extract the resampling logic into its own function, remove old…
Browse files Browse the repository at this point in the history
… `systematic_resample`.

PiperOrigin-RevId: 721110369
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jan 29, 2025
1 parent f3147cd commit 21e1c4c
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3377,7 +3377,7 @@
" if (self.auto_resample.value or self.resample) or not jnp.all(\n",
" jnp.isfinite(extra.target_log_prob)\n",
" ):\n",
" (_, _), ancestor_idx = fun_mc.systematic_resample(\n",
" (_, _), ancestor_idx = fun_mc.resample(\n",
" (),\n",
" resample_strength * extra.target_log_prob,\n",
" jax.random.key(self.step),\n",
Expand Down
55 changes: 0 additions & 55 deletions spinoffs/fun_mc/fun_mc/fun_mc_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@
'SimpleDualAveragesState',
'splitting_integrator_step',
'State',
'systematic_resample',
'trace',
'transform_log_prob_fn',
'TransitionOperator',
Expand Down Expand Up @@ -3466,60 +3465,6 @@ def clip_part(v):
)


@util.named_call
def systematic_resample(
particles: State,
log_weights: FloatArray,
seed: Any,
do_resample: Optional[BooleanArray] = None,
) -> tuple[tuple[State, FloatArray], IntArray]:
"""Systematically resamples particles in proportion to their weights.
This uses the algorithm from [1].
Args:
particles: The particles.
log_weights: Un-normalized weights.
seed: PRNG seed.
do_resample: Whether to perform the resample. If None, resampling is
performed unconditionally.
Returns:
particles_and_weights: tuple of resampled particles and weights.
ancestor_idx: Indices from which the returned particles were sampled from.
#### References
[1] Maskell, S., Alun-Jones, B., & Macleod, M. (2006). A Single Instruction
Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal
Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818
"""
log_weights = jnp.asarray(log_weights)
log_weights = jnp.where(
jnp.isnan(log_weights),
jnp.array(-float('inf'), log_weights.dtype),
log_weights,
)
probs = jax.nn.softmax(log_weights)
num_particles = probs.shape[0]

shift = util.random_uniform([], log_weights.dtype, seed)
pie = jnp.cumsum(probs) * num_particles + shift
repeats = jnp.array(util.diff(jnp.floor(pie), prepend=0), jnp.int32)
parent_idxs = util.repeat(
jnp.arange(num_particles), repeats, total_repeat_length=num_particles
)
if do_resample is not None:
parent_idxs = jnp.where(do_resample, parent_idxs, jnp.arange(num_particles))
new_particles = util.map_tree(lambda x: x[parent_idxs], particles)
new_log_weights = jnp.full(
log_weights.shape, tfp.math.reduce_logmeanexp(log_weights)
)
if do_resample is not None:
new_log_weights = jnp.where(do_resample, new_log_weights, log_weights)
return (new_particles, new_log_weights), parent_idxs


class GeometricAnnealingPathExtra(NamedTuple):
"""Extra outputs of `geometric_annealing_path`.
Expand Down
56 changes: 0 additions & 56 deletions spinoffs/fun_mc/fun_mc/fun_mc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,62 +2039,6 @@ def eval_fn(x):
self.assertAllCloseNested(value, fn(x))
self.assertAllCloseNested(expected_grad, grad)

def testSystematicResample(self):
probs = self._constant([0.0, 0.5, 0.2, 0.3, 0.0])
log_weights = jnp.log(probs)
particles = jnp.arange(probs.shape[0])

@jax.jit
def body(seed):
seed, resample_seed = util.split_seed(seed, 2)
(new_particles, new_log_weights), _ = fun_mc.systematic_resample(
particles, log_weights, resample_seed
)
return seed, (new_particles, new_log_weights)

_, (new_particles, new_log_weights) = fun_mc.trace(
self._make_seed(_test_seed()), body, 1000, trace_mask=(True, False)
)

new_particles_probs = jnp.mean(
jnp.array(new_particles[..., jnp.newaxis] == particles, jnp.float32),
(0, 1),
)

self.assertAllClose(new_particles_probs, probs, atol=0.05)
self.assertEqual(new_particles_probs[0], 0.0)
self.assertEqual(new_particles_probs[-1], 0.0)
self.assertAllClose(
new_log_weights,
jnp.full(probs.shape, tfp.math.reduce_logmeanexp(log_weights)),
)

def testSystematicResampleAncestors(self):
log_weights = self._constant([-float('inf'), 0.0])
particles = jnp.arange(log_weights.shape[0])
seed = self._make_seed(_test_seed())

(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
particles, log_weights, seed=seed
)
self.assertAllEqual(new_particles, jnp.ones_like(particles))
self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5])))
self.assertAllEqual(ancestors, jnp.ones_like(particles))

(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
particles, log_weights, do_resample=True, seed=seed
)
self.assertAllEqual(new_particles, jnp.ones_like(particles))
self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5])))
self.assertAllEqual(ancestors, jnp.ones_like(particles))

(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
particles, log_weights, do_resample=False, seed=seed
)
self.assertAllEqual(new_particles, particles)
self.assertAllEqual(new_log_weights, log_weights)
self.assertAllEqual(ancestors, particles)


@test_util.multi_backend_test(globals(), 'fun_mc_test')
class FunMCTest32(FunMCTest):
Expand Down
135 changes: 91 additions & 44 deletions spinoffs/fun_mc/fun_mc/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
'conditional_systematic_resampling',
'effective_sample_size_predicate',
'ParticleGatherFn',
'resample',
'ResamplingPredicate',
'SampleAncestorsFn',
'sequential_monte_carlo_init',
Expand Down Expand Up @@ -78,6 +79,8 @@ def systematic_resampling(
) -> Int[Array, 'num_particles']:
"""Generate parent indices via systematic resampling.
This uses the algorithm from [1].
Args:
log_weights: Unnormalized log-scale weights.
seed: PRNG seed.
Expand All @@ -87,13 +90,14 @@ def systematic_resampling(
Returns:
parent_idxs: parent indices such that the marginal probability that a
randomly chosen element will be `i` is equal to `softmax(log_weights)[i]`.
#### References
[1] Maskell, S., Alun-Jones, B., & Macleod, M. (2006). A Single Instruction
Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal
Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818
"""
shift_seed, permute_seed = util.split_seed(seed, 2)
log_weights = jnp.where(
jnp.isnan(log_weights),
jnp.array(-float('inf'), log_weights.dtype),
log_weights,
)
probs = jax.nn.softmax(log_weights)
# A common situation is all -inf log_weights that creats a NaN vector.
probs = jnp.where(
Expand Down Expand Up @@ -146,11 +150,6 @@ def conditional_systematic_resampling(
https://www.jstor.org/stable/43590414
"""
mixture_seed, shift_seed, permute_seed = util.split_seed(seed, 3)
log_weights = jnp.where(
jnp.isnan(log_weights),
jnp.array(-float('inf'), log_weights.dtype),
log_weights,
)
probs = jax.nn.softmax(log_weights)
num_particles = log_weights.shape[0]

Expand Down Expand Up @@ -377,7 +376,7 @@ def __call__(


@types.runtime_typed
def _defalt_pytree_gather(
def _default_pytree_gather(
state: State,
indices: Int[Array, 'num_particles'],
) -> State:
Expand All @@ -395,6 +394,75 @@ def _defalt_pytree_gather(
return util.map_tree(lambda x: x[indices], state)


@types.runtime_typed
def resample(
state: State,
log_weights: Float[Array, 'num_particles'],
seed: Seed,
do_resample: BoolScalar = True,
sample_ancestors_fn: SampleAncestorsFn = systematic_resampling,
state_gather_fn: ParticleGatherFn[State] = _default_pytree_gather,
) -> tuple[
tuple[State, Float[Array, 'num_particles']], Int[Array, 'num_particles']
]:
"""Possibly resamples state according to the log_weights.
The state should represent the same number of particles as implied by the
length of `log_weights`. If resampling occurs, the new log weights are
log-mean-exp of the incoming log weights. Otherwise, they are unchanged. By
default, this function performs systematic resampling.
Args:
state: The particles.
log_weights: Un-normalized log weights. NaN log weights are treated as -inf.
seed: Random seed.
do_resample: Whether to resample.
sample_ancestors_fn: Ancestor index sampling function.
state_gather_fn: State gather function.
Returns:
state_and_log_weights: tuple of the resampled state and log weights.
ancestor_idx: Indices that indicate which elements of the original state the
returned state particles were sampled from.
"""

def do_resample_fn(
state,
log_weights,
seed,
):
log_weights = jnp.where(
jnp.isnan(log_weights),
jnp.array(-float('inf'), log_weights.dtype),
log_weights,
)
ancestor_idxs = sample_ancestors_fn(log_weights, seed)
new_state = state_gather_fn(state, ancestor_idxs)
num_particles = log_weights.shape[0]
new_log_weights = jnp.full(
(num_particles,), tfp.math.reduce_logmeanexp(log_weights)
)
return (new_state, new_log_weights), ancestor_idxs

def dont_resample_fn(
state,
log_weights,
seed,
):
del seed
num_particles = log_weights.shape[0]
return (state, log_weights), jnp.arange(num_particles)

return _smart_cond(
do_resample,
do_resample_fn,
dont_resample_fn,
state,
log_weights,
seed,
)


@types.runtime_typed
def sequential_monte_carlo_init(
state: State,
Expand Down Expand Up @@ -430,7 +498,7 @@ def sequential_monte_carlo_step(
seed: Seed,
resampling_pred: ResamplingPredicate = effective_sample_size_predicate,
sample_ancestors_fn: SampleAncestorsFn = systematic_resampling,
state_gather_fn: ParticleGatherFn[State] = _defalt_pytree_gather,
state_gather_fn: ParticleGatherFn[State] = _default_pytree_gather,
) -> tuple[
SequentialMonteCarloState[State], SequentialMonteCarloExtra[State, Extra]
]:
Expand Down Expand Up @@ -461,43 +529,21 @@ def sequential_monte_carlo_step(
"""
resample_seed, kernel_seed = util.split_seed(seed, 2)

def do_resample(
state,
log_weights,
seed,
):
ancestor_idxs = sample_ancestors_fn(log_weights, seed)
new_state = state_gather_fn(state, ancestor_idxs)
num_particles = log_weights.shape[0]
new_log_weights = jnp.full(
(num_particles,), tfp.math.reduce_logmeanexp(log_weights)
)
return (new_state, ancestor_idxs, new_log_weights)

def dont_resample(
state,
log_weights,
seed,
):
del seed
num_particles = log_weights.shape[0]
return state, jnp.arange(num_particles), log_weights

# NOTE: We don't explicitly disable resampling at the first step. However, if
# we initialize the log weights to zeros, either of
# 1. resampling according to the effective sample size criterion and
# 2. using systematic resampling effectively disables resampling at the first
# step.
# First-step resampling can always be forced via the `resampling_pred`.
should_resample = resampling_pred(smc_state)
state_after_resampling, ancestor_idxs, log_weights_after_resampling = (
_smart_cond(
should_resample,
do_resample,
dont_resample,
smc_state.state,
smc_state.log_weights,
resample_seed,
do_resample = resampling_pred(smc_state)
(state_after_resampling, log_weights_after_resampling), ancestor_idxs = (
resample(
state=smc_state.state,
log_weights=smc_state.log_weights,
do_resample=do_resample,
seed=resample_seed,
sample_ancestors_fn=sample_ancestors_fn,
state_gather_fn=state_gather_fn,
)
)

Expand All @@ -516,7 +562,7 @@ def dont_resample(
smc_extra = SequentialMonteCarloExtra(
incremental_log_weights=incremental_log_weights,
kernel_extra=kernel_extra,
resampled=should_resample,
resampled=do_resample,
ancestor_idxs=ancestor_idxs,
state_after_resampling=state_after_resampling,
log_weights_after_resampling=log_weights_after_resampling,
Expand Down Expand Up @@ -711,6 +757,7 @@ def inner_kernel(state, stage, tlp_fn, seed):
)


@types.runtime_typed
def _smart_cond(
pred: BoolScalar,
true_fn: Callable[..., T],
Expand Down
27 changes: 27 additions & 0 deletions spinoffs/fun_mc/fun_mc/smc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,33 @@ def kernel(seed):
)
self.assertAllClose(rejection_freqs, conditional_freqs, atol=0.05)

def test_resample(self):
state = jnp.array([3, 2, 1, 0])
log_weights = jnp.array([-jnp.inf, float('NaN'), 1.0, 1.0], self._dtype)
seed = _test_seed()

(new_state, new_log_weights), ancestor_idxs = smc.resample(
state=state, log_weights=log_weights, seed=seed
)

self.assertAllTrue(new_state != 3)
self.assertAllTrue(new_state != 2)
self.assertAllTrue(~jnp.isnan(new_log_weights))
self.assertAllEqual(3 - new_state, ancestor_idxs)

def test_resample_but_dont(self):
state = jnp.array([3, 2, 1, 0])
log_weights = jnp.array([-jnp.inf, float('NaN'), 1.0, 1.0], self._dtype)
seed = _test_seed()

(new_state, new_log_weights), ancestor_idxs = smc.resample(
state=state, log_weights=log_weights, do_resample=False, seed=seed
)

self.assertAllEqual(new_state, state)
self.assertAllEqual(new_log_weights, log_weights)
self.assertAllEqual(ancestor_idxs, jnp.arange(state.shape[0]))

def test_smc_runs_and_shapes_correct(self):
num_particles = 3
num_timesteps = 20
Expand Down

0 comments on commit 21e1c4c

Please sign in to comment.