Skip to content

Commit

Permalink
lint the code
Browse files Browse the repository at this point in the history
  • Loading branch information
master committed Jul 30, 2023
1 parent a96283a commit cf2c710
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class OptaxHMCState(NamedTuple):
"""Optax state for the HMC integrator."""

count: Array
rng_key: PRNGKeyArray
momentum: PyTree
Expand All @@ -38,6 +39,7 @@ def hmc_integrator(
step_schedule: StepSchedule
A function that takes training step as input and returns the step size.
"""

def init_fn(params):
return OptaxHMCState(
count=jnp.zeros([], jnp.int32),
Expand Down Expand Up @@ -82,7 +84,7 @@ def mh_correction():
momentum, _ = jax.flatten_util.ravel_pytree(momentum)
kinetic = 0.5 * jnp.dot(momentum, momentum)
hamiltonian = kinetic + state.log_prob
accept_prob = jnp.minimum(1., jnp.exp(hamiltonian - state.hamiltonian))
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian))

def _accept():
empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params)
Expand Down

0 comments on commit cf2c710

Please sign in to comment.