Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jul 2, 2023
1 parent 902a2a1 commit 3c83290
Show file tree
Hide file tree
Showing 29 changed files with 619 additions and 347 deletions.
10 changes: 8 additions & 2 deletions benchmarks/transformers/prob_model_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,17 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:

model_editor = None
if args.enable_probit_model_editor:
probit_freeze_fun = lambda p, v: True if "classifier" in p else False if args.probit_last_layer_only else None
probit_freeze_fun = (
lambda p, v: True
if "classifier" in p
else False
if args.probit_last_layer_only
else None
)
model_editor = ProbitModelEditor(
freeze_fun=probit_freeze_fun,
init_log_var=args.probit_init_log_var,
stop_gradient=args.probit_stop_gradient
stop_gradient=args.probit_stop_gradient,
)

### TRAINING
Expand Down
4 changes: 1 addition & 3 deletions fortuna/calib_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def load_state(self, checkpoint_dir: Path) -> None:
)
self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_dir)

def save_state(
self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1
) -> None:
def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None:
return self.predictive.state.put(
self.predictive.state.get(),
checkpoint_dir=checkpoint_dir,
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/calib_model_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from optax._src.base import PyTree

from fortuna.calib_model.state import CalibState
from fortuna.training.trainer import TrainerABC
from fortuna.training.mixins.jitted import JittedMixin
from fortuna.training.mixins.multi_device import MultiDeviceMixin
from fortuna.training.trainer import TrainerABC
from fortuna.typing import (
Array,
Batch,
Expand Down
1 change: 1 addition & 0 deletions fortuna/likelihood/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Tuple,
Union,
)

from flax.core import FrozenDict
from jax import (
jit,
Expand Down
Loading

0 comments on commit 3c83290

Please sign in to comment.