diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py index e53ac8c4..d90f807e 100644 --- a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py @@ -24,6 +24,11 @@ Array, Batch, ) +from fortuna.utils.freeze import ( + has_multiple_opt_state, + get_trainable_opt_state, + update_trainable_opt_state, +) class HMCTrainer(MAPTrainer): @@ -46,9 +51,12 @@ def training_step( unravel=unravel, **kwargs, ) - state = state.replace( - opt_state=state.opt_state._replace(log_prob=aux["loss"]), - ) + if has_multiple_opt_state(state): + opt_state = get_trainable_opt_state(state)._replace(log_prob=aux["loss"]) + state = update_trainable_opt_state(state, opt_state) + else: + opt_state = state.opt_state._replace(log_prob=aux["loss"]) + state = state.replace(opt_state=opt_state) return state, aux def __str__(self): diff --git a/fortuna/utils/freeze.py b/fortuna/utils/freeze.py index a5c3c8f8..31501a87 100644 --- a/fortuna/utils/freeze.py +++ b/fortuna/utils/freeze.py @@ -18,7 +18,9 @@ from optax import ( multi_transform, set_to_zero, + MultiTransformState, ) +from optax._src.wrappers import MaskedState from fortuna.typing import ( AnyKey, @@ -27,6 +29,8 @@ Params, ) +from fortuna.prob_model.posterior.state import PosteriorState + def all_values_in_labels(values: Iterable, labels: Any) -> None: """ @@ -81,6 +85,65 @@ def freeze_optimizer( return multi_transform(partition_optimizers, partition_params) +def has_multiple_opt_state(state: PosteriorState): + """ + Check if a given posterior state containts multiple optimizer states. + + Parameters + ---------- + state: PosteriorState + An instance of `PosteriorState`. + + Returns + ------- + bool + """ + return isinstance(state.opt_state, MultiTransformState) + + +def get_trainable_opt_state(state: PosteriorState): + """ + Get a trainable optimizer state. + + Parameters + ---------- + state: PosteriorState + An instance of `PosteriorState`. + + Returns + ------- + opt_state: Any + An instance of trainable optimizer state. + """ + return state.opt_state.inner_states["trainable"].inner_state + + +def update_trainable_opt_state(state: PosteriorState, opt_state: Any): + """ + Update a trainable optimizer state. + + Parameters + ---------- + state: PosteriorState + An instance of `PosteriorState`. + opt_state: Any + An instance of trainable optimizer state. + + Returns + ------- + PosteriorState + An updated posterior state. + """ + trainable_state = MaskedState(inner_state=opt_state) + opt_state = MultiTransformState( + inner_states={ + k: (trainable_state if k == "trainable" else v) + for k, v in state.opt_state.inner_states.items() + } + ) + return state.replace(opt_state=opt_state) + + def get_trainable_paths( params: Params, freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]], diff --git a/tests/fortuna/prob_model/test_train.py b/tests/fortuna/prob_model/test_train.py index a02ffb55..8da92759 100755 --- a/tests/fortuna/prob_model/test_train.py +++ b/tests/fortuna/prob_model/test_train.py @@ -386,7 +386,7 @@ def dryrun_task(task, method): ) state = ( prob_model.posterior.state.get() - if method not in ["deep_ensemble", "sghmc", "cyclical_sgld"] + if method not in ["deep_ensemble", "sghmc", "cyclical_sgld", "hmc"] else prob_model.posterior.state.get(-1) ) model_editor_params = state.params["model_editor"]["params"].unfreeze()