diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py index 8755d5c7..f4c09a0f 100644 --- a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py @@ -14,6 +14,7 @@ class OptaxHMCState(NamedTuple): """Optax state for the HMC integrator.""" + count: Array rng_key: PRNGKeyArray momentum: PyTree @@ -38,6 +39,7 @@ def hmc_integrator( step_schedule: StepSchedule A function that takes training step as input and returns the step size. """ + def init_fn(params): return OptaxHMCState( count=jnp.zeros([], jnp.int32), @@ -82,7 +84,7 @@ def mh_correction(): momentum, _ = jax.flatten_util.ravel_pytree(momentum) kinetic = 0.5 * jnp.dot(momentum, momentum) hamiltonian = kinetic + state.log_prob - accept_prob = jnp.minimum(1., jnp.exp(hamiltonian - state.hamiltonian)) + accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian)) def _accept(): empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params)