Skip to content

Commit

Permalink
Fix optimizer freeze support in HMC
Browse files Browse the repository at this point in the history
  • Loading branch information
master committed Jul 30, 2023
1 parent c2643a5 commit 04d7a0a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
14 changes: 11 additions & 3 deletions fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
Array,
Batch,
)
from fortuna.utils.freeze import (
has_multiple_opt_state,
get_trainable_opt_state,
update_trainable_opt_state,
)


class HMCTrainer(MAPTrainer):
Expand All @@ -46,9 +51,12 @@ def training_step(
unravel=unravel,
**kwargs,
)
state = state.replace(
opt_state=state.opt_state._replace(log_prob=aux["loss"]),
)
if has_multiple_opt_state(state):
opt_state = get_trainable_opt_state(state)._replace(log_prob=aux["loss"])
state = update_trainable_opt_state(state, opt_state)
else:
opt_state = state.opt_state._replace(log_prob=aux["loss"])
state = state.replace(opt_state=opt_state)
return state, aux

def __str__(self):
Expand Down
63 changes: 63 additions & 0 deletions fortuna/utils/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from optax import (
multi_transform,
set_to_zero,
MultiTransformState,
)
from optax._src.wrappers import MaskedState

from fortuna.typing import (
AnyKey,
Expand All @@ -25,6 +27,8 @@
Params,
)

from fortuna.prob_model.posterior.state import PosteriorState


def all_values_in_labels(values: Iterable, labels: Any) -> None:
"""
Expand Down Expand Up @@ -79,6 +83,65 @@ def freeze_optimizer(
return multi_transform(partition_optimizers, partition_params)


def has_multiple_opt_state(state: PosteriorState):
"""
Check if a given posterior state containts multiple optimizer states.
Parameters
----------
state: PosteriorState
An instance of `PosteriorState`.
Returns
-------
bool
"""
return isinstance(state.opt_state, MultiTransformState)


def get_trainable_opt_state(state: PosteriorState):
"""
Get a trainable optimizer state.
Parameters
----------
state: PosteriorState
An instance of `PosteriorState`.
Returns
-------
opt_state: Any
An instance of trainable optimizer state.
"""
return state.opt_state.inner_states["trainable"].inner_state


def update_trainable_opt_state(state: PosteriorState, opt_state: Any):
"""
Update a trainable optimizer state.
Parameters
----------
state: PosteriorState
An instance of `PosteriorState`.
opt_state: Any
An instance of trainable optimizer state.
Returns
-------
PosteriorState
An updated posterior state.
"""
trainable_state = MaskedState(inner_state=opt_state)
opt_state = MultiTransformState(
inner_states={
k: (trainable_state if k == "trainable" else v)
for k, v in state.opt_state.inner_states.items()
}
)
return state.replace(opt_state=opt_state)


def get_trainable_paths(
params: Params,
freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]],
Expand Down
2 changes: 1 addition & 1 deletion tests/fortuna/prob_model/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def dryrun_task(task, method):
)
state = (
prob_model.posterior.state.get()
if method not in ["deep_ensemble", "sghmc", "cyclical_sgld"]
if method not in ["deep_ensemble", "sghmc", "cyclical_sgld", "hmc"]
else prob_model.posterior.state.get(-1)
)
model_editor_params = state.params["model_editor"]["params"].unfreeze()
Expand Down

0 comments on commit 04d7a0a

Please sign in to comment.