diff --git a/benchmarks/transformers/masked_language_modeling.py b/benchmarks/transformers/masked_language_modeling.py index 066d2303..539b0156 100644 --- a/benchmarks/transformers/masked_language_modeling.py +++ b/benchmarks/transformers/masked_language_modeling.py @@ -193,13 +193,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: try: logger.info(list(pathlib.Path(args.restore_checkpoint_dir).rglob("*"))) - restore_checkpoint_path = unpack_model_tar( + restore_checkpoint_dir = unpack_model_tar( list(pathlib.Path(args.restore_checkpoint_dir).rglob("*"))[0] ) - logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*"))) + logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*"))) except: logger.info("No checkpoint to restore") - restore_checkpoint_path = None + restore_checkpoint_dir = None tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) @@ -303,11 +303,11 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: #### TRAIN! #### ##################################### def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray: - if preds.ndim > 2: + if preds.ndim > 1: raise ValueError( """`preds` must be a one-dimensional array of predicted classes.""" ) - if targets.ndim > 2: + if targets.ndim > 1: raise ValueError( """`targets` must be a one-dimensional array of target classes.""" ) @@ -341,7 +341,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray: save_checkpoint_dir=args.save_checkpoint_dir, save_every_n_steps=args.save_every_n_steps, keep_top_n_checkpoints=args.keep_top_n_checkpoints, - restore_checkpoint_path=restore_checkpoint_path, + restore_checkpoint_dir=restore_checkpoint_dir, ), ) if args.last_layer_only and ( @@ -357,7 +357,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray: and args.last_layer_only else None, ) - if restore_checkpoint_path is not None: + if restore_checkpoint_dir is not None: fit_config.optimizer = last_layer_optimizer train_kwargs = {"fit_config": fit_config} else: diff --git a/benchmarks/transformers/prob_model_text_classification.py b/benchmarks/transformers/prob_model_text_classification.py index 20e1a72b..59e747ff 100644 --- a/benchmarks/transformers/prob_model_text_classification.py +++ b/benchmarks/transformers/prob_model_text_classification.py @@ -36,6 +36,7 @@ accuracy, expected_calibration_error, ) +from fortuna.model_editor import ProbitModelEditor from fortuna.prob_model import ( ADVIPosteriorApproximator, DeepEnsemblePosteriorApproximator, @@ -213,6 +214,11 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: parser.add_argument("--sgmcmc_polynomial_schedule_gamma", type=float, default=0.55) parser.add_argument("--sgmcmc_preconditioner", type=strbool, default=False) parser.add_argument("--sghmc_momentum_decay", type=float, default=0.01) + # model editor + parser.add_argument("--enable_probit_model_editor", type=strbool, default=False) + parser.add_argument("--probit_init_log_var", type=float, default=-5) + parser.add_argument("--probit_stop_gradient", type=strbool, default=False) + parser.add_argument("--probit_last_layer_only", type=strbool, default=False) # optimizer parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--adam_eps", type=float, default=1e-8) @@ -234,13 +240,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: try: logger.info(list(pathlib.Path(args.load_model_dir).rglob("*"))) - restore_checkpoint_path = unpack_model_tar( + restore_checkpoint_dir = unpack_model_tar( list(pathlib.Path(args.load_model_dir).rglob("*"))[0] ) - logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*"))) + logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*"))) except: logger.info("No checkpoint to restore") - restore_checkpoint_path = None + restore_checkpoint_dir = None tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) @@ -392,6 +398,21 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: ), } + model_editor = None + if args.enable_probit_model_editor: + probit_freeze_fun = ( + lambda p, v: True + if "classifier" in p + else False + if args.probit_last_layer_only + else None + ) + model_editor = ProbitModelEditor( + freeze_fun=probit_freeze_fun, + init_log_var=args.probit_init_log_var, + stop_gradient=args.probit_stop_gradient, + ) + ### TRAINING prob_model = ProbClassifier( model=model, @@ -400,6 +421,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: ], prior=IsotropicGaussianPrior(log_var=args.prior_log_var), output_calibrator=None, + model_editor=model_editor ) fit_config = FitConfig( @@ -422,7 +444,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: save_checkpoint_dir=args.output_data_dir, save_every_n_steps=args.save_every_n_steps, keep_top_n_checkpoints=args.keep_top_n_checkpoints, - restore_checkpoint_path=restore_checkpoint_path, + restore_checkpoint_dir=restore_checkpoint_dir, ), callbacks=[ ResetCovarianceCallback( @@ -453,7 +475,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: last_layer_optimizer = FitOptimizer( method=optimizer, n_epochs=args.num_train_epochs, freeze_fun=freeze_fun ) - if restore_checkpoint_path is not None: + if restore_checkpoint_dir is not None: fit_config.optimizer = last_layer_optimizer train_kwargs = {"fit_config": fit_config} else: @@ -478,11 +500,16 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: calib_data_loader=None, **train_kwargs, ) - elif restore_checkpoint_path is not None: - prob_model.load_state(restore_checkpoint_path) + elif restore_checkpoint_dir is not None: + prob_model.load_state(restore_checkpoint_dir) else: raise ValueError( - "Either restore_checkpoint_path or num_train_epochs > 0 should be specified." + "Either restore_checkpoint_dir or num_train_epochs > 0 should be specified." + ) + + if args.enable_probit_model_editor: + logger.info( + f"Probit log-variance: {prob_model.posterior.state.get().params['model_editor']['params']['log_var']}" ) ### IN-D PERFORMANCE diff --git a/benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/default.yaml b/benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/default.yaml deleted file mode 100644 index 5a1d5526..00000000 --- a/benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/default.yaml +++ /dev/null @@ -1,45 +0,0 @@ -defaults: - - task/sentiment - - model/roberta - - method/sgmcmc_ll - - hyperparams/sghmc_ll - -dataset: - base_data_path: ~ - train_relative_path: "" - test_relative_path: "" - validation_relative_path: "" - - -model: - hparams: - tokenizer_max_length: 512 - max_grad_norm: 1 - adam_eps: 0.00000001 - adam_b2: 0.999 - gradient_checkpointing: "true" - save_every_n_steps: 20000 - keep_top_n_checkpoints: 1 - seed: 42 - disable_jit: False - devices: -1 - -sagemaker: - account_id: ~ - iam_role: ~ - entrypoint: "benchmarks/transformers/prob_model_text_classification.py" - instance_type: "ml.g5.2xlarge" - profile: "default" - region: "us-east-1" - job_name_suffix: ~ - metrics: - - {Name: "train_loss_step", Regex: 'loss: ([-+]?(\d+(\.\d*)?|\.\d+))'} - - {Name: "train_accuracy_step", Regex: 'accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} - - {Name: "val_loss", Regex: 'val_loss: ([-+]?(\d+(\.\d*)?|\.\d+))'} - - {Name: "val_accuracy", Regex: 'val_accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} - - {Name: "ind_accuracy", Regex: 'IND Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} - - {Name: "ind_ece", Regex: 'IND ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'} - - {Name: "ood_accuracy", Regex: 'OOD Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} - - {Name: "ood_ece", Regex: 'OOD ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'} - -output_data_path: ~ diff --git a/benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/model/bert.yaml b/benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/model/bert.yaml index c194c33a..f2705d42 100644 --- a/benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/model/bert.yaml +++ b/benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/model/bert.yaml @@ -4,6 +4,6 @@ hparams: per_device_eval_batch_size: 32 per_device_train_batch_size: 32 learning_rate: 2e-05 - num_warmup_steps: 10000 + num_warmup_steps: 500 prior_log_var: 100.0 weight_decay: 0.01 diff --git a/docs/source/references/output_calib_model/output_calib_model.rst b/docs/source/references/output_calib_model/output_calib_model.rst index 51334698..165a4019 100644 --- a/docs/source/references/output_calib_model/output_calib_model.rst +++ b/docs/source/references/output_calib_model/output_calib_model.rst @@ -8,19 +8,19 @@ Please find their references below. .. automodule:: fortuna.output_calib_model.classification :members: - :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint + :exclude-members: save_checkpoint, restore_checkpoint .. _output_calib_regressor: .. automodule:: fortuna.output_calib_model.regression :members: - :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint + :exclude-members: save_checkpoint, restore_checkpoint .. _output_calib_base: .. automodule:: fortuna.output_calib_model.base :members: - :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint + :exclude-members: save_checkpoint, restore_checkpoint .. toctree:: :maxdepth: 1 diff --git a/examples/scaling_up_bayesian_inference.pct.py b/examples/scaling_up_bayesian_inference.pct.py index 4cdb2952..3c4b2e3b 100644 --- a/examples/scaling_up_bayesian_inference.pct.py +++ b/examples/scaling_up_bayesian_inference.pct.py @@ -89,7 +89,7 @@ def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: # We are ready to call `prob_model.train`, which will perform posterior inference under-the-hood. In order to do Bayesian inference on the last layer only and freeze the other parameters, all we need to do is to pass a function `freeze_fun` to the optimizer configuration object, deciding which parameters should be "frozen" and which should be "trainable". # -# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_path`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`. +# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_dir`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`. from fortuna.prob_model import FitConfig, FitOptimizer diff --git a/fortuna/calib_model/base.py b/fortuna/calib_model/base.py index 3c5a379a..bec79417 100644 --- a/fortuna/calib_model/base.py +++ b/fortuna/calib_model/base.py @@ -3,16 +3,17 @@ from typing import ( Callable, Optional, + Tuple, + Type ) from flax.core import FrozenDict import jax.numpy as jnp - -from fortuna.calib_model.calib_mixin import WithCalibCheckpointingMixin +from jax import eval_shape +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.calib_model.calib_model_calibrator import ( CalibModelCalibrator, - JittedCalibModelCalibrator, - MultiDeviceCalibModelCalibrator, + ShardedCalibModelCalibrator, ) from fortuna.calib_model.calib_state_repository import CalibStateRepository from fortuna.calib_model.config.base import Config @@ -29,11 +30,11 @@ Targets, Uncertainties, ) -from fortuna.utils.data import ( - get_input_shape, - get_inputs_from_shape, -) -from fortuna.utils.device import select_trainer_given_devices +import pathlib +from jax._src.prng import PRNGKeyArray +from orbax.checkpoint import CheckpointManager +from fortuna.utils.checkpoint import get_checkpoint_manager +from fortuna.utils.data import get_inputs_from_shape from fortuna.utils.freeze import get_trainable_paths from fortuna.utils.nested_dicts import ( nested_get, @@ -42,7 +43,7 @@ from fortuna.utils.random import RandomNumberGenerator -class CalibModel(WithCalibCheckpointingMixin, abc.ABC): +class CalibModel(abc.ABC): def __init__(self, seed: int = 0): """ A calibration model. @@ -78,16 +79,17 @@ def _calibrate( "`save_checkpoint_dir` must be passed when `dump_state` is set to True." ) - trainer_cls = select_trainer_given_devices( - devices=config.processor.devices, - base_trainer_cls=CalibModelCalibrator, - jitted_trainer_cls=JittedCalibModelCalibrator, - multi_device_trainer_cls=MultiDeviceCalibModelCalibrator, - disable_jit=config.processor.disable_jit, + trainer_cls = ( + ShardedCalibModelCalibrator if not config.processor.disable_jit else CalibModelCalibrator ) trainer = trainer_cls( predict_fn=self.prob_output_layer.predict, + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + config.checkpointer.save_checkpoint_dir, + keep_top_n_checkpoints=config.checkpointer.keep_top_n_checkpoints, + ), uncertainty_fn=uncertainty_fn, save_checkpoint_dir=config.checkpointer.save_checkpoint_dir, save_every_n_steps=config.checkpointer.save_every_n_steps, @@ -100,28 +102,38 @@ def _calibrate( freeze_fun=config.optimizer.freeze_fun, ) - state = self._init_state(calib_data_loader, config) + checkpoint_restorer = ( + get_checkpoint_manager( + str( + pathlib.Path(config.checkpointer.restore_checkpoint_dir) + / config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=config.checkpointer.keep_top_n_checkpoints, + ) + if config.checkpointer.restore_checkpoint_dir is not None + else None + ) - if config.optimizer.freeze_fun is not None: - trainable_paths = get_trainable_paths( - state.params, config.optimizer.freeze_fun + if self._is_state_available_somewhere(config): + state = self._restore_state_from_somewhere( + config=config, + allowed_states=(CalibState,), + checkpoint_manager=checkpoint_restorer, + partition_manager=self.partition_manager, ) - state = state.replace( - opt_state=config.optimizer.method.init( - FrozenDict( - nested_set( - d={}, - key_paths=trainable_paths, - objs=tuple( - [ - nested_get(state.params.unfreeze(), path) - for path in trainable_paths - ] - ), - allow_nonexistent=True, - ) - ) + state = self._freeze_optimizer_in_state(state, config) + self.partition_manager.shapes_dtypes = eval_shape(lambda: state) + else: + input_shape = calib_data_loader.input_shape + + def init_state_fn(rng): + _state = self._init_state( + input_shape=input_shape, config=config, rng=rng ) + return self._freeze_optimizer_in_state(_state, config) + + state = self.partition_manager.init_sharded_state( + init_state_fn, self.rng.get() ) loss = Loss(self.likelihood, loss_fn=loss_fn) @@ -137,51 +149,86 @@ def _calibrate( rng=self.rng.get(), state=state, loss_fun=loss, - training_dataloader=calib_data_loader, + training_data_loader=calib_data_loader, training_dataset_size=n_calib_data, n_epochs=config.optimizer.n_epochs, metrics=config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=n_val_data, verbose=config.monitor.verbose, callbacks=config.callbacks, ) + self.predictive.state = CalibStateRepository( - config.checkpointer.save_checkpoint_dir - if config.checkpointer.dump_state is True - else None - ) - self.predictive.state.put( - state, keep=config.checkpointer.keep_top_n_checkpoints + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str( + pathlib.Path(config.checkpointer.save_checkpoint_dir) + / config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=config.checkpointer.keep_top_n_checkpoints, + ) + if config.checkpointer.save_checkpoint_dir is not None + and config.checkpointer.dump_state + else None, ) + if self.predictive.state.checkpoint_manager is None: + self.predictive.state.put(state, keep=config.checkpointer.keep_top_n_checkpoints) + if config.monitor.verbose: logging.info("Calibration completed.") return status - def load_state(self, checkpoint_path: Path) -> None: + def load_state( + self, + checkpoint_dir: Path, + keep_top_n_checkpoints: int = 2, + checkpoint_type: str = "last" + ) -> None: """ Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the probabilistic model. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to a checkpoint file or directory to restore. + keep_top_n_checkpoints : int + Number of past checkpoint files to keep. + checkpoint_type: str + Which checkpoint type to pass to the state. + There are two possible options: + + - "last": this is the state obtained at the end of training. + - "best": this is the best checkpoint with respect to the metric monitored by early stopping. Notice that + this might be available only if validation data is provided, and both checkpoint saving and early + stopping are enabled. + """ + self.predictive.state = CalibStateRepository( + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager(checkpoint_dir=str(pathlib.Path(checkpoint_dir) / checkpoint_type), keep_top_n_checkpoints=keep_top_n_checkpoints) + ) + self.partition_manager.shapes_dtypes = self.predictive.state.get_shapes_dtypes_checkpoint() + + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ - try: - self.restore_checkpoint(checkpoint_path) - except ValueError: + Save the state of the calibration model to a checkpoint directory. + + Parameters + ---------- + checkpoint_dir: Path + Path to checkpoint file or directory to restore. + keep_top_n_checkpoints: int + Number of past checkpoint files to keep. + """ + if self.predictive.state is None: raise ValueError( - f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." + """No state available. You must first either fit the posterior distribution, or load a + saved checkpoint.""" ) - self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_path) - - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: - return self.predictive.state.put( + self.predictive.state.put( self.predictive.state.get(), - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints, ) @@ -209,13 +256,12 @@ def _get_output_dim(self, input_shape: Shape, **kwargs) -> int: else outputs.shape[-1] ) - def _init(self, data_loader: DataLoader, config: Config): - for inputs, targets in data_loader: - input_shape = get_input_shape(inputs) - break + def _init_state(self, input_shape: Shape, config: Config, rng: Optional[PRNGKeyArray] = None): + if rng is None: + rng = self.rng.get() state = ModelManagerState.init_from_dict( - self.likelihood.model_manager.init(input_shape, rng=self.rng.get()) + self.likelihood.model_manager.init(input_shape, rng=rng) ) return CalibState.init( params=state.params, @@ -223,20 +269,60 @@ def _init(self, data_loader: DataLoader, config: Config): optimizer=config.optimizer.method, ) - def _init_state(self, calib_data_loader: DataLoader, config: Config) -> CalibState: - if config.checkpointer.restore_checkpoint_path is None: - if config.checkpointer.start_from_current_state: - state = self.predictive.state.get(optimizer=config.optimizer.method) - else: - state = self._init(calib_data_loader, config) - else: - if config.checkpointer.start_from_current_state: - logging.warning( - "`config.checkpointer.start_from_current_state` will be ignored since " - "`config.checkpointer.restore_checkpoint_path` is given." + def _freeze_optimizer_in_state( + self, + state: CalibState, + config: Config + ) -> CalibState: + if config.optimizer.freeze_fun is not None: + trainable_paths = get_trainable_paths( + state.params, config.optimizer.freeze_fun + ) + state = state.replace( + opt_state=config.optimizer.method.init( + FrozenDict( + nested_set( + d={}, + key_paths=trainable_paths, + objs=tuple( + [ + nested_get(state.params.unfreeze(), path) + for path in trainable_paths + ] + ), + allow_nonexistent=True, + ) + ) ) - state = self.restore_checkpoint( - restore_checkpoint_path=config.checkpointer.restore_checkpoint_path, - optimizer=config.optimizer.method, ) return state + + def _is_state_available_somewhere(self, config: Config) -> bool: + return ( + config.checkpointer.restore_checkpoint_dir is not None + or config.checkpointer.start_from_current_state + ) + + def _restore_state_from_somewhere( + self, + config: Config, + allowed_states: Optional[Tuple[Type[CalibState], ...]] = None, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + ) -> CalibState: + if checkpoint_manager is not None: + repo = CalibStateRepository( + partition_manager=partition_manager, + checkpoint_manager=checkpoint_manager, + ) + state = repo.get(optimizer=config.optimizer.method) + elif config.checkpointer.start_from_current_state: + state = self.predictive.state.get(optimizer=config.optimizer.method) + + if allowed_states is not None and not isinstance(state, allowed_states): + raise ValueError( + f"The type of the restored checkpoint must be within {allowed_states}. " + f"However, the restored checkpoint has type {type(state)}." + ) + + return state diff --git a/fortuna/calib_model/calib_mixin.py b/fortuna/calib_model/calib_mixin.py index 029e991e..a0c212ae 100644 --- a/fortuna/calib_model/calib_mixin.py +++ b/fortuna/calib_model/calib_mixin.py @@ -1,10 +1,9 @@ -import os from typing import Optional -from flax.training import checkpoints - -from fortuna.calib_model.state import CalibState -from fortuna.training.mixin import WithCheckpointingMixin +from fortuna.prob_model.posterior.name_to_posterior_state import NameToPosteriorState +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.calib_model.name_to_calib_state import NameToCalibState +from fortuna.training.name_to_train_state import NameToTrainState from fortuna.typing import ( OptaxOptimizer, Path, @@ -14,27 +13,22 @@ class WithCalibCheckpointingMixin(WithCheckpointingMixin): def restore_checkpoint( self, - restore_checkpoint_path: Path, + restore_checkpoint_dir: Path, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, - ) -> CalibState: - if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( - restore_checkpoint_path - ): - raise ValueError( - f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_path), - target=None, - step=None, - prefix=prefix, - parallel=True, + name_to_train_state: NameToTrainState = NameToPosteriorState, + ): + return super().restore_checkpoint( + restore_checkpoint_dir=restore_checkpoint_dir, + optimizer=optimizer, + name_to_train_state=name_to_train_state, ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." - ) - return CalibState.init_from_dict(d, optimizer, **kwargs) + def get_shapes_dtypes_checkpoint( + self, + restore_checkpoint_dir: Optional[Path] = None, + name_to_train_state: NameToTrainState = NameToCalibState, + ): + return super().get_shapes_dtypes_checkpoint( + restore_checkpoint_dir=restore_checkpoint_dir, + name_to_train_state=name_to_train_state, + ) diff --git a/fortuna/calib_model/calib_model_calibrator.py b/fortuna/calib_model/calib_model_calibrator.py index 81c58aa3..9e852ebb 100644 --- a/fortuna/calib_model/calib_model_calibrator.py +++ b/fortuna/calib_model/calib_model_calibrator.py @@ -13,11 +13,8 @@ from optax._src.base import PyTree from fortuna.calib_model.state import CalibState -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, - TrainerABC, -) +from fortuna.training.mixins.sharding import ShardingMixin +from fortuna.training.trainer import TrainerABC from fortuna.typing import ( Array, Batch, @@ -102,9 +99,5 @@ def validation_step( return dict(val_loss=loss) -class JittedCalibModelCalibrator(JittedMixin, CalibModelCalibrator): - pass - - -class MultiDeviceCalibModelCalibrator(MultiDeviceMixin, CalibModelCalibrator): +class ShardedCalibModelCalibrator(ShardingMixin, CalibModelCalibrator): pass diff --git a/fortuna/calib_model/classification.py b/fortuna/calib_model/classification.py index e8e4994e..6c8acb0e 100644 --- a/fortuna/calib_model/classification.py +++ b/fortuna/calib_model/classification.py @@ -21,11 +21,14 @@ ClassificationMaskedProbOutputLayer, ClassificationProbOutputLayer, ) +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.typing import ( Outputs, Status, Targets, ) +from fortuna.calib_model.calib_mixin import WithCalibCheckpointingMixin from fortuna.utils.data import get_input_shape @@ -34,6 +37,7 @@ def __init__( self, model: nn.Module, model_editor: Optional[ModelEditor] = None, + partitioner: Partitioner = Partitioner(), seed: int = 0, ): r""" @@ -48,6 +52,8 @@ def __init__( a function :math:`f(w, x)`, where each component of :math:`f` corresponds to one of the classes. model_editor : ModelEditor A model_editor objects. It takes the forward pass and transforms the outputs. + partitioner : Partitioner + A partitioning object for data, fully sharded data model parallelization. seed: int A random seed. @@ -76,6 +82,7 @@ def __init__( prob_output_layer=self.prob_output_layer, output_calib_manager=None, ) + self.partition_manager = PartitionManager(partitioner) self.predictive = ClassificationPredictive(likelihood=self.likelihood) super().__init__(seed=seed) diff --git a/fortuna/calib_model/config/checkpointer.py b/fortuna/calib_model/config/checkpointer.py index 645e776d..810c294c 100644 --- a/fortuna/calib_model/config/checkpointer.py +++ b/fortuna/calib_model/config/checkpointer.py @@ -7,11 +7,12 @@ class Checkpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, start_from_current_state: bool = False, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, dump_state: bool = False, + checkpoint_type: str = "last", ): """ An object to configure saving and restoring of checkpoints during the calibration process. @@ -20,10 +21,10 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] = None Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. start_from_current_state: bool = False - If True, the optimization will start from the current state. If `restore_checkpoint_path` is given, then + If True, the optimization will start from the current state. If `restore_checkpoint_dir` is given, then `start_from_current_state` is ignored. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -33,10 +34,19 @@ def __init__( dump_state: bool Dump the fitted calibration state as a checkpoint in `save_checkpoint_dir`. Any future call to the state will internally involve restoring it from memory. + checkpoint_type: str + Which checkpoint type to pass to the state. + There are two possible options: + + - "last": this is the state obtained at the end of training. + - "best": this is the best checkpoint with respect to the metric monitored by early stopping. Notice that + this might be available only if validation data is provided, and both checkpoint saving and early + stopping are enabled. """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.start_from_current_state = start_from_current_state self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state + self.checkpoint_type = checkpoint_type diff --git a/fortuna/calib_model/loss.py b/fortuna/calib_model/loss.py index 915ae781..57abdf0c 100644 --- a/fortuna/calib_model/loss.py +++ b/fortuna/calib_model/loss.py @@ -86,7 +86,7 @@ def __call__( mutable=mutable, rng=rng, ) - if "mutable" in return_aux: + if train and "mutable" is not None: outputs, aux = outs mutable = aux["mutable"] else: @@ -104,17 +104,17 @@ def __call__( outputs=outputs, calib="calib_mutable" in return_aux, ) - if ( - calib_mutable is not None - and calib_mutable["output_calibrator"] is not None - and "calib_mutable" in return_aux - ): - outputs, aux["calib_mutable"] = outs - aux["calib_mutable"] = dict(output_calibrator=aux["calib_mutable"]) - else: - outputs = outs - if "calib_mutable" in return_aux: - aux["calib_mutable"] = dict(output_calibrator=None) + if ( + calib_mutable is not None + and calib_mutable["output_calibrator"] is not None + and "calib_mutable" in return_aux + ): + outputs, aux["calib_mutable"] = outs + aux["calib_mutable"] = dict(output_calibrator=aux["calib_mutable"]) + else: + outputs = outs + if "calib_mutable" in return_aux: + aux["calib_mutable"] = dict(output_calibrator=None) if "outputs" in return_aux: aux["outputs"] = outputs diff --git a/fortuna/calib_model/name_to_calib_state.py b/fortuna/calib_model/name_to_calib_state.py new file mode 100644 index 00000000..bd274676 --- /dev/null +++ b/fortuna/calib_model/name_to_calib_state.py @@ -0,0 +1,7 @@ +import enum + +from fortuna.calib_model.state import CalibState + + +class NameToCalibState(enum.Enum): + vars()[CalibState.__name__] = CalibState diff --git a/fortuna/calib_model/regression.py b/fortuna/calib_model/regression.py index cffc3e4f..fc570ed0 100644 --- a/fortuna/calib_model/regression.py +++ b/fortuna/calib_model/regression.py @@ -15,6 +15,9 @@ from fortuna.model.model_manager.regression import RegressionModelManager from fortuna.model_editor.base import ModelEditor from fortuna.prob_output_layer.regression import RegressionProbOutputLayer +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.calib_model.calib_mixin import WithCalibCheckpointingMixin from fortuna.typing import ( Outputs, Status, @@ -28,6 +31,7 @@ def __init__( model: nn.Module, likelihood_log_variance_model: nn.Module, model_editor: Optional[ModelEditor] = None, + partitioner: Partitioner = Partitioner(), seed: int = 0, ): r""" @@ -46,6 +50,8 @@ def __init__( parameters. Then the model is described by a function :math:`\log\sigma^2(w, x)`. model_editor : ModelEditor A model_editor objects. It takes the forward pass and transforms the outputs. + partitioner : Partitioner + A partitioning object for data, fully sharded data model parallelization. seed: int A random seed. @@ -76,6 +82,7 @@ def __init__( prob_output_layer=self.prob_output_layer, output_calib_manager=None, ) + self.partition_manager = PartitionManager(partitioner) self.predictive = RegressionPredictive(likelihood=self.likelihood) super().__init__(seed=seed) diff --git a/fortuna/calib_model/state.py b/fortuna/calib_model/state.py index 4a7be7b9..d8269d98 100644 --- a/fortuna/calib_model/state.py +++ b/fortuna/calib_model/state.py @@ -16,6 +16,8 @@ OptaxOptimizer, Params, ) +import jax.numpy as jnp +from fortuna.utils.strings import convert_string_to_jnp_array class CalibState(TrainState): @@ -23,6 +25,7 @@ class CalibState(TrainState): mutable: Optional[Mutable] = None calib_params: Optional[CalibParams] = None calib_mutable: Optional[CalibMutable] = None + encoded_name: jnp.ndarray = convert_string_to_jnp_array("CalibState") @classmethod def init( diff --git a/fortuna/data/dataset/data_collator.py b/fortuna/data/dataset/data_collator.py index cfc42d48..5c13f5a8 100644 --- a/fortuna/data/dataset/data_collator.py +++ b/fortuna/data/dataset/data_collator.py @@ -121,10 +121,10 @@ def __call__( return_tensors=TensorType.NUMPY, ) - if self.mlm: - # If special token mask has been preprocessed, pop it from the dict. - special_tokens_mask = batch.pop("special_tokens_mask", None) + # If special token mask has been preprocessed, pop it from the dict. + special_tokens_mask = batch.pop("special_tokens_mask", None) + if self.mlm: batch["input_ids"], batch["labels"] = self.mask_tokens( batch["input_ids"], special_tokens_mask=special_tokens_mask ) diff --git a/fortuna/data/dataset/huggingface_datasets.py b/fortuna/data/dataset/huggingface_datasets.py index 7d556459..885ab809 100644 --- a/fortuna/data/dataset/huggingface_datasets.py +++ b/fortuna/data/dataset/huggingface_datasets.py @@ -89,7 +89,7 @@ def data_collator(self): def get_data_loader( self, dataset: Dataset, - per_device_batch_size: int, + batch_size: int, rng: PRNGKeyArray, shuffle: bool = False, drop_last: bool = False, @@ -103,26 +103,26 @@ def get_data_loader( ---------- dataset: Dataset A tokenizeed dataset (see :meth:`.HuggingFaceClassificationDatasetABC.get_tokenized_datasets`). - per_device_batch_size: bool - Batch size for each device. + batch_size: bool + Total batch size, possibly divided over multiple devices. rng: PRNGKeyArray Random number generator. shuffle: bool if True, shuffle the data so that each batch is a ranom sample from the dataset. drop_last: bool - if True, the last batch (which potentially is smaller then the default batch size) is dropped. + if True, the last batch (which potentially is smaller than the default batch size) is dropped. verbose: bool - Whether to show a progress bar while iterating over the dataloader or not. + Whether to show a progress bar while iterating over the data_loader or not. Returns ------- HuggingFaceDataLoader - The dataloader + The data_loader """ iterable = IterableData.from_callable( lambda *args, **kwargs: self._get_data_loader( dataset, - batch_size=per_device_batch_size * jax.local_device_count(), + batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, rng=rng, diff --git a/fortuna/data/loader/base.py b/fortuna/data/loader/base.py index 73e1b1ca..3663fbf3 100644 --- a/fortuna/data/loader/base.py +++ b/fortuna/data/loader/base.py @@ -2,6 +2,7 @@ import abc from typing import ( + Any, Callable, Iterable, List, @@ -13,9 +14,14 @@ from flax import jax_utils import jax +from jax.sharding import PartitionSpec from jax.tree_util import tree_map -from fortuna.data.loader.utils import IterableData +from fortuna.data.loader.utils import ( + IterableData, + prefetch_to_mesh, +) +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.typing import ( Array, Batch, @@ -185,7 +191,7 @@ def from_tensorflow_data_loader(cls: Type[T], tf_data_loader) -> T: T A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`. """ - return cls(iterable=IterableData.from_tf_dataloader(tf_data_loader)) + return cls(iterable=IterableData.from_tf_data_loader(tf_data_loader)) @classmethod def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T: @@ -203,7 +209,7 @@ def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T: T A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`. """ - return cls(iterable=IterableData.from_torch_dataloader(torch_data_loader)) + return cls(iterable=IterableData.from_torch_data_loader(torch_data_loader)) @classmethod def from_inputs_loaders( @@ -553,6 +559,27 @@ def __iter__(self, *args, **kwargs): yield from loader +class ShardedPrefetchedLoader: + def __init__( + self, + loader, + partition_manager: Optional[PartitionManager] = None, + partition_spec: Optional[PartitionSpec] = None, + ): + self._loader = loader + self.partition_manager = partition_manager + self.partition_spec = partition_spec + + def __iter__(self, *args, **kwargs): + loader = prefetch_to_mesh( + iter(self._loader), + 2, + self.partition_manager.partitioner.mesh, + self.partition_spec, + ) + yield from loader + + class ConcatenatedLoader: def __init__( self, diff --git a/fortuna/data/loader/huggingface_loaders.py b/fortuna/data/loader/huggingface_loaders.py index 3a89801c..39e44461 100644 --- a/fortuna/data/loader/huggingface_loaders.py +++ b/fortuna/data/loader/huggingface_loaders.py @@ -35,7 +35,7 @@ def __init__( Parameters ---------- iterable : Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array],Array]]] - A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_dataloader`. + A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_data_loader`. num_unique_labels: int Number of unique target labels in the task (classification only) num_inputs: Optional[int] diff --git a/fortuna/data/loader/utils.py b/fortuna/data/loader/utils.py index dee5c35d..b5ed985c 100644 --- a/fortuna/data/loader/utils.py +++ b/fortuna/data/loader/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations +import collections from copy import deepcopy +import itertools from itertools import zip_longest from typing import ( Iterable, @@ -9,6 +11,12 @@ Union, ) +import jax +from jax.sharding import ( + Mesh, + NamedSharding, + PartitionSpec, +) import numpy as np from fortuna.typing import ( @@ -44,9 +52,9 @@ def _inner(): return cls(_inner) @classmethod - def from_tf_dataloader(cls, tf_dataloader) -> IterableData: + def from_tf_data_loader(cls, tf_data_loader) -> IterableData: def _inner(): - for batch_inputs, batch_targets in tf_dataloader: + for batch_inputs, batch_targets in tf_data_loader: if not isinstance(batch_inputs, dict): batch_inputs = batch_inputs.numpy() else: @@ -57,9 +65,9 @@ def _inner(): return cls(_inner) @classmethod - def from_torch_dataloader(cls, torch_dataloader) -> IterableData: + def from_torch_data_loader(cls, torch_data_loader) -> IterableData: def _inner(): - for batch_inputs, batch_targets in torch_dataloader: + for batch_inputs, batch_targets in torch_data_loader: if not isinstance(batch_inputs, dict): batch_inputs = batch_inputs.numpy() else: @@ -234,3 +242,29 @@ def _prefetch(self): if not self._ready: self._batch = self._generator.__next__() self._ready = True + + +def prefetch_to_mesh(iterator, size: int, mesh: Mesh, xs_spec): + queue = collections.deque() + + def _prefetch(xs): + return jax.device_put( + xs, + NamedSharding( + mesh, + xs_spec + if xs_spec is not None + else xs.sharding.spec + if hasattr(xs, "sharding") and hasattr(xs.sharding, "spec") + else PartitionSpec(), + ), + ) + + def enqueue(n): # Enqueues *up to* `n` elements from the iterator. + for data in itertools.islice(iterator, n): + queue.append(jax.tree_util.tree_map(_prefetch, data)) + + enqueue(size) # Fill up the buffer. + while queue: + yield queue.popleft() + enqueue(1) diff --git a/fortuna/likelihood/base.py b/fortuna/likelihood/base.py index 118439f3..0f1e60b9 100644 --- a/fortuna/likelihood/base.py +++ b/fortuna/likelihood/base.py @@ -8,6 +8,7 @@ Union, ) +from flax.core import FrozenDict from jax import ( jit, pmap, @@ -214,9 +215,10 @@ def _batched_log_joint_prob( mutable=mutable, rng=rng, ) - if "mutable" in return_aux: + if train and mutable is not None: outputs, aux = outs - mutable = aux["mutable"] + if mutable in return_aux: + mutable = aux["mutable"] else: outputs = outs @@ -238,11 +240,11 @@ def _batched_log_joint_prob( and "calib_mutable" in return_aux ): outputs, aux["calib_mutable"] = outs - aux["calib_mutable"] = dict(output_calibrator=aux["calib_mutable"]) + aux["calib_mutable"] = FrozenDict(output_calibrator=aux["calib_mutable"]) else: outputs = outs if "calib_mutable" in return_aux: - aux["calib_mutable"] = dict(output_calibrator=None) + aux["calib_mutable"] = FrozenDict(output_calibrator=None) log_joint_prob = jnp.sum( self.prob_output_layer.log_prob(outputs, targets, train=train, **kwargs) diff --git a/fortuna/metric/classification.py b/fortuna/metric/classification.py index 749f3180..1c0f953b 100755 --- a/fortuna/metric/classification.py +++ b/fortuna/metric/classification.py @@ -29,15 +29,7 @@ def accuracy(preds: Array, targets: Array) -> jnp.ndarray: jnp.ndarray The computed accuracy. """ - if preds.ndim > 1: - raise ValueError( - """`preds` must be a one-dimensional array of predicted classes.""" - ) - if targets.ndim > 1: - raise ValueError( - """`targets` must be a one-dimensional array of target classes.""" - ) - return jnp.mean(preds == targets) + return jnp.mean(jnp.equal(preds, targets)) def compute_counts_confs_accs( @@ -211,12 +203,7 @@ def brier_score(probs: Array, targets: Union[TargetsLoader, Array]) -> jnp.ndarr jnp.ndarray The Brier score. """ - if probs.ndim != 2: - raise ValueError( - """`probs` must be a two-dimensional array of probabilities for each class and each data - point.""" - ) if type(targets) == TargetsLoader: targets = targets.to_array_targets() - targets = jax.nn.one_hot(targets, probs.shape[1]) - return jnp.mean(jnp.sum((probs - targets) ** 2, axis=1)) + targets = jax.nn.one_hot(targets, probs.shape[-1]) + return jnp.mean(jnp.sum((probs - targets) ** 2, axis=-1)) diff --git a/fortuna/model/llama.py b/fortuna/model/llama.py new file mode 100644 index 00000000..49e2427d --- /dev/null +++ b/fortuna/model/llama.py @@ -0,0 +1,1268 @@ +from functools import partial +import json +import os +from shutil import copyfile +import tempfile +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Union, +) + +from flax.core.frozen_dict import ( + FrozenDict, + freeze, + unfreeze, +) +import flax.linen as nn +from flax.linen import ( + combine_masks, + make_causal_mask, + partitioning as nn_partitioning, +) +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import ( + flatten_dict, + unflatten_dict, +) +import jax +from jax import lax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as PS +import numpy as np +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxCausalLMOutput, +) +from transformers.modeling_flax_utils import FlaxPreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + +LLAMA_STANDARD_CONFIGS = { + "3b": { + "vocab_size": 32000, + "hidden_size": 3200, + "intermediate_size": 8640, + "num_hidden_layers": 26, + "num_attention_heads": 32, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "7b": { + "vocab_size": 32000, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "13b": { + "vocab_size": 32000, + "hidden_size": 5120, + "intermediate_size": 13824, + "num_hidden_layers": 40, + "num_attention_heads": 40, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "30b": { + "vocab_size": 32000, + "hidden_size": 6656, + "intermediate_size": 17920, + "num_hidden_layers": 60, + "num_attention_heads": 52, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "65b": { + "vocab_size": 32000, + "hidden_size": 8192, + "intermediate_size": 22016, + "num_hidden_layers": 80, + "num_attention_heads": 64, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-5, + "use_cache": True, + "tie_word_embeddings": False, + }, + "debug": { # A small model for debugging + "vocab_size": 32000, + "hidden_size": 128, + "intermediate_size": 256, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, +} + + +class LLaMAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_sequence_length (`int`, *optional*, defaults to 2048): + Max sequence length for model (for RoPE computation) + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from transformers import LLaMAModel, LLaMAConfig + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LLaMAConfig() + >>> # Initializing a model from the llama-7b style configuration + >>> model = LLaMAModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "llama" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + max_sequence_length=2048, + rms_norm_eps=1e-6, + initializer_range=0.02, + use_cache=True, + # pad_token_id=-1, + bos_token_id=0, + eos_token_id=1, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + tie_word_embeddings=False, + remat_block="nothing_saveable", + remat_attention="", + remat_mlp="", + fcm_min_ratio=0.0, + fcm_max_ratio=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.remat_block = remat_block + self.remat_attention = remat_attention + self.remat_mlp = remat_mlp + self.fcm_min_ratio = fcm_min_ratio + self.fcm_max_ratio = fcm_max_ratio + super().__init__( + # pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class RMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.weight = self.param( + "kernel", + nn.initializers.ones, + (self.dim,), + self.param_dtype, + ) + + def _norm(self, x: jnp.ndarray) -> jnp.ndarray: + return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) + output = self._norm(x).astype(self.dtype) + weight = jnp.asarray(self.weight, self.dtype) + return output * weight + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32 +) -> jnp.ndarray: + freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) + t = np.arange(end) # type: ignore + freqs = np.outer(t, freqs).astype(dtype) # type: ignore + sin, cos = np.sin(freqs), np.cos(freqs) + freqs_cis = np.complex64(cos + 1j * sin) + return jnp.asarray(freqs_cis) + + +def apply_rotary_emb( + xq: jnp.ndarray, + xk: jnp.ndarray, + freqs_cis: jnp.ndarray, + dtype: jnp.dtype = jnp.float32, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) + reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + + xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) + xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + + # add head dim + freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) + + xq_out = xq_ * freqs_cis + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape( + *xq_out.shape[:-1], -1 + ) + + xk_out = xk_ * freqs_cis + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape( + *xk_out.shape[:-1], -1 + ) + + return xq_out.astype(dtype), xk_out.astype(dtype) + + +class FlaxLLaMAAttention(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + config = self.config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.wq = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wk = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wv = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wo = nn.Dense( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + + self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) + + self.causal_mask = make_causal_mask( + jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool" + ) + + self.freqs_cis = precompute_freqs_cis( + self.head_dim, + config.max_sequence_length * 2, + dtype=self.dtype, + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape( + hidden_states.shape[:2] + (self.num_heads, self.head_dim) + ) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable( + "cache", "cached_key", jnp.zeros, key.shape, key.dtype + ) + cached_value = self.variable( + "cache", "cached_value", jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32) + ) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + position_ids, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask=None, + ): + xq, xk, xv = ( + self.wq(hidden_states), + self.wk(hidden_states), + self.wv(hidden_states), + ) + + xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp")) + xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp")) + xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp")) + + xq = self._split_heads(xq) + xk = self._split_heads(xk) + xv = self._split_heads(xv) + + freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype) + + query_length, key_length = xq.shape[1], xk.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length), + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size,) + causal_mask.shape[1:] + ) + + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) + + dropout_rng = None + if not deterministic and self.config.attn_pdrop > 0.0: + dropout_rng = self.make_rng("dropout") + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + xk, xv, attention_mask = self._concatenate_to_cache( + xk, xv, xq, attention_mask + ) + + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype( + self.dtype + ), + ) + + # usual dot product attention + attn_weights = dot_product_attention_weights( + xq, + xk, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attn_pdrop, + deterministic=deterministic, + dtype=jnp.promote_types(self.dtype, jnp.float32), + precision=self.precision, + ) + attn_weights = with_sharding_constraint( + attn_weights, PS(("dp", "fsdp"), "mp", None, None) + ) + + attn_output = jnp.einsum( + "...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision + ) + attn_output = self._merge_heads(attn_output) + attn_output = self.wo(attn_output) + attn_output = self.resid_dropout(attn_output, deterministic=deterministic) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxLLaMAMLP(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self) -> None: + config = self.config + + self.w1 = nn.Dense( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.w2 = nn.Dense( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.w3 = nn.Dense( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.dropout = nn.Dropout(rate=self.config.resid_pdrop) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + x = self.w2(nn.silu(self.w1(x)) * self.w3(x)) + x = self.dropout(x, deterministic=deterministic) + return x + + +class FlaxLLaMABlock(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self) -> None: + attention_module = FlaxLLaMAAttention + mlp_module = FlaxLLaMAMLP + if self.config.remat_attention != "": + attention_module = remat( + FlaxLLaMAAttention, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_attention), + ) + if self.config.remat_mlp != "": + mlp_module = remat( + FlaxLLaMAMLP, + static_argnums=(1,), + policy=get_gradient_checkpoint_policy(self.config.remat_mlp), + ) + + self.attention = attention_module( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.feed_forward = mlp_module( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.attention_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.ffn_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask: Optional[jnp.ndarray] = None, + ): + attn_outputs = self.attention( + self.attention_norm(hidden_states), + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + attn_output = attn_outputs[0] + hidden_states = hidden_states + attn_output + + feed_forward_hidden_states = self.feed_forward( + self.ffn_norm(hidden_states), + deterministic, + ) + hidden_states = hidden_states + feed_forward_hidden_states + + return (hidden_states,) + attn_outputs[1:] + + +class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LLaMAConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: LLaMAConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__( + config, + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init, + ) + + def init_weights( + self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None + ) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape + ) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, position_ids, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + input_ids, + attention_mask, + position_ids, + return_dict=False, + init_cache=True, + ) + return init_variables["cache"] + + @add_start_docstrings_to_model_forward("") + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError( + "Make sure to provide `position_ids` when passing `past_key_values`." + ) + + position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxLLaMABlockCollection(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + block = FlaxLLaMABlock + if self.config.remat_block != "": + block = remat( + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), + ) + self.blocks = [ + block( + self.config, + name=str(i), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if not deterministic and self.config.fcm_max_ratio > 0: + # Apply forgetful causal mask + batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] + fcm_ratio = jax.random.uniform( + self.make_rng("fcm"), + shape=(batch_size, 1, 1, 1), + minval=self.config.fcm_min_ratio, + maxval=self.config.fcm_max_ratio, + ) + fcm_mask = ( + jax.random.uniform( + self.make_rng("fcm"), shape=(batch_size, 1, seq_length, seq_length) + ) + > fcm_ratio + ) + fcm_mask = fcm_mask.at[:, :, :, 0].set(True) + fcm_mask = fcm_mask.astype("bool") + else: + fcm_mask = None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxGPTJModule` will filter them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +class FlaxLLaMAModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.embed_dim = self.config.hidden_size + + self.wte = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.dropout = nn.Dropout(rate=self.config.embd_pdrop) + self.h = FlaxLLaMABlockCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.ln_f = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + input_embeds = self.wte(input_ids.astype("i4")) + + hidden_states = self.dropout(input_embeds, deterministic=deterministic) + + outputs = self.h( + hidden_states, + attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings("", "") +class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAModule + + +# append_call_sample_docstring( +# FlaxLLaMAModel, +# _TOKENIZER_FOR_DOC, +# _CHECKPOINT_FOR_DOC, +# FlaxCausalLMOutput, +# _CONFIG_FOR_DOC, +# ) + + +class FlaxLLaMAForCausalLMModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + precision=self.precision, + ) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + batch_size, seq_length = input_ids.shape + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length), + ) + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + lm_logits = self.lm_head.apply( + {"params": {"kernel": shared_kernel}}, hidden_states + ) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("", "") +class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAForCausalLMModule + + def prepare_inputs_for_generation( + self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None + ): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPTJ uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0) + ) + else: + position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +# append_call_sample_docstring( +# FlaxGPTJForCausalLM, +# _TOKENIZER_FOR_DOC, +# _CHECKPOINT_FOR_DOC, +# FlaxCausalLMOutput, +# _CONFIG_FOR_DOC, +# ) + + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +class LLaMATokenizer(PreTrainedTokenizer): + """ + Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding. + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=False, + add_eos_token=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + super().__init__( + bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs + ) + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + + with tempfile.NamedTemporaryFile() as tfile: + with open_file(self.vocab_file, "rb") as fin: + tfile.write(fin.read()) + tfile.flush() + tfile.seek(0) + self.sp_model.Load(tfile.name) + """ Initialisation""" + self.add_special_tokens( + dict( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + ) + ) + self.pad_token_id = self.unk_token_id + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] diff --git a/fortuna/model/model_manager/base.py b/fortuna/model/model_manager/base.py index 77f90773..0e8210c6 100755 --- a/fortuna/model/model_manager/base.py +++ b/fortuna/model/model_manager/base.py @@ -9,14 +9,15 @@ from flax import linen as nn from flax.core import FrozenDict -from flax.training.checkpoints import PyTree from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.typing import ( InputData, Mutable, Params, + Shape, ) from fortuna.utils.random import WithRNG @@ -67,14 +68,14 @@ def apply( @abc.abstractmethod def init( - self, input_shape: Tuple[int, ...], rng: Optional[PRNGKeyArray] = None, **kwargs + self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs ) -> Dict[str, Mapping]: """ Initialize random parameters and mutable objects. Parameters ---------- - input_shape : Tuple + input_shape : Shape The shape of the input variable. rng: Optional[PRNGKeyArray] A random number generator. diff --git a/fortuna/model/model_manager/classification.py b/fortuna/model/model_manager/classification.py index df4c85fe..f4566fd7 100644 --- a/fortuna/model/model_manager/classification.py +++ b/fortuna/model/model_manager/classification.py @@ -10,11 +10,11 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree import jax from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.model.model_manager.base import ModelManager from fortuna.model.utils.random_features import RandomFeatureGaussianProcess diff --git a/fortuna/model/model_manager/regression.py b/fortuna/model/model_manager/regression.py index 2458c6d6..2fef5659 100644 --- a/fortuna/model/model_manager/regression.py +++ b/fortuna/model/model_manager/regression.py @@ -7,11 +7,11 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree import jax from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.model.model_manager.base import ModelManager from fortuna.typing import ( @@ -65,6 +65,7 @@ def apply( lik_log_var_rngs = None if mutable is not None: + mutable = mutable.unfreeze() mutable["model"] = mutable.get("model") mutable["lik_log_var"] = mutable.get("lik_log_var") @@ -102,12 +103,12 @@ def apply_fn(p, x, m_mutable, llv_mutable): self._check_outputs(model_outputs, lik_log_var_outputs) aux = dict() - if m_mutable or llv_mutable: + if train and (m_mutable or llv_mutable): aux["mutable"] = dict() if m_mutable: aux["mutable"]["model"] = m_mutable if llv_mutable: - aux["mutable"]["lik_log_var"] = m_mutable + aux["mutable"]["lik_log_var"] = llv_mutable return jnp.concatenate( (model_outputs, lik_log_var_outputs), axis=-1 diff --git a/fortuna/model/model_manager/state.py b/fortuna/model/model_manager/state.py index adc02504..c34614b3 100644 --- a/fortuna/model/model_manager/state.py +++ b/fortuna/model/model_manager/state.py @@ -54,7 +54,7 @@ def init_from_dict(cls, d: Union[Dict, FrozenDict]) -> ModelManagerState: mutable = FrozenDict( { k: FrozenDict({_k: _v for _k, _v in v.items() if _k != "params"}) - for k, v in d.items() + for k, v in d.items() if k != "model_editor" } ) flag = 0 diff --git a/fortuna/model/model_manager/transformers/classification.py b/fortuna/model/model_manager/transformers/classification.py index 160b5941..1549ca0f 100644 --- a/fortuna/model/model_manager/transformers/classification.py +++ b/fortuna/model/model_manager/transformers/classification.py @@ -8,13 +8,14 @@ from flax import linen as nn from flax.core import FrozenDict -from flax.training.checkpoints import PyTree import jax from jax import ( numpy as jnp, random, + eval_shape ) from jax._src.prng import PRNGKeyArray +from optax._src.base import PyTree from fortuna.model.model_manager.classification import ( ClassificationModelManager, @@ -26,6 +27,8 @@ Mutable, Params, ) +from flax.traverse_util import flatten_dict +from flax.core.frozen_dict import unfreeze from fortuna.utils.data import get_inputs_from_shape from fortuna.utils.nested_dicts import nested_update @@ -65,8 +68,6 @@ def apply_fn(p, x): ) if hasattr(_outputs, "logits"): _outputs = _outputs.logits - if _outputs.ndim == 3: - _outputs = _outputs[:, -1] if isinstance(_outputs, tuple) and not has_aux: _outputs = _outputs[0] @@ -91,17 +92,20 @@ def apply_fn(p, x): def init( self, input_shape: Tuple[int, ...], rng: Optional[PRNGKeyArray] = None, **kwargs ) -> Dict[str, Mapping]: - assert self.model._is_initialized, ( - "At the moment Fortuna supports models from Hugging Face that are loaded via " - "`from_pretrained` method, which also takes care of model initialization." - ) + if rng is None: + rng = self.rng.get() + + if not self.model._is_initialized: + if not hasattr(self.model, "_params"): + raise ValueError("If the transformer model is not initialized, you must externally pass the model " + "parameters as attribute `_params` to `model`.") + + self._params_shape_tree = eval_shape(lambda: self.model._params) + self._required_params = set(flatten_dict(unfreeze(self._params_shape_tree)).keys()) + self.model._is_initialized = True + params = {"model": {"params": self.model.params}} if self.model_editor is not None: - if rng is None: - rng = self.rng.get() - output_shape = jax.eval_shape( - self.model, **get_inputs_from_shape(input_shape) - ).logits.shape rng, params_key, dropout_key = random.split(rng, 3) rngs = {"params": params_key, "dropout": dropout_key} @@ -109,8 +113,6 @@ def apply_fn(p, x): _outputs = self.model(**x, params=p) if hasattr(_outputs, "logits"): _outputs = _outputs.logits - if _outputs.ndim == 3: - _outputs = _outputs[:, -1] return _outputs params.update( diff --git a/fortuna/model_editor/__init__.py b/fortuna/model_editor/__init__.py new file mode 100644 index 00000000..9915d1b3 --- /dev/null +++ b/fortuna/model_editor/__init__.py @@ -0,0 +1 @@ +from fortuna.model_editor.probit import ProbitModelEditor diff --git a/fortuna/model_editor/base.py b/fortuna/model_editor/base.py index c24b8c65..78e2023b 100644 --- a/fortuna/model_editor/base.py +++ b/fortuna/model_editor/base.py @@ -12,7 +12,6 @@ from fortuna.typing import ( InputData, - Mutable, ) diff --git a/fortuna/model_editor/classification.py b/fortuna/model_editor/probit.py similarity index 63% rename from fortuna/model_editor/classification.py rename to fortuna/model_editor/probit.py index 34c506c3..ef8fc2a3 100644 --- a/fortuna/model_editor/classification.py +++ b/fortuna/model_editor/probit.py @@ -17,12 +17,16 @@ InputData, Params, ) -from fortuna.utils.probit import probit_scaling +from fortuna.utils.probit import sequential_probit_scaling -class ProbitClassificationModelEditor(ModelEditor): +class ProbitModelEditor(ModelEditor): freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None top_k: Optional[int] = None + memory: Optional[int] = None + n_final_tokens: Optional[int] = None + init_log_var: float = -5.0 + stop_gradient: bool = False @nn.compact def __call__( @@ -34,8 +38,10 @@ def __call__( x: Any, has_aux: bool, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]: - log_var = self.param("log_var", nn.initializers.zeros, (1,)) - outputs = probit_scaling( + log_var = self.param( + "log_var", nn.initializers.constant(self.init_log_var), (1,) + ) + outputs = sequential_probit_scaling( apply_fn, model_params, x, @@ -43,5 +49,8 @@ def __call__( has_aux=has_aux, freeze_fun=self.freeze_fun, top_k=self.top_k, + memory=self.memory, + n_final_tokens=self.n_final_tokens, + stop_gradient=self.stop_gradient, ) return outputs diff --git a/fortuna/output_calib_model/base.py b/fortuna/output_calib_model/base.py index 97f017cc..62c9c5cc 100644 --- a/fortuna/output_calib_model/base.py +++ b/fortuna/output_calib_model/base.py @@ -10,19 +10,15 @@ from fortuna.output_calib_model.config.base import Config from fortuna.output_calib_model.loss import Loss -from fortuna.output_calib_model.output_calib_mixin import ( - WithOutputCalibCheckpointingMixin, -) -from fortuna.output_calib_model.output_calib_model_calibrator import ( +from fortuna.training.train_state_repository import TrainStateRepository +from fortuna.output_calib_model.output_calibrator.base import ( JittedOutputCalibModelCalibrator, MultiDeviceOutputCalibModelCalibrator, OutputCalibModelCalibrator, ) -from fortuna.output_calib_model.output_calib_state_repository import ( - OutputCalibStateRepository, -) from fortuna.output_calib_model.state import OutputCalibState from fortuna.output_calibrator.output_calib_manager.state import OutputCalibManagerState +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.typing import ( Array, Outputs, @@ -34,7 +30,7 @@ from fortuna.utils.random import RandomNumberGenerator -class OutputCalibModel(WithOutputCalibCheckpointingMixin, abc.ABC): +class OutputCalibModel(WithCheckpointingMixin, abc.ABC): """ Abstract calibration model class. """ @@ -90,7 +86,7 @@ def _calibrate( early_stopping_patience=config.monitor.early_stopping_patience, ) - if config.checkpointer.restore_checkpoint_path is None: + if config.checkpointer.restore_checkpoint_dir is None: state = OutputCalibManagerState.init_from_dict( d=FrozenDict( output_calibrator=self.output_calib_manager.init( @@ -105,7 +101,7 @@ def _calibrate( ) else: state = self.restore_checkpoint( - config.checkpointer.restore_checkpoint_path, + config.checkpointer.restore_checkpoint_dir, optimizer=config.optimizer.method, ) @@ -123,7 +119,7 @@ def _calibrate( verbose=config.monitor.verbose, ) - self.predictive.state = OutputCalibStateRepository( + self.predictive.state = TrainStateRepository( config.checkpointer.save_checkpoint_dir if config.checkpointer.dump_state is True else None @@ -133,35 +129,33 @@ def _calibrate( ) return status - def load_state(self, checkpoint_path: Path) -> None: + def load_state(self, checkpoint_dir: Path) -> None: """ Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to a checkpoint file or directory to restore. """ try: - self.restore_checkpoint(checkpoint_path) + self.restore_checkpoint(checkpoint_dir) except ValueError: raise ValueError( - f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." + f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`." ) self.predictive.state = OutputCalibStateRepository( - checkpoint_dir=checkpoint_path + checkpoint_dir=checkpoint_dir ) - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the calibration state as a checkpoint. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to file or directory where to save the current state. keep_top_n_checkpoints : int Number of past checkpoint files to keep. @@ -172,6 +166,6 @@ def save_state( ) return self.predictive.state.put( self.predictive.state.get(), - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints, ) diff --git a/fortuna/output_calib_model/config/checkpointer.py b/fortuna/output_calib_model/config/checkpointer.py index 1a63b0a7..63946ddf 100644 --- a/fortuna/output_calib_model/config/checkpointer.py +++ b/fortuna/output_calib_model/config/checkpointer.py @@ -7,7 +7,7 @@ class Checkpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, dump_state: bool = False, @@ -19,7 +19,7 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] = None Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -32,6 +32,6 @@ def __init__( """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state diff --git a/fortuna/output_calib_model/output_calib_mixin.py b/fortuna/output_calib_model/output_calib_mixin.py deleted file mode 100644 index a0551947..00000000 --- a/fortuna/output_calib_model/output_calib_mixin.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -from typing import Optional - -from flax.training import checkpoints - -from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.mixin import WithCheckpointingMixin -from fortuna.typing import ( - OptaxOptimizer, - Path, -) - - -class WithOutputCalibCheckpointingMixin(WithCheckpointingMixin): - def restore_checkpoint( - self, - restore_checkpoint_path: Path, - optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, - ) -> OutputCalibState: - if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( - restore_checkpoint_path - ): - raise ValueError( - f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_path), - target=None, - step=None, - prefix=prefix, - parallel=True, - ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." - ) - - return OutputCalibState.init_from_dict(d, optimizer, **kwargs) diff --git a/fortuna/output_calib_model/output_calib_state_repository.py b/fortuna/output_calib_model/output_calib_state_repository.py deleted file mode 100644 index 1612e88d..00000000 --- a/fortuna/output_calib_model/output_calib_state_repository.py +++ /dev/null @@ -1,10 +0,0 @@ -from fortuna.output_calib_model.output_calib_mixin import ( - WithOutputCalibCheckpointingMixin, -) -from fortuna.training.train_state_repository import TrainStateRepository - - -class OutputCalibStateRepository( - WithOutputCalibCheckpointingMixin, TrainStateRepository -): - pass diff --git a/fortuna/output_calib_model/output_calibrator/__init__.py b/fortuna/output_calib_model/output_calibrator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/output_calib_model/output_calib_model_calibrator.py b/fortuna/output_calib_model/output_calibrator/base.py similarity index 98% rename from fortuna/output_calib_model/output_calib_model_calibrator.py rename to fortuna/output_calib_model/output_calibrator/base.py index d18ea327..4ebaba36 100644 --- a/fortuna/output_calib_model/output_calib_model_calibrator.py +++ b/fortuna/output_calib_model/output_calibrator/base.py @@ -31,11 +31,9 @@ TargetsLoader, ) from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.mixin import ( - InputValidatorMixin, - WithCheckpointingMixin, - WithEarlyStoppingMixin, -) +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin from fortuna.typing import ( Array, CalibMutable, @@ -442,7 +440,7 @@ def val_epoch_end( ) # early stopping improved = self.early_stopping_update(val_losses_and_metrics_current_epoch) - if improved and self.save_checkpoint_dir: + if improved and self.save_checkpoint_dir is not None: self.save_checkpoint(state, self.save_checkpoint_dir, force_save=True) return val_losses_and_metrics_current_epoch @@ -567,7 +565,7 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: state = self.sync_mutable(state) state = jax.device_get(tree_map(lambda x: x[0], state)) diff --git a/fortuna/output_calibrator/output_calib_manager/base.py b/fortuna/output_calibrator/output_calib_manager/base.py index fcc4b50d..2ebff850 100644 --- a/fortuna/output_calibrator/output_calib_manager/base.py +++ b/fortuna/output_calibrator/output_calib_manager/base.py @@ -6,10 +6,10 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.typing import ( Array, @@ -106,7 +106,9 @@ def init( rng, params_key, dropout_key = random.split(rng, 3) rngs = {"params": params_key, "dropout": dropout_key} return ( - self.output_calibrator.init(rngs, jnp.zeros((1, output_dim)), **kwargs) + FrozenDict( + self.output_calibrator.init(rngs, jnp.zeros((1, output_dim)), **kwargs) + ) if self.output_calibrator is not None else None ) diff --git a/fortuna/partitioner/__init__.py b/fortuna/partitioner/__init__.py new file mode 100644 index 00000000..d96500f1 --- /dev/null +++ b/fortuna/partitioner/__init__.py @@ -0,0 +1 @@ +from fortuna.partitioner.base import Partitioner diff --git a/fortuna/partitioner/base.py b/fortuna/partitioner/base.py new file mode 100644 index 00000000..0ed59269 --- /dev/null +++ b/fortuna/partitioner/base.py @@ -0,0 +1,27 @@ +from typing import ( + Dict, + Optional, + Tuple, +) + +from jax.sharding import PartitionSpec + +from fortuna.utils.mesh import get_mesh +from fortuna.utils.port import is_port_in_use + + +class Partitioner: + def __init__( + self, + axes_dims: Optional[Dict[str, int]] = None, + rules: Optional[Dict[str, Tuple[str, ...]]] = None, + ): + if axes_dims is None: + axes_dims = {"dp": 1, "fsdp": 1, "mp": -1} + if rules is None: + rules = {} + self.specs = { + k: PartitionSpec(*v) if v is not None else PartitionSpec(None) + for k, v in rules.items() + } + self.mesh = get_mesh(axes_dims) diff --git a/fortuna/partitioner/partition_manager/__init__.py b/fortuna/partitioner/partition_manager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/partitioner/partition_manager/base.py b/fortuna/partitioner/partition_manager/base.py new file mode 100644 index 00000000..46a2c751 --- /dev/null +++ b/fortuna/partitioner/partition_manager/base.py @@ -0,0 +1,78 @@ +from typing import ( + Any, + Callable, + List, + Optional, +) + +from jax import ( + device_put, + eval_shape, + random, +) +from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit +from jax.sharding import ( + NamedSharding, + PartitionSpec, +) +from jax.tree_util import ( + tree_map, + tree_map_with_path, +) + +from fortuna.partitioner.base import Partitioner +from fortuna.training.train_state import TrainState +from fortuna.utils.partition.base import match_partition_specs +from fortuna.utils.random import WithRNG + + +class PartitionManager(WithRNG): + def __init__(self, partitioner: Partitioner): + self.partitioner = partitioner + self._shapes_dtypes = None + self._shardings = None + + @property + def shapes_dtypes(self): + return self._shapes_dtypes + + @shapes_dtypes.setter + def shapes_dtypes(self, shapes_dtypes: TrainState): + self._shapes_dtypes = shapes_dtypes + partitions = match_partition_specs(self.partitioner.specs, self._shapes_dtypes) + self._shardings = tree_map( + lambda p: NamedSharding(mesh=self.partitioner.mesh, spec=p), partitions + ) + + @property + def shardings(self): + return self._shardings + + @shardings.setter + def shardings(self, shardings: Optional[TrainState]): + self._shardings = shardings + + def init_sharded_state(self, init_state_fn: Callable[[Any], TrainState], *args): + self.shapes_dtypes = eval_shape(init_state_fn, random.PRNGKey(0)) + + with self.partitioner.mesh: + return pjit( + init_state_fn, + in_shardings=PartitionSpec(), + out_shardings=self.shardings, + )(*args) + + def reshard( + self, state: TrainState, exclude: Optional[List[str]] = None + ) -> TrainState: + if self.shardings is not None: + if exclude is None: + exclude = [] + return tree_map_with_path( + lambda p, _v, s: device_put(_v, s) + if _v is not None and p[0].name not in exclude + else _v, + state, + self.shardings, + ) diff --git a/fortuna/prob_model/__init__.py b/fortuna/prob_model/__init__.py index fa19fc5c..d94c71ae 100644 --- a/fortuna/prob_model/__init__.py +++ b/fortuna/prob_model/__init__.py @@ -9,6 +9,7 @@ from fortuna.prob_model.fit_config.monitor import FitMonitor from fortuna.prob_model.fit_config.optimizer import FitOptimizer from fortuna.prob_model.fit_config.processor import FitProcessor +from fortuna.prob_model.fit_config.hyperparameters import FitHyperparameters from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_approximator import ( DeepEnsemblePosteriorApproximator, ) diff --git a/fortuna/prob_model/base.py b/fortuna/prob_model/base.py index 9f9f4e94..0f01e3d4 100644 --- a/fortuna/prob_model/base.py +++ b/fortuna/prob_model/base.py @@ -6,17 +6,17 @@ Optional, ) -import jax +from jax import eval_shape import jax.numpy as jnp from fortuna.data.loader import DataLoader from fortuna.output_calib_model.state import OutputCalibState +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.prob_model_calibrator import ( - JittedProbModelOutputCalibrator, - MultiDeviceProbModelOutputCalibrator, ProbModelOutputCalibrator, + ShardedProbModelOutputCalibrator, ) from fortuna.typing import ( Array, @@ -24,7 +24,6 @@ Status, ) from fortuna.utils.data import check_data_loader_is_not_random -from fortuna.utils.device import select_trainer_given_devices from fortuna.utils.random import RandomNumberGenerator @@ -46,6 +45,7 @@ def __set_rng(self): self.joint.rng = self.rng self.posterior.rng = self.rng self.predictive.rng = self.rng + self.partition_manager.rng = self.rng def train( self, @@ -136,7 +136,7 @@ def _calibrate( "Pre-compute ensemble of outputs on the calibration data loader." ) - distribute = jax.local_devices()[0].platform != "cpu" + shard = not calib_config.processor.disable_jit ( calib_ensemble_outputs_loader, @@ -145,7 +145,7 @@ def _calibrate( inputs_loader=calib_data_loader.to_inputs_loader(), n_output_samples=calib_config.processor.n_posterior_samples, return_size=True, - distribute=distribute, + shard=shard, ) if calib_config.monitor.verbose: logging.info( @@ -156,22 +156,39 @@ def _calibrate( inputs_loader=val_data_loader.to_inputs_loader(), n_output_samples=calib_config.processor.n_posterior_samples, return_size=True, - distribute=distribute, + shard=shard, ) if val_data_loader is not None else (None, None) ) - trainer_cls = select_trainer_given_devices( - devices=calib_config.processor.devices, - base_trainer_cls=ProbModelOutputCalibrator, - jitted_trainer_cls=JittedProbModelOutputCalibrator, - multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator, - disable_jit=calib_config.processor.disable_jit, + output_calib_partition_manager = PartitionManager( + partitioner=self.partition_manager.partitioner + ) + + if calib_config.checkpointer.restore_checkpoint_dir is None: + calib_dict = self.posterior.state.extract_calib_keys() + state = OutputCalibState.init( + params=calib_dict["calib_params"], + mutable=calib_dict["calib_mutable"], + optimizer=calib_config.optimizer.method, + ) + else: + state = self.posterior.state.restore_checkpoint( + calib_config.checkpointer.restore_checkpoint_dir, + optimizer=calib_config.optimizer.method, + ) + output_calib_partition_manager.shapes_dtypes = eval_shape(lambda: state) + + trainer_cls = ( + ShardedProbModelOutputCalibrator + if not calib_config.processor.disable_jit + else ProbModelOutputCalibrator ) calibrator = trainer_cls( calib_outputs_loader=calib_ensemble_outputs_loader, + partition_manager=output_calib_partition_manager, val_outputs_loader=val_ensemble_outputs_loader, predict_fn=self.prob_output_layer.predict, uncertainty_fn=uncertainty_fn, @@ -185,20 +202,6 @@ def _calibrate( early_stopping_patience=calib_config.monitor.early_stopping_patience, ) - if calib_config.checkpointer.restore_checkpoint_path is None: - calib_dict = self.posterior.state.extract_calib_keys() - - state = OutputCalibState.init( - params=calib_dict["calib_params"], - mutable=calib_dict["calib_mutable"], - optimizer=calib_config.optimizer.method, - ) - else: - state = self.posterior.restore_checkpoint( - calib_config.checkpointer.restore_checkpoint_path, - optimizer=calib_config.optimizer.method, - ) - if calib_config.monitor.verbose: logging.info("Start calibration.") state, status = calibrator.train( @@ -225,7 +228,7 @@ def _calibrate( if calib_config.monitor.verbose: logging.info("Dump state to disk.") self.save_state( - checkpoint_path=calib_config.checkpointer.save_checkpoint_dir + checkpoint_dir=str(Path(calib_config.checkpointer.save_checkpoint_dir) / "calibrated") ) if calib_config.monitor.verbose: @@ -233,29 +236,42 @@ def _calibrate( return status - def load_state(self, checkpoint_path: Path) -> None: + def load_state( + self, + checkpoint_dir: Path, + keep_top_n_checkpoints: int = 2, + checkpoint_type: str = "last" + ) -> None: """ Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the probabilistic model. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to a checkpoint file or directory to restore. + keep_top_n_checkpoints : int + Number of past checkpoint files to keep. + checkpoint_type: str + Which checkpoint type to pass to the state. + There are two possible options: + + - "last": this is the state obtained at the end of training. + - "best": this is the best checkpoint with respect to the metric monitored by early stopping. Notice that + this might be available only if validation data is provided, and both checkpoint saving and early + stopping are enabled. """ - return self.posterior.load_state(checkpoint_path) + return self.posterior.load_state(checkpoint_dir, keep_top_n_checkpoints=keep_top_n_checkpoints, checkpoint_type=checkpoint_type) - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the posterior distribution state as a checkpoint. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to file or directory where to save the current state. keep_top_n_checkpoints : int Number of past checkpoint files to keep. """ - return self.posterior.save_state(checkpoint_path, keep_top_n_checkpoints) + return self.posterior.save_state(checkpoint_dir, keep_top_n_checkpoints) diff --git a/fortuna/prob_model/calib_config/checkpointer.py b/fortuna/prob_model/calib_config/checkpointer.py index 8ce9c9c3..ee50b54c 100644 --- a/fortuna/prob_model/calib_config/checkpointer.py +++ b/fortuna/prob_model/calib_config/checkpointer.py @@ -7,7 +7,7 @@ class CalibCheckpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, dump_state: bool = False, @@ -19,7 +19,7 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] = None Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -32,6 +32,6 @@ def __init__( """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state diff --git a/fortuna/prob_model/classification.py b/fortuna/prob_model/classification.py index 8f157256..2ecc4296 100644 --- a/fortuna/prob_model/classification.py +++ b/fortuna/prob_model/classification.py @@ -20,6 +20,8 @@ from fortuna.model_editor.base import ModelEditor from fortuna.output_calibrator.classification import ClassificationTemperatureScaler from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.base import ProbModel from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig @@ -53,6 +55,7 @@ def __init__( posterior_approximator: PosteriorApproximator = SWAGPosteriorApproximator(), output_calibrator: Optional[nn.Module] = ClassificationTemperatureScaler(), model_editor: Optional[ModelEditor] = None, + partitioner: Partitioner = Partitioner(), seed: int = 0, ): r""" @@ -77,6 +80,8 @@ def __init__( calibration parameters. model_editor : ModelEditor A model_editor objects. It takes the forward pass and transforms the outputs. + partitioner : Partitioner + A partitioning object for data, fully sharded data model parallelization. seed: int A random seed. @@ -130,10 +135,14 @@ def __init__( self.model_manager, self.prob_output_layer, self.output_calib_manager ) self.joint = Joint(self.prior, self.likelihood) - + self.partition_manager = PartitionManager(partitioner) self.posterior = getattr( PosteriorApproximations, posterior_approximator.__str__() - ).value(joint=self.joint, posterior_approximator=posterior_approximator) + ).value( + joint=self.joint, + posterior_approximator=posterior_approximator, + partition_manager=self.partition_manager, + ) self.predictive = ClassificationPredictive(self.posterior) super().__init__(seed=seed) @@ -250,7 +259,7 @@ def train( calib_config: CalibConfig = CalibConfig(), **fit_kwargs, ) -> Dict[str, Status]: - self._check_output_dim(train_data_loader) + # self._check_output_dim(train_data_loader) return super().train( train_data_loader, val_data_loader, @@ -283,7 +292,7 @@ def calibrate( Status A calibration status object. It provides information about the calibration. """ - self._check_output_dim(calib_data_loader) + # self._check_output_dim(calib_data_loader) if val_data_loader is not None: self._check_output_dim(val_data_loader) return super()._calibrate( diff --git a/fortuna/prob_model/fit_config/checkpointer.py b/fortuna/prob_model/fit_config/checkpointer.py index 95ef2e9d..c4e810f1 100644 --- a/fortuna/prob_model/fit_config/checkpointer.py +++ b/fortuna/prob_model/fit_config/checkpointer.py @@ -7,11 +7,12 @@ class FitCheckpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, start_from_current_state: bool = False, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, dump_state: bool = False, + checkpoint_type: str = "last", ): """ An object to configure saving and restoring of checkpoints during the posterior fitting. @@ -20,10 +21,10 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. start_from_current_state: bool = False - If True, the optimization will start from the current state. If `restore_checkpoint_path` is given, then + If True, the optimization will start from the current state. If `restore_checkpoint_dir` is given, then `start_from_current_state` is ignored. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -33,10 +34,26 @@ def __init__( dump_state: bool Dump the fitted posterior state as a checkpoint in `save_checkpoint_dir`. Any future call to the state will internally involve restoring it from memory. + checkpoint_type: str + Which checkpoint type to pass to the state. + There are two possible options: + + - "last": this is the state obtained at the end of training. + - "best": this is the best checkpoint with respect to the metric monitored by early stopping. Notice that + this might be available only if validation data is provided, and both checkpoint saving and early + stopping are enabled. """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.start_from_current_state = start_from_current_state self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state + + allowed_checkpoint_types = ["last", "best"] + if checkpoint_type not in allowed_checkpoint_types: + raise ValueError( + f"`checkpoint_type={checkpoint_type}` not recognised. " + f"Pleas select one of the following options: {allowed_checkpoint_types}." + ) + self.checkpoint_type = checkpoint_type diff --git a/fortuna/prob_model/joint/base.py b/fortuna/prob_model/joint/base.py index b113c548..3a34ec8a 100755 --- a/fortuna/prob_model/joint/base.py +++ b/fortuna/prob_model/joint/base.py @@ -6,6 +6,7 @@ ) from flax.core import FrozenDict +from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp @@ -154,7 +155,9 @@ def _batched_negative_log_joint_prob( return loss, aux return -outs - def init(self, input_shape: Shape, **kwargs) -> JointState: + def init( + self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs + ) -> JointState: """ Initialize the state of the joint distribution. @@ -162,15 +165,19 @@ def init(self, input_shape: Shape, **kwargs) -> JointState: ---------- input_shape : Shape The shape of the input variable. + rng: Optional[PRNGKeyArray] + A random number generator key. Returns ------- A state of the joint distribution. """ + if rng is None: + rng = self.rng.get() + key1, key2 = random.split(rng) + oms = ModelManagerState.init_from_dict( - self.likelihood.model_manager.init( - input_shape, rng=self.rng.get(), **kwargs - ) + self.likelihood.model_manager.init(input_shape, rng=key1, **kwargs) ) inputs = get_inputs_from_shape(input_shape) outputs = self.likelihood.model_manager.apply( @@ -184,7 +191,7 @@ def init(self, input_shape: Shape, **kwargs) -> JointState: ocms = OutputCalibManagerState.init_from_dict( FrozenDict( output_calibrator=self.likelihood.output_calib_manager.init( - output_dim=output_dim + output_dim=output_dim, rng=key2 ) ) ) diff --git a/fortuna/prob_model/posterior/base.py b/fortuna/prob_model/posterior/base.py index 0676bb9f..65785fd0 100755 --- a/fortuna/prob_model/posterior/base.py +++ b/fortuna/prob_model/posterior/base.py @@ -1,5 +1,6 @@ import abc import logging +import pathlib from typing import ( Any, Dict, @@ -7,23 +8,26 @@ Tuple, Type, ) - +from fortuna.prob_model.posterior.map.map_state import MAPState from flax.core import FrozenDict from jax._src.prng import PRNGKeyArray +from orbax.checkpoint.checkpoint_manager import CheckpointManager from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState -from fortuna.prob_model.posterior.posterior_mixin import WithPosteriorCheckpointingMixin from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) from fortuna.prob_model.posterior.state import PosteriorState from fortuna.typing import ( Path, + Shape, Status, ) +from fortuna.utils.checkpoint import get_checkpoint_manager from fortuna.utils.freeze import get_trainable_paths from fortuna.utils.nested_dicts import ( nested_get, @@ -46,10 +50,15 @@ def posterior_method_kwargs(self) -> Dict[str, Any]: return {} -class Posterior(WithRNG, WithPosteriorCheckpointingMixin): +class Posterior(WithRNG): state = None - def __init__(self, joint: Joint, posterior_approximator: PosteriorApproximator): + def __init__( + self, + joint: Joint, + posterior_approximator: PosteriorApproximator, + partition_manager: PartitionManager, + ): r""" Posterior distribution class. This refers to :math:`p(w|\mathcal{D}, \phi)`, where :math:`w` are the random model parameters, :math:`\mathcal{D}` is a training data set and :math:`\phi` are calibration parameters. @@ -60,35 +69,45 @@ def __init__(self, joint: Joint, posterior_approximator: PosteriorApproximator): A joint distribution object. posterior_approximator: PosteriorApproximator A posterior approximator. + partition_manager: PartitionManager + An object to manage partitions. """ super().__init__() self.joint = joint self.posterior_approximator = posterior_approximator + self.partition_manager = partition_manager def _restore_state_from_somewhere( self, fit_config: FitConfig, allowed_states: Optional[Tuple[Type[PosteriorState], ...]] = None, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + _do_reshard: bool = True ) -> PosteriorState: - if fit_config.checkpointer.restore_checkpoint_path is not None: - state = self.restore_checkpoint( - restore_checkpoint_path=fit_config.checkpointer.restore_checkpoint_path, - optimizer=fit_config.optimizer.method, + if checkpoint_manager is not None: + repo = PosteriorStateRepository( + partition_manager=partition_manager, + checkpoint_manager=checkpoint_manager, ) - elif fit_config.checkpointer.start_from_current_state is not None: - state = self.state.get(optimizer=fit_config.optimizer.method) + state = repo.get(optimizer=fit_config.optimizer.method) + elif fit_config.checkpointer.start_from_current_state: + state = self.state.get(optimizer=fit_config.optimizer.method, _do_reshard=_do_reshard) if allowed_states is not None and not isinstance(state, allowed_states): raise ValueError( f"The type of the restored checkpoint must be within {allowed_states}. " - f"However, {fit_config.checkpointer.restore_checkpoint_path} pointed to a state " - f"with type {type(state)}." + f"However, the restored checkpoint has type {type(state)}." ) return state - def _init_joint_state(self, data_loader: DataLoader) -> JointState: - return self.joint.init(input_shape=data_loader.input_shape) + def _init_joint_state( + self, data_loader: Optional[DataLoader] = None, input_shape: Optional[Shape] = None, rng: Optional[PRNGKeyArray] = None + ) -> JointState: + if data_loader is None and input_shape is None: + raise ValueError("At least one between `data_loader` and `input_shape` must be provided.") + return self.joint.init(input_shape=input_shape or data_loader.input_shape, rng=rng) @staticmethod def _freeze_optimizer_in_state( @@ -165,33 +184,46 @@ def sample(self, rng: Optional[PRNGKeyArray] = None, *args, **kwargs) -> JointSt """ pass - def load_state(self, checkpoint_path: Path) -> None: + def load_state( + self, + checkpoint_dir: Path, + keep_top_n_checkpoints: int = 2, + checkpoint_type: str = "last" + ) -> None: """ Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the current probabilistic model. Parameters ---------- - checkpoint_path: Path + checkpoint_dir: Path Path to checkpoint file or directory to restore. + keep_top_n_checkpoints : int + Number of past checkpoint files to keep. + checkpoint_type: str + Which checkpoint type to pass to the state. + There are two possible options: + + - "last": this is the state obtained at the end of training. + - "best": this is the best checkpoint with respect to the metric monitored by early stopping. Notice that + this might be available only if validation data is provided, and both checkpoint saving and early + stopping are enabled. """ - try: - self.restore_checkpoint(checkpoint_path) - except ValueError: - raise ValueError( - f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." - ) - self.state = PosteriorStateRepository(checkpoint_dir=checkpoint_path) + self.state = PosteriorStateRepository( + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager(checkpoint_dir=str(pathlib.Path(checkpoint_dir) / checkpoint_type), keep_top_n_checkpoints=keep_top_n_checkpoints) + ) + # currently, sharding is only supported with MAPState + if isinstance(self.state, MAPState): + self.partition_manager.shapes_dtypes = self.state.get_shapes_dtypes_checkpoint() - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the state of the posterior distribution to a checkpoint directory. Parameters ---------- - checkpoint_path: Path + checkpoint_dir: Path Path to checkpoint file or directory to restore. keep_top_n_checkpoints: int Number of past checkpoint files to keep. @@ -201,9 +233,9 @@ def save_state( """No state available. You must first either fit the posterior distribution, or load a saved checkpoint.""" ) - return self.state.put( + self.state.put( self.state.get(), - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints, ) @@ -216,10 +248,9 @@ def _check_fit_config(self, fit_config: FitConfig): "`save_checkpoint_dir` must be passed when `dump_state` is set to True." ) - @staticmethod - def _is_state_available_somewhere(fit_config: FitConfig) -> bool: + def _is_state_available_somewhere(self, fit_config: FitConfig) -> bool: return ( - fit_config.checkpointer.restore_checkpoint_path is not None + fit_config.checkpointer.restore_checkpoint_dir is not None or fit_config.checkpointer.start_from_current_state ) @@ -234,7 +265,7 @@ def _warn_frozen_params_start_from_random( logging.warning( "Parameters frozen via `fit_config.optimizer.freeze_fun` will not be updated. To start " "from sensible frozen parameters, you should configure " - "`fit_config.checkpointer.restore_checkpoint_path`, or " + "`fit_config.checkpointer.restore_checkpoint_dir`, or " "`fit_config.checkpointer.start_from_current_state`, or `map_fit_config`. " "Otherwise, " "a randomly initialized configuration of frozen parameters will be returned." diff --git a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py index 6fe54e6f..12efda46 100755 --- a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py +++ b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os import pathlib from typing import ( List, @@ -9,15 +8,18 @@ Tuple, Type, ) - +from fortuna.utils.checkpoint import get_checkpoint_manager from flax.core import FrozenDict from jax import ( pure_callback, random, ) +from copy import deepcopy +from orbax.checkpoint import CheckpointManager from jax._src.prng import PRNGKeyArray - from fortuna.data.loader import DataLoader +from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior +from fortuna.prob_model.posterior.map.map_approximator import MAPPosteriorApproximator from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState @@ -26,27 +28,26 @@ from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_approximator import ( DeepEnsemblePosteriorApproximator, ) +from fortuna.prob_model.posterior.state import PosteriorState from fortuna.prob_model.posterior.map.map_posterior import MAPState -from fortuna.prob_model.posterior.map.map_trainer import ( - JittedMAPTrainer, - MAPTrainer, - MultiDeviceMAPTrainer, -) +from fortuna.prob_model.posterior.map.map_trainer import ShardedMAPTrainer, MAPTrainer from fortuna.prob_model.posterior.posterior_multi_state_repository import ( PosteriorMultiStateRepository, ) +from fortuna.prob_model.posterior.posterior_state_repository import PosteriorStateRepository from fortuna.prob_model.posterior.run_preliminary_map import run_preliminary_map from fortuna.typing import ( Path, Status, ) from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype -from fortuna.utils.device import select_trainer_given_devices from fortuna.utils.freeze import get_trainable_paths from fortuna.utils.nested_dicts import ( nested_get, nested_set, ) +from fortuna.partitioner.partition_manager.base import PartitionManager + logger = logging.getLogger(__name__) @@ -56,6 +57,7 @@ def __init__( self, joint: Joint, posterior_approximator: DeepEnsemblePosteriorApproximator, + partition_manager: PartitionManager ): """ Deep ensemble approximate posterior class. @@ -67,7 +69,7 @@ def __init__( posterior_approximator: DeepEnsemble Deep ensemble posterior approximator. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) def __str__(self): return DEEP_ENSEMBLE_NAME @@ -80,101 +82,37 @@ def fit( map_fit_config: Optional[FitConfig] = None, **kwargs, ) -> List[Status]: - super()._checks_on_fit_start(fit_config, map_fit_config) - - status = dict() - - map_state = None - if map_fit_config is not None and fit_config.optimizer.freeze_fun is None: - logging.warning( - "It appears that you are trying to configure `map_fit_config`. " - "However, a preliminary run with MAP is supported only if " - "`fit_config.optimizer.freeze_fun` is given. " - "Since the latter was not given, `map_fit_config` will be ignored." + def _fun(i: int): + fit_config_i = deepcopy(fit_config) + fit_config_i.checkpointer.save_checkpoint_dir = str(pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / str(i)) if fit_config.checkpointer.save_checkpoint_dir else None + fit_config_i.checkpointer.restore_checkpoint_dir = str(pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) / str(i)) if fit_config.checkpointer.restore_checkpoint_dir else None + map_posterior = MAPPosterior( + self.joint, posterior_approximator=MAPPosteriorApproximator(), partition_manager=self.partition_manager ) - elif not super()._is_state_available_somewhere( - fit_config - ) and super()._should_run_preliminary_map(fit_config, map_fit_config): - map_state, status["map"] = run_preliminary_map( - joint=self.joint, + map_posterior.rng = self.rng + if self.state is not None: + map_posterior.state = self.state.state[i] + + status = map_posterior.fit( + rng=map_posterior.rng.get(), train_data_loader=train_data_loader, val_data_loader=val_data_loader, - map_fit_config=map_fit_config, - rng=self.rng, + fit_config=fit_config_i, **kwargs, ) + return map_posterior.state, status - trainer_cls = select_trainer_given_devices( - devices=fit_config.processor.devices, - base_trainer_cls=MAPTrainer, - jitted_trainer_cls=JittedMAPTrainer, - multi_device_trainer_cls=MultiDeviceMAPTrainer, - disable_jit=fit_config.processor.disable_jit, - ) - - train_data_size = train_data_loader.size - val_data_size = val_data_loader.size if val_data_loader is not None else None - - def _fit(i): - if self._is_state_available_somewhere(fit_config): - _state = self._restore_state_from_somewhere( - i=i, - fit_config=fit_config, - allowed_states=(MAPState,), - ) - else: - _state = self._init_map_state(map_state, train_data_loader, fit_config) - - _state = self._freeze_optimizer_in_state(_state, fit_config) - - save_checkpoint_dir_i = ( - pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / str(i) - if fit_config.checkpointer.save_checkpoint_dir - else None - ) - trainer = trainer_cls( - predict_fn=self.joint.likelihood.prob_output_layer.predict, - save_checkpoint_dir=save_checkpoint_dir_i, - save_every_n_steps=fit_config.checkpointer.save_every_n_steps, - keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, - disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, - eval_every_n_epochs=fit_config.monitor.eval_every_n_epochs, - early_stopping_monitor=fit_config.monitor.early_stopping_monitor, - early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta, - early_stopping_patience=fit_config.monitor.early_stopping_patience, - freeze_fun=fit_config.optimizer.freeze_fun, - ) - - return trainer.train( - rng=self.rng.get(), - state=_state, - loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, - training_dataset_size=train_data_size, - n_epochs=fit_config.optimizer.n_epochs, - metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, - validation_dataset_size=val_data_size, - verbose=fit_config.monitor.verbose, - callbacks=fit_config.callbacks, - max_grad_norm=fit_config.hyperparameters.max_grad_norm, - gradient_accumulation_steps=fit_config.hyperparameters.gradient_accumulation_steps, - ) - - if isinstance(self.state, PosteriorMultiStateRepository): - for i in range(self.posterior_approximator.ensemble_size): - self.state.state[i].checkpoint_dir = ( - pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / str(i) - if fit_config.checkpointer.save_checkpoint_dir is not None - and fit_config.checkpointer.dump_state - else None - ) - else: + if self.state is None: self.state = PosteriorMultiStateRepository( size=self.posterior_approximator.ensemble_size, - checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir - if fit_config.checkpointer.dump_state is True - else None, + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) + ) if fit_config.checkpointer.save_checkpoint_dir is not None else None, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ), + checkpoint_type=fit_config.checkpointer.checkpoint_type ) status = [] @@ -182,10 +120,7 @@ def _fit(i): logging.info( f"Run {i+1} out of {self.posterior_approximator.ensemble_size}." ) - state, _status = _fit(i) - self.state.put( - state=state, i=i, keep=fit_config.checkpointer.keep_top_n_checkpoints - ) + self.state.state[i], _status = _fun(i) status.append(_status) logging.info("Fit completed.") return status @@ -205,21 +140,27 @@ def sample(self, rng: Optional[PRNGKeyArray] = None, **kwargs) -> JointState: calib_mutable=state.calib_mutable, ) - def load_state(self, checkpoint_dir: Path) -> None: - try: - self.restore_checkpoint(pathlib.Path(checkpoint_dir) / "0") - except ValueError: - raise ValueError( - f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`." - ) + def load_state( + self, + checkpoint_dir: Path, + keep_top_n_checkpoints: int = 2, + checkpoint_type: str = "last" + ) -> None: self.state = PosteriorMultiStateRepository( size=self.posterior_approximator.ensemble_size, - checkpoint_dir=checkpoint_dir, + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager(checkpoint_dir=checkpoint_dir), + checkpoint_type=checkpoint_type ) def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: + if self.state is None: + raise ValueError( + """No state available. You must first either fit the posterior distribution, or load a + saved checkpoint.""" + ) for i in range(self.posterior_approximator.ensemble_size): - self.state.put(state=self.state.get(i), i=i, keep=keep_top_n_checkpoints) + self.state.put(state=self.state.get(i), i=i, checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints) def _init_map_state( self, state: Optional[MAPState], data_loader: DataLoader, fit_config: FitConfig @@ -265,25 +206,26 @@ def _restore_state_from_somewhere( self, i: int, fit_config: FitConfig, - allowed_states: Optional[Tuple[Type[MAPState], ...]] = None, + allowed_states: Optional[Tuple[Type[PosteriorState], ...]] = None, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, ) -> MAPState: - if fit_config.checkpointer.restore_checkpoint_path is not None: - restore_checkpoint_path = pathlib.Path( - fit_config.checkpointer.restore_checkpoint_path - ) / str(i) - state = self.restore_checkpoint( - restore_checkpoint_path=restore_checkpoint_path, - optimizer=fit_config.optimizer.method, + if checkpoint_manager is not None: + repo = PosteriorStateRepository( + partition_manager=partition_manager, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str(pathlib.Path(getattr(checkpoint_manager, "directory")) / fit_config.checkpointer.checkpoint_type / str(i)), + keep_top_n_checkpoints=checkpoint_manager._options.max_to_keep if checkpoint_manager is not None else None + ), ) - elif fit_config.checkpointer.start_from_current_state is not None: + state = repo.get(optimizer=fit_config.optimizer.method) + elif fit_config.checkpointer.start_from_current_state: state = self.state.get(i=i, optimizer=fit_config.optimizer.method) if allowed_states is not None and not isinstance(state, allowed_states): raise ValueError( f"The type of the restored checkpoint must be within {allowed_states}. " - f"However, {fit_config.checkpointer.restore_checkpoint_path} pointed to a state " - f"with type {type(state)}." + f"However, the restored checkpoint has type {type(state)}." ) - self._check_state(state) return state diff --git a/fortuna/prob_model/posterior/laplace/laplace_posterior.py b/fortuna/prob_model/posterior/laplace/laplace_posterior.py index ef852e6c..f50e0658 100755 --- a/fortuna/prob_model/posterior/laplace/laplace_posterior.py +++ b/fortuna/prob_model/posterior/laplace/laplace_posterior.py @@ -8,27 +8,26 @@ Tuple, Union, ) - +from fortuna.data.loader.base import ShardedPrefetchedLoader from flax.core import FrozenDict -from flax.training.common_utils import ( - shard, - shard_prng_key, -) -import jax +from jax.sharding import PartitionSpec +from jax.experimental.pjit import pjit from jax import ( devices, hessian, jit, lax, pmap, + random, vjp, ) +import jax.scipy as jsp from jax._src.prng import PRNGKeyArray from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.tree_util import tree_map import tqdm - +import pathlib from fortuna.data.loader import ( DataLoader, DeviceDimensionAugmentedLoader, @@ -59,6 +58,8 @@ Params, Status, ) +import pathlib +from fortuna.utils.checkpoint import get_checkpoint_manager from fortuna.utils.freeze import get_trainable_paths from fortuna.utils.nested_dicts import ( nested_get, @@ -67,6 +68,7 @@ ) from fortuna.utils.random import generate_random_normal_like_tree from fortuna.utils.strings import decode_encoded_tuple_of_lists_of_strings_to_array +from fortuna.partitioner.partition_manager.base import PartitionManager class LaplacePosterior(Posterior): @@ -74,6 +76,7 @@ def __init__( self, joint: Joint, posterior_approximator: LaplacePosteriorApproximator, + partition_manager: PartitionManager, ): """ Laplace approximation posterior class. @@ -85,7 +88,7 @@ def __init__( posterior_approximator: LaplacePosteriorApproximator A Laplace posterior approximator. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) if type(joint.prior) not in [DiagonalGaussianPrior, IsotropicGaussianPrior]: raise ValueError( """The Laplace posterior_approximation is not supported for this model. The prior distribution must be one of the @@ -234,14 +237,27 @@ def fit( status = dict() + checkpoint_restorer = ( + get_checkpoint_manager( + str( + pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) + if super()._is_state_available_somewhere(fit_config): state = super()._restore_state_from_somewhere( - fit_config=fit_config, allowed_states=(MAPState, LaplaceState) + fit_config=fit_config, allowed_states=(MAPState, LaplaceState), checkpoint_manager=checkpoint_restorer, _do_reshard=False ) elif super()._should_run_preliminary_map(fit_config, map_fit_config): state, status["map"] = run_preliminary_map( joint=self.joint, + partition_manager=self.partition_manager, train_data_loader=train_data_loader, val_data_loader=val_data_loader, map_fit_config=map_fit_config, @@ -252,7 +268,7 @@ def fit( raise ValueError( "The Laplace approximation must start from a preliminary run of MAP or an existing " "checkpoint or state. Please configure `map_fit_config`, or " - "`fit_config.checkpointer.restore_checkpoint_path`, " + "`fit_config.checkpointer.restore_checkpoint_dir`, " "or `fit_config.checkpointer.start_from_current_state`." ) @@ -282,31 +298,33 @@ def fit( which_params=which_params, ) - if fit_config.checkpointer.save_checkpoint_dir: - self.save_checkpoint( - state, - save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, - keep=fit_config.checkpointer.keep_top_n_checkpoints, - force_save=True, - ) - self.state = PosteriorStateRepository( - fit_config.checkpointer.save_checkpoint_dir + partition_manager=None, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.save_checkpoint_dir is not None + and fit_config.checkpointer.dump_state + else None, ) - self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) + self.state.replace(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") if ( val_data_loader is not None and self.posterior_approximator.tune_prior_log_variance ): - logging.info("Tuning the prior log-variance now") + logging.info("Tuning the prior log-variance now.") opt_prior_log_var = self.prior_log_variance_tuning( val_data_loader=val_data_loader, n_posterior_samples=5, - distribute=fit_config.processor.devices == -1, + shard=fit_config.processor.devices == -1, ) state = state.replace(prior_log_var=opt_prior_log_var) - self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) + self.state.replace(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info(f"Best prior log-variance found: {opt_prior_log_var}") return status @@ -317,7 +335,7 @@ def sample( ) -> JointState: if rng is None: rng = self.rng.get() - state: LaplaceState = self.state.get() + state = self.state.get() if kwargs.get("prior_log_var") is not None: state = state.replace(prior_log_var=kwargs.get("prior_log_var")) @@ -326,9 +344,9 @@ def sample( state._encoded_which_params ) mean, hess_lik_diag = nested_unpair( - state.params.unfreeze(), - which_params, - ("mean", "hess_lik_diag"), + d=state.params.unfreeze(), + key_paths=tuple(which_params), + labels=("mean", "hess_lik_diag"), ) std = self._compute_std( prior_log_var=state.prior_log_var, hess_lik_diag=hess_lik_diag @@ -337,7 +355,7 @@ def sample( noise = generate_random_normal_like_tree(rng, std) params = nested_set( d=mean, - key_paths=which_params, + key_paths=tuple(which_params), objs=tuple( [ tree_map( @@ -392,7 +410,7 @@ def _init_map_state( params=FrozenDict( nested_unpair( d=state.params.unfreeze(), - key_paths=which_params, + key_paths=tuple(which_params), labels=("mean", "hess_lik_diag"), )[0] ) @@ -423,30 +441,73 @@ def _batched_log_prob( prior_log_var: float, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, **kwargs, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]: - import jax.random as random - import jax.scipy as jsp - if rng is None: rng = self.rng.get() keys = random.split(rng, n_posterior_samples) - def _lik_log_batched_prob(key): - sample = self.sample(inputs=batch[0], rng=key, prior_log_var=prior_log_var) + def _lik_log_batched_prob(params, mutable, calib_params, calib_mutable): return self.joint.likelihood._batched_log_prob( - sample.params, + params, batch, - mutable=sample.mutable, - calib_params=sample.calib_params, - calib_mutable=sample.calib_mutable, + mutable=mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, **kwargs, ) + if shard and self.partition_manager.shardings is not None: + _lik_log_batched_prob = pjit( + _lik_log_batched_prob, + in_shardings=( + self.partition_manager.shardings.params, + self.partition_manager.shardings.mutable, + self.partition_manager.shardings.calib_params, + self.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("dp", "fsdp")), + ) + else: + _lik_log_batched_prob = jit(_lik_log_batched_prob) + + def _fun(key): + sample = self.sample(inputs=batch[0], rng=key, prior_log_var=prior_log_var) + with self.partition_manager.partitioner.mesh: + return _lik_log_batched_prob( + sample.params, + sample.mutable, + sample.calib_params, + sample.calib_mutable, + ) + return jsp.special.logsumexp( - lax.map(_lik_log_batched_prob, keys), axis=0 + jnp.stack(list(map(_fun, keys))), axis=0 ) - jnp.log(n_posterior_samples) + def _log_prob( + self, + data_loader: DataLoader, + prior_log_var: float, + n_posterior_samples: int = 30, + rng: Optional[PRNGKeyArray] = None, + shard: bool = True, + **kwargs, + ) -> jnp.ndarray: + if rng is None: + rng = self.rng.get() + + if shard and self.partition_manager.shardings is not None: + data_loader = ShardedPrefetchedLoader( + loader=data_loader, partition_manager=self.partition_manager + ) + + def fun2(_data): + return self._batched_log_prob(_data, prior_log_var, n_posterior_samples, rng, shard, **kwargs) + + return jnp.concatenate([fun2(data) for data in data_loader], 0) + def prior_log_variance_tuning( self, val_data_loader: DataLoader, @@ -455,7 +516,7 @@ def prior_log_variance_tuning( min_prior_log_var: float = -3, max_prior_log_var: float = 3, grid_size: int = 20, - distribute: bool = False, + shard: bool = False, ) -> jnp.ndarray: if mode == "cv": return self._prior_log_variance_tuning_cv( @@ -464,7 +525,7 @@ def prior_log_variance_tuning( min_prior_log_var, max_prior_log_var, grid_size, - distribute, + shard, ) elif mode == "marginal_lik": raise NotImplementedError( @@ -480,30 +541,19 @@ def _prior_log_variance_tuning_cv( min_prior_log_var: float, max_prior_log_var: float, grid_size: int, - distribute: bool, + shard: bool, ) -> jnp.ndarray: best = None - candidates = list( - jnp.linspace(min_prior_log_var, max_prior_log_var, grid_size) - ) + [jnp.array(self.joint.prior.log_var)] - if distribute: - rng = shard_prng_key(jax.random.PRNGKey(0)) - val_data_loader = DeviceDimensionAugmentedLoader(val_data_loader) - candidates = [shard(c) for c in candidates] - fn = pmap(self._batched_log_prob, static_broadcasted_argnums=(2,)) - else: - fn = jit(self._batched_log_prob, static_argnums=(2,)) + candidates = jnp.concatenate((jnp.linspace(min_prior_log_var, max_prior_log_var, grid_size), jnp.array([self.joint.prior.log_var]))) for lpv in tqdm.tqdm(candidates, desc="Tuning prior log-var"): neg_log_prob = -jnp.sum( - jnp.concatenate( - [ - self.joint.likelihood._unshard_array( - fn(batch, lpv, n_posterior_samples, rng) - ) - for batch in val_data_loader - ], - 0, + self._log_prob( + data_loader=val_data_loader, + prior_log_var=lpv, + n_posterior_samples=n_posterior_samples, + rng=self.rng.get(), + shard=shard ) ) if best is None or neg_log_prob < best[-1]: @@ -511,3 +561,100 @@ def _prior_log_variance_tuning_cv( opt_prior_log_var = best[0].reshape() return opt_prior_log_var + + # def _batched_log_prob( + # self, + # batch, + # prior_log_var: float, + # n_posterior_samples: int = 30, + # rng: Optional[PRNGKeyArray] = None, + # distribute: bool = True, + # **kwargs, + # ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]: + # import jax.random as random + # import jax.scipy as jsp + # + # if rng is None: + # rng = self.rng.get() + # keys = random.split(rng, n_posterior_samples) + # + # def _eval_batched_log_prob(params, mutable, calib_params, calib_mutable): + # return self.joint.likelihood._batched_log_prob( + # params, + # batch, + # mutable=mutable, + # calib_params=calib_params, + # calib_mutable=calib_mutable, + # **kwargs, + # ) + # + # if distribute: + # _eval_batched_log_prob = pmap(_eval_batched_log_prob) + # else: + # _eval_batched_log_prob = jit(_eval_batched_log_prob) + # + # def fun(key): + # sample = self.sample(inputs=batch[0], rng=key, prior_log_var=prior_log_var) + # return _eval_batched_log_prob(sample.params, sample.mutable, sample.calib_params, sample.calib_mutable) + # + # return jsp.special.logsumexp( + # lax.map(fun, keys), axis=0 + # ) - jnp.log(n_posterior_samples) + # + # def prior_log_variance_tuning( + # self, + # val_data_loader: DataLoader, + # n_posterior_samples: int = 10, + # mode: str = "cv", + # min_prior_log_var: float = -3, + # max_prior_log_var: float = 3, + # grid_size: int = 20, + # distribute: bool = False, + # ) -> jnp.ndarray: + # if mode == "cv": + # return self._prior_log_variance_tuning_cv( + # val_data_loader, + # n_posterior_samples, + # min_prior_log_var, + # max_prior_log_var, + # grid_size, + # distribute, + # ) + # elif mode == "marginal_lik": + # raise NotImplementedError( + # f"Optimizing the prior log variance via marginal likelihood maximization is not yet available." + # ) + # else: + # raise ValueError(f"Unrecognized mode={mode} for prior log variance tuning.") + # + # def _prior_log_variance_tuning_cv( + # self, + # val_data_loader: DataLoader, + # n_posterior_samples: int, + # min_prior_log_var: float, + # max_prior_log_var: float, + # grid_size: int, + # distribute: bool, + # ) -> jnp.ndarray: + # best = None + # candidates = jnp.concatenate((jnp.linspace(min_prior_log_var, max_prior_log_var, grid_size), jnp.array([self.joint.prior.log_var]))) + # if distribute: + # val_data_loader = DeviceDimensionAugmentedLoader(val_data_loader) + # + # for lpv in tqdm.tqdm(candidates, desc="Tuning prior log-var"): + # neg_log_prob = -jnp.sum( + # jnp.concatenate( + # [ + # self.joint.likelihood._unshard_array( + # self._batched_log_prob(batch, lpv, n_posterior_samples, self.rng.get()) + # ) + # for batch in val_data_loader + # ], + # 0, + # ) + # ) + # if best is None or neg_log_prob < best[-1]: + # best = (lpv, neg_log_prob) + # + # opt_prior_log_var = best[0].reshape() + # return opt_prior_log_var diff --git a/fortuna/prob_model/posterior/map/map_posterior.py b/fortuna/prob_model/posterior/map/map_posterior.py index c86be8c1..87cdd1fd 100755 --- a/fortuna/prob_model/posterior/map/map_posterior.py +++ b/fortuna/prob_model/posterior/map/map_posterior.py @@ -1,9 +1,12 @@ import logging +import pathlib from typing import Optional +from jax import eval_shape from jax._src.prng import PRNGKeyArray from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState @@ -12,16 +15,15 @@ from fortuna.prob_model.posterior.map.map_approximator import MAPPosteriorApproximator from fortuna.prob_model.posterior.map.map_state import MAPState from fortuna.prob_model.posterior.map.map_trainer import ( - JittedMAPTrainer, MAPTrainer, - MultiDeviceMAPTrainer, + ShardedMAPTrainer, ) from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) -from fortuna.typing import Status +from fortuna.typing import Status, Shape from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype -from fortuna.utils.device import select_trainer_given_devices +from fortuna.utils.checkpoint import get_checkpoint_manager logger = logging.getLogger(__name__) @@ -31,6 +33,7 @@ def __init__( self, joint: Joint, posterior_approximator: MAPPosteriorApproximator, + partition_manager: PartitionManager, ): """ Maximum-a-Posteriori (MAP) approximate posterior class. @@ -41,8 +44,14 @@ def __init__( A Joint distribution object. posterior_approximator: MAPPosteriorApproximator A MAP posterior approximator. + partition_manager: PartitionManager + An object to manage partitions. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__( + joint=joint, + posterior_approximator=posterior_approximator, + partition_manager=partition_manager, + ) def __str__(self): return MAP_NAME @@ -57,16 +66,17 @@ def fit( ) -> Status: super()._checks_on_fit_start(fit_config, map_fit_config) - trainer_cls = select_trainer_given_devices( - devices=fit_config.processor.devices, - base_trainer_cls=MAPTrainer, - jitted_trainer_cls=JittedMAPTrainer, - multi_device_trainer_cls=MultiDeviceMAPTrainer, - disable_jit=fit_config.processor.disable_jit, + trainer_cls = ( + ShardedMAPTrainer if not fit_config.processor.disable_jit else MAPTrainer ) trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + fit_config.checkpointer.save_checkpoint_dir, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ), save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, @@ -78,17 +88,39 @@ def fit( freeze_fun=fit_config.optimizer.freeze_fun, ) - if super()._is_state_available_somewhere(fit_config): + checkpoint_restorer = ( + get_checkpoint_manager( + str( + pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) + + if self._is_state_available_somewhere(fit_config): state = self._restore_state_from_somewhere( fit_config=fit_config, allowed_states=(MAPState,), + partition_manager=self.partition_manager, + checkpoint_manager=checkpoint_restorer, ) + state = self._freeze_optimizer_in_state(state, fit_config) + self.partition_manager.shapes_dtypes = eval_shape(lambda: state) else: - state = self._init_state( - data_loader=train_data_loader, fit_config=fit_config - ) + input_shape = train_data_loader.input_shape - state = super()._freeze_optimizer_in_state(state, fit_config) + def init_state_fn(rng): + _state = self._init_state( + input_shape=input_shape, fit_config=fit_config, rng=rng + ) + return self._freeze_optimizer_in_state(_state, fit_config) + + state = self.partition_manager.init_sharded_state( + init_state_fn, self.rng.get() + ) self._check_state(state) logging.info("Run MAP.") @@ -96,11 +128,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, @@ -110,11 +142,20 @@ def fit( gradient_accumulation_steps=fit_config.hyperparameters.gradient_accumulation_steps, ) self.state = PosteriorStateRepository( - fit_config.checkpointer.save_checkpoint_dir - if fit_config.checkpointer.dump_state is True - else None + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.save_checkpoint_dir is not None + and fit_config.checkpointer.dump_state + else None, ) - self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) + if self.state.checkpoint_manager is None: + self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") return status @@ -127,9 +168,13 @@ def sample(self, rng: Optional[PRNGKeyArray] = None, **kwargs) -> JointState: calib_mutable=state.calib_mutable, ) - def _init_state(self, data_loader: DataLoader, fit_config: FitConfig) -> MAPState: - state = super()._init_joint_state(data_loader=data_loader) - + def _init_state( + self, + input_shape: Shape, + fit_config: FitConfig, + rng: Optional[PRNGKeyArray] = None, + ) -> MAPState: + state = super()._init_joint_state(input_shape=input_shape, rng=rng) return MAPState.init( params=state.params, mutable=state.mutable, diff --git a/fortuna/prob_model/posterior/map/map_trainer.py b/fortuna/prob_model/posterior/map/map_trainer.py index d1d7d955..a9d0d686 100644 --- a/fortuna/prob_model/posterior/map/map_trainer.py +++ b/fortuna/prob_model/posterior/map/map_trainer.py @@ -17,10 +17,9 @@ from fortuna.prob_model.posterior.map import * from fortuna.prob_model.posterior.map.map_state import MAPState from fortuna.prob_model.posterior.posterior_trainer import PosteriorTrainerABC -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, -) +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin +from fortuna.training.mixins.sharding import ShardingMixin from fortuna.typing import ( Array, Batch, @@ -111,3 +110,7 @@ class JittedMAPTrainer(JittedMixin, MAPTrainer): class MultiDeviceMAPTrainer(MultiDeviceMixin, MAPTrainer): pass + + +class ShardedMAPTrainer(ShardingMixin, MAPTrainer): + pass diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py index 8f06738e..5040a37c 100755 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py @@ -14,7 +14,6 @@ from jax._src.prng import PRNGKeyArray from jax.flatten_util import ravel_pytree import jax.numpy as jnp -import numpy as np from fortuna.data.loader import ( DataLoader, @@ -39,6 +38,8 @@ JittedADVITrainer, MultiDeviceADVITrainer, ) +import pathlib +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) @@ -59,6 +60,7 @@ nested_unpair, ) from fortuna.utils.strings import decode_encoded_tuple_of_lists_of_strings_to_array +from fortuna.utils.checkpoint import get_checkpoint_manager class ADVIPosterior(Posterior): @@ -66,6 +68,7 @@ def __init__( self, joint: Joint, posterior_approximator: ADVIPosteriorApproximator, + partition_manager: PartitionManager, ): """ Automatic Differentiation Variational Inference (ADVI) approximate posterior class. @@ -77,7 +80,7 @@ def __init__( posterior_approximator: ADVI An ADVI posterior approximator. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) self._base = None self._architecture = None self._unravel = None @@ -98,20 +101,34 @@ def fit( status = dict() + checkpoint_restorer = ( + get_checkpoint_manager( + str( + pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) + if super()._is_state_available_somewhere(fit_config): state = super()._restore_state_from_somewhere( - fit_config=fit_config, allowed_states=(MAPState, ADVIState) + fit_config=fit_config, allowed_states=(MAPState, ADVIState), checkpoint_manager=checkpoint_restorer, _do_reshard=False ) elif super()._should_run_preliminary_map(fit_config, map_fit_config): state, status["map"] = run_preliminary_map( joint=self.joint, + partition_manager=self.partition_manager, train_data_loader=train_data_loader, val_data_loader=val_data_loader, map_fit_config=map_fit_config, rng=self.rng, **kwargs, ) + self.partition_manager.shardings = None else: state = None @@ -143,6 +160,10 @@ def fit( trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, + checkpoint_manager=get_checkpoint_manager( + fit_config.checkpointer.save_checkpoint_dir, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ), save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, @@ -172,11 +193,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, @@ -189,11 +210,20 @@ def fit( ) self.state = PosteriorStateRepository( - fit_config.checkpointer.save_checkpoint_dir - if fit_config.checkpointer.dump_state is True - else None + partition_manager=None, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.save_checkpoint_dir is not None + and fit_config.checkpointer.dump_state + else None, ) - self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) + if self.state.checkpoint_manager is None: + self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") return status @@ -250,7 +280,7 @@ def sample( ravel_pytree( nested_unpair( d=state.params.unfreeze(), - key_paths=which_params, + key_paths=tuple(which_params), labels=("mean", "log_std"), )[1] )[0] @@ -260,7 +290,7 @@ def sample( FrozenDict( nested_unpair( d=state.params.unfreeze(), - key_paths=which_params, + key_paths=tuple(which_params), labels=("mean", "log_std"), )[0] if which_params @@ -298,7 +328,7 @@ def sample( ) means, log_stds = nested_unpair( d=state.params.unfreeze(), - key_paths=which_params, + key_paths=tuple(which_params), labels=("mean", "log_std"), ) rav_params = { @@ -338,7 +368,7 @@ def _init_map_state( ) means, log_stds = nested_unpair( d=state.params.unfreeze(), - key_paths=which_params, + key_paths=tuple(which_params), labels=("mean", "log_std"), ) means, log_stds = FrozenDict(means), FrozenDict(log_stds) @@ -412,7 +442,7 @@ def unravel(_rav): return FrozenDict( nested_set( d=params.unfreeze(), - key_paths=which_params, + key_paths=tuple(which_params), objs=sub_unravel(_rav), ) ) @@ -421,7 +451,7 @@ def unravel(_rav): rav_log_stds = ravel_pytree( nested_set( d={}, - key_paths=which_params, + key_paths=tuple(which_params), objs=tuple( [nested_get(log_stds, path) for path in which_params] ), diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py index 995276dd..0f97e42a 100644 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py @@ -8,7 +8,7 @@ from jax._src.prng import PRNGKeyArray from jax.flatten_util import ravel_pytree from jax.tree_util import tree_map - +import pathlib from fortuna.data.loader import DataLoader from fortuna.prob_model.posterior.normalizing_flow.advi import ADVI_NAME from fortuna.prob_model.posterior.normalizing_flow.advi.advi_state import ADVIState @@ -18,10 +18,8 @@ from fortuna.prob_model.posterior.normalizing_flow.normalizing_flow_trainer import ( NormalizingFlowTrainer, ) -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, -) +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin from fortuna.typing import ( Params, Path, @@ -47,21 +45,24 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: state = state.replace( params=self._unravel_params(state.params), frozen_params=FrozenDict(), _encoded_which_params=self._encoded_which_params, ) - super().save_checkpoint(state, save_checkpoint_dir, keep, force_save, prefix) + super().save_checkpoint(state, save_checkpoint_dir, keep, force_save) def on_train_end(self, state: NormalizingFlowState) -> NormalizingFlowState: self.save_checkpoint( state, - save_checkpoint_dir=self.save_checkpoint_dir, + save_checkpoint_dir=str(pathlib.Path(self.save_checkpoint_dir) / "last") + if self.save_checkpoint_dir is not None + else None, keep=self.keep_top_n_checkpoints, force_save=True, + prefix="", ) state = state.replace( @@ -96,7 +97,7 @@ def _unravel_params( def on_train_start( self, state: NormalizingFlowState, - dataloaders: List[DataLoader], + data_loaders: List[DataLoader], rng: PRNGKeyArray, ) -> Tuple[NormalizingFlowState, List[DataLoader], PRNGKeyArray]: if self.freeze_fun is not None: @@ -152,7 +153,7 @@ def on_train_start( ), ) - return state, dataloaders, rng + return state, data_loaders, rng class JittedADVITrainer(JittedMixin, ADVITrainer): diff --git a/fortuna/prob_model/posterior/normalizing_flow/normalizing_flow_trainer.py b/fortuna/prob_model/posterior/normalizing_flow/normalizing_flow_trainer.py index be6a237e..2dd6fd02 100644 --- a/fortuna/prob_model/posterior/normalizing_flow/normalizing_flow_trainer.py +++ b/fortuna/prob_model/posterior/normalizing_flow/normalizing_flow_trainer.py @@ -32,8 +32,9 @@ CalibParams, Mutable, Params, - Path, ) +from fortuna.partitioner.partition_manager.base import PartitionManager +from orbax.checkpoint import CheckpointManager from fortuna.utils.strings import encode_tuple_of_lists_of_strings_to_numpy @@ -73,9 +74,14 @@ def __init__( which_params: Optional[Tuple[List[str]]], unravel: Callable, sub_unravel: Callable, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, **kwargs, ): - super(NormalizingFlowTrainer, self).__init__(**kwargs) + super(NormalizingFlowTrainer, self).__init__( + partition_manager=partition_manager, + checkpoint_manager=checkpoint_manager, + **kwargs) # base distribution self.sample_base = base.sample self.base_log_joint_prob = base.log_joint_prob diff --git a/fortuna/prob_model/posterior/posterior_mixin.py b/fortuna/prob_model/posterior/posterior_mixin.py index d73fd0b7..8390b10e 100644 --- a/fortuna/prob_model/posterior/posterior_mixin.py +++ b/fortuna/prob_model/posterior/posterior_mixin.py @@ -1,8 +1,8 @@ from typing import Optional from fortuna.prob_model.posterior.name_to_posterior_state import NameToPosteriorState -from fortuna.prob_model.posterior.state import PosteriorState -from fortuna.training.mixin import WithCheckpointingMixin +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.name_to_train_state import NameToTrainState from fortuna.typing import ( OptaxOptimizer, Path, @@ -12,15 +12,22 @@ class WithPosteriorCheckpointingMixin(WithCheckpointingMixin): def restore_checkpoint( self, - restore_checkpoint_path: Path, + restore_checkpoint_dir: Path, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - name_to_train_state: NameToPosteriorState = NameToPosteriorState, - **kwargs, - ) -> PosteriorState: + name_to_train_state: NameToTrainState = NameToPosteriorState, + ): return super().restore_checkpoint( - restore_checkpoint_path, - optimizer, - prefix, + restore_checkpoint_dir=restore_checkpoint_dir, + optimizer=optimizer, + name_to_train_state=name_to_train_state, + ) + + def get_shapes_dtypes_checkpoint( + self, + restore_checkpoint_dir: Optional[Path] = None, + name_to_train_state: NameToTrainState = NameToPosteriorState, + ): + return super().get_shapes_dtypes_checkpoint( + restore_checkpoint_dir=restore_checkpoint_dir, name_to_train_state=name_to_train_state, ) diff --git a/fortuna/prob_model/posterior/posterior_multi_state_repository.py b/fortuna/prob_model/posterior/posterior_multi_state_repository.py index 1cca1e86..733d305b 100644 --- a/fortuna/prob_model/posterior/posterior_multi_state_repository.py +++ b/fortuna/prob_model/posterior/posterior_multi_state_repository.py @@ -14,33 +14,51 @@ OptaxOptimizer, Path, ) +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.utils.checkpoint import get_checkpoint_manager +import pathlib +from orbax.checkpoint import CheckpointManager class PosteriorMultiStateRepository: - def __init__(self, size: int, checkpoint_dir: Optional[Path] = None): + def __init__( + self, + size: int, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + checkpoint_type: Optional[str] = "last" + ): self.size = size - self.state = [ - PosteriorStateRepository( - checkpoint_dir=os.path.join(checkpoint_dir, str(i)) - if checkpoint_dir - else None + self.state = [] + for i in range(size): + if checkpoint_manager is not None: + path = pathlib.Path(checkpoint_manager.directory) / str(i) + if checkpoint_type is not None: + path = path / checkpoint_type + path = str(path) + else: + path = None + self.state.append( + PosteriorStateRepository( + partition_manager=partition_manager, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=path, + keep_top_n_checkpoints=checkpoint_manager._options.max_to_keep if checkpoint_manager is not None else None + ) + ) ) - for i in range(size) - ] def get( self, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", **kwargs, ) -> Union[List[PosteriorState], PosteriorState]: def _get(_i): return self.state[_i].get( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, **kwargs, ) @@ -55,13 +73,12 @@ def put( self, state: PosteriorState, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, keep: int = 1, - prefix: str = "checkpoint_", ) -> None: def _put(_i): return self.state[_i].put( - state=state, checkpoint_path=checkpoint_path, keep=keep, prefix=prefix + state=state, checkpoint_dir=checkpoint_dir, keep=keep ) if i is not None: @@ -73,17 +90,13 @@ def _put(_i): def pull( self, i: int = None, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, ) -> PosteriorState: def _pull(_i): return self.state[_i].pull( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, - **kwargs, ) if i is not None: @@ -97,20 +110,16 @@ def update( self, variables: Dict, i: int = None, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, keep: int = 1, - prefix: str = "checkpoint_", - **kwargs, ): def _update(_i): self.state[_i].update( variables=variables, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, keep=keep, - prefix=prefix, - **kwargs, ) if i is not None: @@ -123,13 +132,11 @@ def extract( self, keys: List[str], i: int = None, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", - **kwargs, + checkpoint_dir: Optional[Path] = None, ) -> Union[Dict, List[Dict]]: def _extract(_i): return self.state[_i].extract( - keys=keys, checkpoint_path=checkpoint_path, prefix=prefix, **kwargs + keys=keys, checkpoint_dir=checkpoint_dir ) if i is not None: @@ -141,10 +148,8 @@ def _extract(_i): def extract_calib_keys( self, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", - **kwargs, + checkpoint_dir: Optional[Path] = None, ) -> Dict: return self.extract( - ["calib_params", "calib_mutable"], 0, checkpoint_path, prefix, **kwargs + ["calib_params", "calib_mutable"], 0, checkpoint_dir ) diff --git a/fortuna/prob_model/posterior/posterior_state_repository.py b/fortuna/prob_model/posterior/posterior_state_repository.py index 1d907940..6496c3a9 100644 --- a/fortuna/prob_model/posterior/posterior_state_repository.py +++ b/fortuna/prob_model/posterior/posterior_state_repository.py @@ -11,10 +11,6 @@ class PosteriorStateRepository(WithPosteriorCheckpointingMixin, TrainStateRepository): def extract_calib_keys( self, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", - **kwargs, + checkpoint_dir: Optional[Path] = None, ) -> Dict: - return super().extract( - ["calib_params", "calib_mutable"], checkpoint_path, prefix, **kwargs - ) + return super().extract(["calib_params", "calib_mutable"], checkpoint_dir) diff --git a/fortuna/prob_model/posterior/run_preliminary_map.py b/fortuna/prob_model/posterior/run_preliminary_map.py index 45c0282c..64445e63 100644 --- a/fortuna/prob_model/posterior/run_preliminary_map.py +++ b/fortuna/prob_model/posterior/run_preliminary_map.py @@ -12,10 +12,12 @@ from fortuna.prob_model.posterior.map.map_state import MAPState from fortuna.typing import Status from fortuna.utils.random import RandomNumberGenerator +from fortuna.partitioner.partition_manager.base import PartitionManager def run_preliminary_map( joint: Joint, + partition_manager: PartitionManager, train_data_loader: DataLoader, val_data_loader: DataLoader, map_fit_config: Optional[FitConfig], @@ -24,7 +26,7 @@ def run_preliminary_map( ) -> Tuple[MAPState, Status]: logging.info("Do a preliminary run of MAP.") map_posterior = MAPPosterior( - joint, posterior_approximator=MAPPosteriorApproximator() + joint, posterior_approximator=MAPPosteriorApproximator(), partition_manager=partition_manager ) map_posterior.rng = rng status = map_posterior.fit( diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py index d8f25150..d7169789 100644 --- a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py @@ -13,6 +13,7 @@ MAPTrainer, MultiDeviceMAPTrainer, ) +import orbax from fortuna.prob_model.posterior.run_preliminary_map import run_preliminary_map from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import CYCLICAL_SGLD_NAME from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator import ( @@ -31,6 +32,7 @@ from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior_state_repository import ( SGMCMCPosteriorStateRepository, ) +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.typing import Status from fortuna.utils.device import select_trainer_given_devices from fortuna.utils.freeze import get_trainable_paths @@ -38,6 +40,7 @@ nested_get, nested_set, ) +from fortuna.utils.checkpoint import get_checkpoint_manager logger = logging.getLogger(__name__) @@ -47,6 +50,7 @@ def __init__( self, joint: Joint, posterior_approximator: CyclicalSGLDPosteriorApproximator, + partition_manager: PartitionManager, ): """ Cyclical Stochastic Gradient Langevin Dynamics (SGLD) approximate posterior class. @@ -58,7 +62,7 @@ def __init__( posterior_approximator: CyclicalSGLDPosteriorApproximator A cyclical SGLD posterior approximator. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) def __str__(self): return CYCLICAL_SGLD_NAME @@ -88,6 +92,7 @@ def fit( ) and super()._should_run_preliminary_map(fit_config, map_fit_config): map_state, status["map"] = run_preliminary_map( joint=self.joint, + partition_manager=self.partition_manager, train_data_loader=train_data_loader, val_data_loader=val_data_loader, map_fit_config=map_fit_config, @@ -114,14 +119,11 @@ def fit( disable_jit=fit_config.processor.disable_jit, ) - save_checkpoint_dir = ( - pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "c" - if fit_config.checkpointer.save_checkpoint_dir - else None - ) trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, - save_checkpoint_dir=save_checkpoint_dir, + partition_manager=self.partition_manager, + checkpoint_manager=None, + save_checkpoint_dir=None, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, @@ -132,8 +134,14 @@ def fit( freeze_fun=fit_config.optimizer.freeze_fun, ) + checkpoint_restorer = ( + get_checkpoint_manager(str(pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) / "chain"), keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) + if super()._is_state_available_somewhere(fit_config): - state = self._restore_state_from_somewhere(fit_config=fit_config) + state = self._restore_state_from_somewhere(fit_config=fit_config, checkpoint_manager=checkpoint_restorer, _do_reshard=False) else: state = self._init_map_state(map_state, train_data_loader, fit_config) @@ -154,11 +162,21 @@ def fit( self.state = SGMCMCPosteriorStateRepository( size=self.posterior_approximator.n_samples, - checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, + checkpoint_manager=get_checkpoint_manager( + str(pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "chain"), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) if fit_config.checkpointer.save_checkpoint_dir is not None else None, + checkpoint_type=None, which_params=which_params, all_params=state.params if which_params else None, ) + if fit_config.checkpointer.save_checkpoint_dir is not None and fit_config.optimizer.freeze_fun is not None: + all_params_checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) + all_params_checkpointer.save( + str(pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "all/0/default"), self.state._all_params + ) + cyclical_sampling_callback = CyclicalSGLDSamplingCallback( n_epochs=fit_config.optimizer.n_epochs, n_training_steps=len(train_data_loader), @@ -176,11 +194,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py index c9dfa12e..30ad00e2 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py @@ -1,8 +1,6 @@ from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import ( SGMCMCSamplingCallback, ) -from fortuna.training.callback import Callback -from fortuna.training.train_state import TrainState from fortuna.training.train_state_repository import TrainStateRepository from fortuna.training.trainer import TrainerABC diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py index 4641a71d..71275b88 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py @@ -2,18 +2,14 @@ import pathlib from typing import Optional +import orbax.checkpoint from flax.core import FrozenDict from fortuna.data.loader import DataLoader from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint -from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_trainer import SGMCMCTrainer, JittedSGMCMCTrainer, MultiDeviceSGMCMCTrainer from fortuna.prob_model.posterior.map.map_state import MAPState -from fortuna.prob_model.posterior.map.map_trainer import ( - JittedMAPTrainer, - MAPTrainer, - MultiDeviceMAPTrainer, -) from fortuna.prob_model.posterior.run_preliminary_map import run_preliminary_map from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator import ( @@ -35,6 +31,8 @@ nested_get, nested_set, ) +from fortuna.utils.checkpoint import get_checkpoint_manager +from fortuna.partitioner.partition_manager.base import PartitionManager logger = logging.getLogger(__name__) @@ -44,6 +42,7 @@ def __init__( self, joint: Joint, posterior_approximator: SGHMCPosteriorApproximator, + partition_manager: PartitionManager, ): """ Stochastic Gradient Hamiltonian Monte Carlo approximate posterior class. @@ -55,7 +54,7 @@ def __init__( posterior_approximator: SGHMCPosteriorApproximator A SGHMC posterior approximator. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) def __str__(self): return SGHMC_NAME @@ -85,6 +84,7 @@ def fit( ) and super()._should_run_preliminary_map(fit_config, map_fit_config): map_state, status["map"] = run_preliminary_map( joint=self.joint, + partition_manager=self.partition_manager, train_data_loader=train_data_loader, val_data_loader=val_data_loader, map_fit_config=map_fit_config, @@ -105,20 +105,17 @@ def fit( trainer_cls = select_trainer_given_devices( devices=fit_config.processor.devices, - base_trainer_cls=MAPTrainer, - jitted_trainer_cls=JittedMAPTrainer, - multi_device_trainer_cls=MultiDeviceMAPTrainer, + base_trainer_cls=SGMCMCTrainer, + jitted_trainer_cls=JittedSGMCMCTrainer, + multi_device_trainer_cls=MultiDeviceSGMCMCTrainer, disable_jit=fit_config.processor.disable_jit, ) - save_checkpoint_dir = ( - pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "c" - if fit_config.checkpointer.save_checkpoint_dir - else None - ) trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, - save_checkpoint_dir=save_checkpoint_dir, + partition_manager=self.partition_manager, + checkpoint_manager=None, + save_checkpoint_dir=None, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, @@ -129,8 +126,14 @@ def fit( freeze_fun=fit_config.optimizer.freeze_fun, ) + checkpoint_restorer = ( + get_checkpoint_manager(str(pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) / "chain"), keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) + if super()._is_state_available_somewhere(fit_config): - state = self._restore_state_from_somewhere(fit_config=fit_config) + state = self._restore_state_from_somewhere(fit_config=fit_config, checkpoint_manager=checkpoint_restorer, _do_reshard=False) else: state = self._init_map_state(map_state, train_data_loader, fit_config) @@ -151,11 +154,21 @@ def fit( self.state = SGMCMCPosteriorStateRepository( size=self.posterior_approximator.n_samples, - checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, + checkpoint_manager=get_checkpoint_manager( + str(pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "chain"), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) if fit_config.checkpointer.save_checkpoint_dir is not None else None, + checkpoint_type=None, which_params=which_params, all_params=state.params if which_params else None, ) + if fit_config.checkpointer.save_checkpoint_dir is not None and fit_config.optimizer.freeze_fun is not None: + all_params_checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) + all_params_checkpointer.save( + str(pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "all/0/default"), self.state._all_params + ) + sghmc_sampling_callback = SGHMCSamplingCallback( n_epochs=fit_config.optimizer.n_epochs, n_training_steps=len(train_data_loader), @@ -172,11 +185,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py index ac48e64c..a2b98d18 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py @@ -4,13 +4,14 @@ Tuple, Type, ) - +from flax.traverse_util import path_aware_map, flatten_dict from jax import ( pure_callback, random, ) +from flax.core import FrozenDict from jax._src.prng import PRNGKeyArray - +import orbax from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.state import JointState from fortuna.prob_model.posterior.base import Posterior @@ -18,8 +19,11 @@ from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior_state_repository import ( SGMCMCPosteriorStateRepository, ) +from fortuna.prob_model.posterior.posterior_state_repository import PosteriorStateRepository +from fortuna.utils.checkpoint import get_checkpoint_manager +from fortuna.partitioner.partition_manager.base import PartitionManager +from orbax.checkpoint import CheckpointManager from fortuna.typing import Path -from fortuna.utils.strings import decode_encoded_tuple_of_lists_of_strings_to_array class SGMCMCPosterior(Posterior): @@ -58,41 +62,59 @@ def sample( calib_mutable=state.calib_mutable, ) - def load_state(self, checkpoint_dir: Path) -> None: - try: - state = self.restore_checkpoint(pathlib.Path(checkpoint_dir) / "c") - except ValueError: - raise ValueError( - f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`." - ) - which_params = decode_encoded_tuple_of_lists_of_strings_to_array( - state._encoded_which_params - ) + def load_state( + self, + checkpoint_dir: Path, + keep_top_n_checkpoints: int = 2, + checkpoint_type: str = "last" + ) -> None: + path = pathlib.Path(checkpoint_dir) + all_params_path = path / "all/0" + all_params_checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) self.state = SGMCMCPosteriorStateRepository( size=self.posterior_approximator.n_samples, - checkpoint_dir=checkpoint_dir, - which_params=which_params, - all_params=state.params if which_params else None, + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager(checkpoint_dir=str(path / "chain")), + checkpoint_type=None, ) + if all_params_path.exists(): + self.state._which_params = tuple([list(p) for p in flatten_dict(path_aware_map(lambda p, v: p, self.state.get(0).params)).keys()]) + self.state._all_params = FrozenDict(all_params_checkpointer.restore(path / "all/0/default")) def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: + if self.state is None: + raise ValueError( + """No state available. You must first either fit the posterior distribution, or load a + saved checkpoint.""" + ) for i in range(self.posterior_approximator.n_samples): - self.state.put(state=self.state.get(i), i=i, keep=keep_top_n_checkpoints) + self.state.put(state=self.state.get(i), i=i, checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints) def _restore_state_from_somewhere( - self, - fit_config: FitConfig, - allowed_states: Optional[Tuple[Type[MAPState], ...]] = None, + self, + fit_config: FitConfig, + allowed_states: Optional[Tuple[Type[MAPState], ...]] = None, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + _do_reshard: bool = True ) -> MAPState: - if fit_config.checkpointer.restore_checkpoint_path is not None: - restore_checkpoint_path = ( - pathlib.Path(fit_config.checkpointer.restore_checkpoint_path) / "c" - ) - state = self.restore_checkpoint( - restore_checkpoint_path=restore_checkpoint_path, - optimizer=fit_config.optimizer.method, + if checkpoint_manager is not None: + repo = PosteriorStateRepository( + partition_manager=None, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str(pathlib.Path(checkpoint_manager.directory) / str(checkpoint_manager.latest_step())), + keep_top_n_checkpoints=checkpoint_manager._options.max_to_keep + ), ) - elif fit_config.checkpointer.start_from_current_state is not None: + state = repo.get(optimizer=fit_config.optimizer.method, _do_reshard=_do_reshard) + + if fit_config.checkpointer.restore_checkpoint_dir is not None and fit_config.optimizer.freeze_fun is not None: + all_params_checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) + all_params_path = pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) / "all/0/default" + if all_params_path.exists(): + state = state.replace(params=FrozenDict(all_params_checkpointer.restore(all_params_path))) + + elif fit_config.checkpointer.start_from_current_state: state = self.state.get( i=self.state.size - 1, optimizer=fit_config.optimizer.method, @@ -101,8 +123,7 @@ def _restore_state_from_somewhere( if allowed_states is not None and not isinstance(state, allowed_states): raise ValueError( f"The type of the restored checkpoint must be within {allowed_states}. " - f"However, {fit_config.checkpointer.restore_checkpoint_path} pointed to a state " - f"with type {type(state)}." + f"However, the restored checkpoint has type {type(state)}." ) self._check_state(state) diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py index be343127..d768079c 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py @@ -22,33 +22,35 @@ nested_get, nested_set, ) +from fortuna.partitioner.partition_manager.base import PartitionManager +from orbax.checkpoint import CheckpointManager class SGMCMCPosteriorStateRepository(PosteriorMultiStateRepository): def __init__( self, size: int, - checkpoint_dir: Optional[Path] = None, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + checkpoint_type: Optional[str] = "last", all_params: Optional[Params] = None, which_params: Optional[Tuple[List[AnyKey], ...]] = None, ): - super().__init__(size=size, checkpoint_dir=checkpoint_dir) + super().__init__(size=size, checkpoint_manager=checkpoint_manager, partition_manager=partition_manager, checkpoint_type=checkpoint_type) self._all_params = all_params self._which_params = which_params def get( self, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", **kwargs, ) -> Union[List[PosteriorState], PosteriorState]: state = super().get( i=i, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, **kwargs, ) return self._update_state(state, modify="add") @@ -57,33 +59,27 @@ def put( self, state: PosteriorState, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, keep: int = 1, - prefix: str = "checkpoint_", ) -> None: state = self._update_state(state, modify="remove") return super().put( state=state, i=i, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep, - prefix=prefix, ) def pull( self, i: int = None, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, ) -> PosteriorState: state = super().pull( i=i, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, - **kwargs, ) return self._update_state(state, modify="add") @@ -91,15 +87,12 @@ def extract( self, keys: List[str], i: int = None, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", - **kwargs, + checkpoint_dir: Optional[Path] = None, ) -> Union[Dict, List[Dict]]: def _extract(_i): state = self.get( i=_i, - checkpoint_path=checkpoint_path, - prefix=prefix, + checkpoint_dir=checkpoint_dir, ) return {k: getattr(state, k) for k in keys} @@ -119,7 +112,7 @@ def _update_state( return state if isinstance(state, list): - return [_update_state(_state, modify=modify) for _state in state] + return [self._update_state(_state, modify=modify) for _state in state] if modify == "add": state = state.replace( @@ -154,6 +147,6 @@ def _update_state( step=state.step, ) else: - raise RuntimeError(f"Invalid update state method {method}.") + raise RuntimeError(f"Invalid update state method {modify}.") return state diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py index 7641ba84..bd7d506e 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py @@ -39,8 +39,10 @@ def training_step_end(self, state: TrainState) -> TrainState: self._current_step += 1 if self._do_sample(self._current_step, self._samples_count): + state = self._trainer._sync_state(state) + self._state_repository.put( - state=self._trainer._sync_state(state), + state=state, i=self._samples_count, keep=self._keep_top_n_checkpoints, ) diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_trainer.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_trainer.py new file mode 100644 index 00000000..600db2d8 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_trainer.py @@ -0,0 +1,27 @@ +from fortuna.prob_model.posterior.map.map_trainer import MAPTrainer +from typing import Dict, List +import jax.numpy as jnp +from fortuna.training.train_state import TrainState +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin + + +class SGMCMCTrainer(MAPTrainer): + def on_train_end(self, state: TrainState, mark_checkpoint_as_last: bool = False) -> TrainState: + return super().on_train_end(state, mark_checkpoint_as_last) + + def validation_epoch_end( + self, + validation_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], + state: TrainState, + mark_checkpoint_as_best: bool = False + ) -> Dict[str, float]: + return super().validation_epoch_end(validation_losses_and_metrics_current_epoch, state, mark_checkpoint_as_best) + + +class JittedSGMCMCTrainer(JittedMixin, SGMCMCTrainer): + pass + + +class MultiDeviceSGMCMCTrainer(MultiDeviceMixin, SGMCMCTrainer): + pass diff --git a/fortuna/prob_model/posterior/sngp/sngp_posterior.py b/fortuna/prob_model/posterior/sngp/sngp_posterior.py index 2a40efb1..19b108c3 100755 --- a/fortuna/prob_model/posterior/sngp/sngp_posterior.py +++ b/fortuna/prob_model/posterior/sngp/sngp_posterior.py @@ -13,6 +13,8 @@ from fortuna.prob_model.posterior.state import PosteriorState from fortuna.typing import Status from fortuna.utils.nested_dicts import find_one_path_to_key +from fortuna.partitioner.partition_manager.base import PartitionManager + logger = logging.getLogger(__name__) @@ -22,6 +24,7 @@ def __init__( self, joint: Joint, posterior_approximator: SNGPPosteriorApproximator, + partition_manager: PartitionManager, ): """ Spectral-normalized Neural Gaussian Process (`SNGP `_) approximate posterior class. @@ -33,7 +36,7 @@ def __init__( posterior_approximator: SNGPPosteriorApproximator An SNGP posterior approximator. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) def fit( self, diff --git a/fortuna/prob_model/posterior/swag/swag_posterior.py b/fortuna/prob_model/posterior/swag/swag_posterior.py index 1cb4bb30..01fdffe8 100755 --- a/fortuna/prob_model/posterior/swag/swag_posterior.py +++ b/fortuna/prob_model/posterior/swag/swag_posterior.py @@ -8,11 +8,12 @@ from jax._src.prng import PRNGKeyArray from jax.flatten_util import ravel_pytree import jax.numpy as jnp - +import pathlib from fortuna.data.loader import ( DataLoader, InputsLoader, ) +from fortuna.utils.checkpoint import get_checkpoint_manager from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState @@ -26,6 +27,7 @@ from fortuna.prob_model.posterior.swag.swag_approximator import ( SWAGPosteriorApproximator, ) +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.posterior.swag.swag_state import SWAGState from fortuna.prob_model.posterior.swag.swag_trainer import ( JittedSWAGTrainer, @@ -46,7 +48,12 @@ class SWAGPosterior(Posterior): - def __init__(self, joint: Joint, posterior_approximator: SWAGPosteriorApproximator): + def __init__( + self, + joint: Joint, + posterior_approximator: SWAGPosteriorApproximator, + partition_manager: PartitionManager + ): """ SWAG approximate posterior class. @@ -57,7 +64,7 @@ def __init__(self, joint: Joint, posterior_approximator: SWAGPosteriorApproximat posterior_approximator: SWAGPosteriorApproximator A SWAG posterior approximator. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) def __str__(self): return SWAG_NAME @@ -94,14 +101,30 @@ def fit( status = dict() + checkpoint_restorer = ( + get_checkpoint_manager( + str( + pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) + if super()._is_state_available_somewhere(fit_config): state = super()._restore_state_from_somewhere( - fit_config=fit_config, allowed_states=(MAPState, SWAGState) + fit_config=fit_config, + allowed_states=(MAPState, SWAGState), + checkpoint_manager=checkpoint_restorer, + _do_reshard=False ) elif super()._should_run_preliminary_map(fit_config, map_fit_config): state, status["map"] = run_preliminary_map( joint=self.joint, + partition_manager=self.partition_manager, train_data_loader=train_data_loader, val_data_loader=val_data_loader, map_fit_config=map_fit_config, @@ -112,7 +135,7 @@ def fit( raise ValueError( "The SWAG approximation must start from a preliminary run of MAP or an existing " "checkpoint or state. Please configure `map_fit_config`, or " - "`fit_config.checkpointer.restore_checkpoint_path`, " + "`fit_config.checkpointer.restore_checkpoint_dir`, " "or `fit_config.checkpointer.start_from_current_state`." ) @@ -139,6 +162,11 @@ def fit( ) trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + fit_config.checkpointer.save_checkpoint_dir, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ), save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, @@ -155,11 +183,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, @@ -169,11 +197,20 @@ def fit( ) self.state = PosteriorStateRepository( - fit_config.checkpointer.save_checkpoint_dir - if fit_config.checkpointer.dump_state is True - else None + partition_manager=None, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.save_checkpoint_dir is not None + and fit_config.checkpointer.dump_state + else None, ) - self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) + if self.state.checkpoint_manager is None: + self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") return status @@ -245,7 +282,7 @@ def sample( params=FrozenDict( nested_set( d=state.params.unfreeze(), - key_paths=which_params, + key_paths=tuple(which_params), objs=tuple( self._get_sample( mean=state.mean, diff --git a/fortuna/prob_model/posterior/swag/swag_trainer.py b/fortuna/prob_model/posterior/swag/swag_trainer.py index 30e4fa95..c73027d0 100644 --- a/fortuna/prob_model/posterior/swag/swag_trainer.py +++ b/fortuna/prob_model/posterior/swag/swag_trainer.py @@ -16,27 +16,36 @@ from fortuna.prob_model.posterior.map.map_trainer import MAPTrainer from fortuna.prob_model.posterior.swag.swag_state import SWAGState from fortuna.training.callback import Callback -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, -) +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin from fortuna.typing import ( Array, Batch, Path, ) +import pathlib +from fortuna.partitioner.partition_manager.base import PartitionManager +from orbax.checkpoint import CheckpointManager from fortuna.utils.strings import encode_tuple_of_lists_of_strings_to_numpy class SWAGTrainer(MAPTrainer): - def __init__(self, *, which_params: Optional[Tuple[List[str]]], **kwargs): - super(SWAGTrainer, self).__init__(**kwargs) + def __init__(self, + *, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + which_params: Optional[Tuple[List[str]]], + **kwargs + ): + super(SWAGTrainer, self).__init__(partition_manager=partition_manager, checkpoint_manager=checkpoint_manager, **kwargs) self._mean_rav_params = None self._mean_squared_rav_params = None self._deviation_rav_params = None self._encoded_which_params = encode_tuple_of_lists_of_strings_to_numpy( which_params ) + self.partition_manager = partition_manager + self.checkpoint_manager = checkpoint_manager def _update_state_with_stats(self, state: SWAGState) -> SWAGState: var = self._mean_squared_rav_params - self._mean_rav_params**2 @@ -98,15 +107,15 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: state = self._update_state_with_stats(state) - super().save_checkpoint(state, save_checkpoint_dir, keep, force_save, prefix) + super().save_checkpoint(state, save_checkpoint_dir, keep, force_save) def on_train_end(self, state: SWAGState) -> SWAGState: self.save_checkpoint( state, - save_checkpoint_dir=self.save_checkpoint_dir, + save_checkpoint_dir=str(pathlib.Path(self.save_checkpoint_dir) / "last") if self.save_checkpoint_dir is not None else None, keep=self.keep_top_n_checkpoints, force_save=True, ) diff --git a/fortuna/prob_model/predictive/base.py b/fortuna/prob_model/predictive/base.py index ba7cee6e..91590bf9 100644 --- a/fortuna/prob_model/predictive/base.py +++ b/fortuna/prob_model/predictive/base.py @@ -1,5 +1,7 @@ import abc +import inspect from typing import ( + Any, Callable, Dict, List, @@ -8,24 +10,24 @@ Union, ) -from flax.training.common_utils import shard_prng_key from jax import ( jit, lax, - pmap, random, ) from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit import jax.numpy as jnp import jax.scipy as jsp +from jax.sharding import PartitionSpec from jax.tree_util import tree_map from fortuna.data.loader import ( DataLoader, - DeviceDimensionAugmentedLoader, InputsLoader, TargetsLoader, ) +from fortuna.data.loader.base import ShardedPrefetchedLoader from fortuna.prob_model.posterior.base import Posterior from fortuna.typing import ( Array, @@ -133,7 +135,7 @@ def log_prob( data_loader: DataLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, **kwargs, ) -> jnp.ndarray: r""" @@ -156,8 +158,8 @@ def log_prob( that would be produced using the posterior distribution state. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -167,12 +169,12 @@ def log_prob( if rng is None: rng = self.rng.get() - return self._loop_fun_through_data_loader( + return self._loop_fun_through_loader( self._batched_log_prob, data_loader, n_posterior_samples, rng, - distribute, + shard, **kwargs, ) @@ -181,25 +183,49 @@ def _batched_log_prob( batch: Batch, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, **kwargs, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]: if rng is None: rng = self.rng.get() keys = random.split(rng, n_posterior_samples) - def _lik_log_batched_prob(key): - sample = self.posterior.sample(inputs=batch[0], rng=key) + def _lik_log_batched_prob(params, mutable, calib_params, calib_mutable): return self.likelihood._batched_log_prob( - sample.params, + params, batch, - mutable=sample.mutable, - calib_params=sample.calib_params, - calib_mutable=sample.calib_mutable, + mutable=mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, **kwargs, ) + if shard and self.posterior.partition_manager.shardings is not None: + _lik_log_batched_prob = pjit( + _lik_log_batched_prob, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("dp", "fsdp")), + ) + else: + _lik_log_batched_prob = jit(_lik_log_batched_prob) + + def _fun(key): + sample = self.posterior.sample(inputs=batch[0], rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _lik_log_batched_prob( + sample.params, + sample.mutable, + sample.calib_params, + sample.calib_mutable, + ) + return jsp.special.logsumexp( - lax.map(_lik_log_batched_prob, keys), axis=0 + jnp.stack(list(map(_fun, keys))), axis=0 ) - jnp.log(n_posterior_samples) def _batched_log_joint_prob( @@ -317,9 +343,9 @@ def sample( n_target_samples: int = 1, return_aux: Optional[List[str]] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, **kwargs, - ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]]: + ) -> Union[Tuple[Array, Dict[str, Array]], Array]: r""" Sample from an approximation of the predictive distribution for each input data point, that is @@ -333,6 +359,7 @@ def sample( Parameters ---------- + **kwargs inputs_loader : InputsLoader A loader of input data points. n_target_samples : int @@ -341,63 +368,27 @@ def sample( Return auxiliary objects. We currently support 'outputs'. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- - Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]] + Tuple[Array, Dict[str, Array]] | Array Samples for each input data point. Optionally, an auxiliary object is returned. """ if not rng: rng = self.rng.get() - def fun(_inputs): - return self._batched_sample( - _inputs, n_target_samples, return_aux, rng, **kwargs - ) - - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - fun = pmap(fun) - if return_aux is None or len(return_aux) == 0: - return jnp.concatenate( - [ - self.likelihood._unshard_array(fun(inputs)) - for inputs in inputs_loader - ], - 1, - ) - else: - samples, aux_outputs = [], [] - for inputs in inputs_loader: - _samples, _aux = fun(inputs) - samples.append(self.likelihood._unshard_array(_samples)) - if "outputs" in _aux: - aux_outputs.append( - self.likelihood._unshard_array(_aux["outputs"]) - ) - samples = jnp.concatenate(samples, axis=0) - aux = dict() - if "outputs" in aux: - aux["outputs"] = jnp.concatenate(aux_outputs, axis=0) - return samples, aux - else: - fun = jit(fun) - if return_aux is None or len(return_aux) == 0: - return jnp.concatenate([fun(inputs) for inputs in inputs_loader], 1) - else: - samples, aux_outputs = [], [] - for inputs in inputs_loader: - _samples, _aux = fun(inputs) - samples.append(_samples) - if "outputs" in _aux: - aux_outputs.append(_aux["outputs"]) - samples = jnp.concatenate(samples, axis=0) - aux = dict() - if "outputs" in aux: - aux["outputs"] = jnp.concatenate(aux_outputs, axis=0) - return samples, aux + return self._loop_fun_through_loader( + self._batched_sample, + inputs_loader, + n_target_samples, + rng, + shard, + is_fun_ensembled=False, + return_aux=return_aux, + **kwargs, + ) def _batched_sample( self, @@ -405,8 +396,9 @@ def _batched_sample( n_target_samples: int = 1, return_aux: Optional[List[str]] = None, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, **kwargs, - ) -> jnp.ndarray: + ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: if return_aux is None: return_aux = [] @@ -414,33 +406,57 @@ def _batched_sample( rng = self.rng.get() keys = random.split(rng, n_target_samples) - def _sample(key): - key1, key2 = random.split(key, 2) - _post_sample = self.posterior.sample(inputs=inputs, rng=key1) - outs = self.likelihood._batched_sample( + def _sample(rng, params, mutable, calib_params, calib_mutable): + return self.likelihood._batched_sample( 1, - _post_sample.params, + params, inputs, - mutable=_post_sample.mutable, - calib_params=_post_sample.calib_params, - calib_mutable=_post_sample.calib_mutable, + mutable=mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, return_aux=return_aux, - rng=key2, + rng=rng, **kwargs, ) - if len(return_aux) > 0: - _samples, aux = outs - return _samples.squeeze(0), aux - return outs.squeeze(0) - return lax.map(_sample, keys) + if shard and self.posterior.partition_manager.shardings is not None: + _sample = pjit( + _sample, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("dp", "fsdp")) + if not len(return_aux) + else (PartitionSpec(("dp", "fsdp")), PartitionSpec()), + ) + + def _fun(key): + key1, key2 = random.split(key, 2) + with self.posterior.partition_manager.partitioner.mesh: + sample = self.posterior.sample(inputs=inputs, rng=key1) + return _sample( + key2, + sample.params, + sample.mutable, + sample.calib_params, + sample.calib_mutable, + ) + + samples = list(map(_fun, keys)) + if len(return_aux): + samples, aux = samples + return jnp.stack(samples), aux + return jnp.stack(samples) def sample_calibrated_outputs( self, inputs_loader: InputsLoader, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Sample parameters from the posterior distribution state and compute calibrated outputs. @@ -453,8 +469,8 @@ def sample_calibrated_outputs( Number of output samples to draw for each input. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -464,12 +480,13 @@ def sample_calibrated_outputs( if rng is None: rng = self.rng.get() - return self._loop_ensemble_fun_through_inputs_loader( + return self._loop_fun_through_loader( self._sample_batched_calibrated_outputs, inputs_loader, n_output_samples, rng, - distribute, + shard, + is_fun_ensembled=True, ) def _sample_batched_calibrated_outputs( @@ -477,39 +494,58 @@ def _sample_batched_calibrated_outputs( inputs: Array, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) - def _sample(key): - sample = self.posterior.sample(inputs=inputs, rng=key) + def _apply_fn(params, mutable, calib_params, calib_mutable): return self.likelihood._get_batched_calibrated_outputs( - params=sample.params, + params=params, inputs=inputs, - mutable=sample.mutable, - calib_params=sample.calib_params, - calib_mutable=sample.calib_mutable, + mutable=mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, ) - return lax.map(_sample, keys) + if shard and self.posterior.partition_manager.shardings is not None: + _apply_fn = pjit( + _apply_fn, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("fsdp", "dp")), + ) + else: + _apply_fn = jit(_apply_fn) + + def _sample(key): + sample = self.posterior.sample(inputs=inputs, rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _apply_fn(sample.params, sample.mutable) + + return jnp.stack(list(map(_sample, keys))) def _sample_outputs( self, inputs_loader: InputsLoader, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() - return self._loop_fun_through_inputs_loader( + return self._loop_fun_through_loader( self._sample_batched_outputs, inputs_loader, n_output_samples, rng, - distribute, + shard, ) def _sample_batched_outputs( @@ -517,18 +553,35 @@ def _sample_batched_outputs( inputs: Array, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) - def _sample(key): - sample = self.posterior.sample(inputs=inputs, rng=key) + def _apply_fn(params, mutable): return self.likelihood.model_manager.apply( - params=sample.params, inputs=inputs, mutable=sample.mutable + params=params, inputs=inputs, mutable=mutable ) - return lax.map(_sample, keys) + if shard and getattr(self.posterior.partition_manager, "shardings") is not None: + _apply_fn = pjit( + _apply_fn, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + ), + out_shardings=PartitionSpec(("fsdp", "dp")), + ) + else: + _apply_fn = jit(_apply_fn) + + def _sample(key): + sample = self.posterior.sample(inputs=inputs, rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _apply_fn(sample.params, sample.mutable) + + return jnp.stack(list(map(_sample, keys))) def _sample_outputs_loader( self, @@ -536,23 +589,12 @@ def _sample_outputs_loader( n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, return_size: bool = False, - distribute: bool = True, + shard: bool = True, ) -> Union[TargetsLoader, Tuple[TargetsLoader, int]]: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - - def _sample(key, _inputs): - sample = self.posterior.sample(inputs=_inputs, rng=key) - return self.likelihood.model_manager.apply( - params=sample.params, inputs=_inputs, mutable=sample.mutable - ) - - _sample = pmap(_sample) if distribute else jit(_sample) - iterable = [] size = 0 for inputs in inputs_loader: @@ -561,13 +603,18 @@ def _sample(key, _inputs): if not isinstance(inputs, dict) else inputs[list(inputs.keys())[0]].shape[0] ) - if distribute: - outputs = jnp.stack( - list(map(lambda key: _sample(shard_prng_key(key), inputs), keys)) + outputs = jnp.stack( + list( + map( + lambda key: self._sample_batched_outputs( + inputs=inputs, + rng=key, + shard=shard, + )[0], + keys, + ) ) - outputs = self._unshard_ensemble_arrays(outputs) - else: - outputs = lax.map(lambda key: _sample(key, inputs), keys) + ) iterable.append(outputs) iterable = TargetsLoader.from_iterable(iterable=iterable) if return_size: @@ -579,7 +626,7 @@ def mean( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive mean of the target variable, that is @@ -601,8 +648,8 @@ def mean( Number of samples to draw from the posterior distribution for each input. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -612,8 +659,8 @@ def mean( if rng is None: rng = self.rng.get() - return self._loop_fun_through_inputs_loader( - self._batched_mean, inputs_loader, n_posterior_samples, rng, distribute + return self._loop_fun_through_loader( + self._batched_mean, inputs_loader, n_posterior_samples, rng, shard ) def _batched_mean( @@ -621,25 +668,46 @@ def _batched_mean( inputs: Array, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() keys = random.split(rng, n_posterior_samples) - def fun(i, _curr_sum): - _sample = self.posterior.sample(inputs=inputs, rng=keys[i]) - _curr_sum += self.likelihood._batched_mean( - _sample.params, + def _lik_batched_mean(params, mutable, calib_params, calib_mutable): + return self.likelihood._batched_mean( + params, inputs, - _sample.mutable, - calib_params=_sample.calib_params, - calib_mutable=_sample.calib_mutable, + mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, ) - return _curr_sum - curr_sum = fun(0, 0.0) - curr_sum = lax.fori_loop(1, n_posterior_samples, fun, curr_sum) - return curr_sum / n_posterior_samples + if shard and self.posterior.partition_manager.shardings is not None: + _lik_batched_mean = pjit( + _lik_batched_mean, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("dp", "fsdp")), + ) + else: + _lik_batched_mean = jit(_lik_batched_mean) + + def fun(key): + _sample = self.posterior.sample(inputs=inputs, rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _lik_batched_mean( + params=_sample.params, + mutable=_sample.mutable, + calib_params=_sample.calib_params, + calib_mutable=_sample.calib_mutable, + ) + + return jnp.mean(list(map(fun, keys)), axis=0) @abc.abstractmethod def mode( @@ -648,7 +716,7 @@ def mode( n_posterior_samples: int = 30, means: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive mode of the target variable, that is @@ -671,8 +739,8 @@ def mode( An estimate of the predictive mean. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -686,7 +754,7 @@ def aleatoric_variance( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive aleatoric variance of the target variable, that is @@ -708,8 +776,8 @@ def aleatoric_variance( Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -724,7 +792,7 @@ def aleatoric_variance( inputs_loader, n_posterior_samples, rng, - distribute, + shard, ) def _batched_aleatoric_variance( @@ -757,7 +825,7 @@ def epistemic_variance( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive epistemic variance of the one-hot encoded target variable, that is @@ -779,8 +847,8 @@ def epistemic_variance( Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -790,12 +858,12 @@ def epistemic_variance( if rng is None: rng = self.rng.get() - return self._loop_fun_through_inputs_loader( + return self._loop_fun_through_loader( self._batched_epistemic_variance, inputs_loader, n_posterior_samples, rng, - distribute, + shard, ) def _batched_epistemic_variance( @@ -837,7 +905,7 @@ def variance( aleatoric_variances: Optional[jnp.ndarray] = None, epistemic_variances: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive variance of the target variable, that is @@ -864,8 +932,8 @@ def variance( An estimate of the epistemic predictive variance for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -880,7 +948,7 @@ def variance( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=key, - distribute=distribute, + shard=shard, ) if epistemic_variances is None: rng, key = random.split(rng) @@ -888,7 +956,7 @@ def variance( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=key, - distribute=distribute, + shard=shard, ) return aleatoric_variances + epistemic_variances @@ -898,7 +966,7 @@ def std( n_posterior_samples: int = 30, variances: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive standard deviation of the target variable, that is @@ -921,8 +989,8 @@ def std( An estimate of the predictive variance. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -934,84 +1002,44 @@ def std( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=rng, - distribute=distribute, + shard=shard, ) return jnp.sqrt(variances) - @staticmethod - def _unshard_ensemble_arrays(arr: Array) -> Array: - arr = arr.swapaxes(1, 2) - arr = arr.reshape((arr.shape[0] * arr.shape[1],) + arr.shape[2:]) - return arr.swapaxes(0, 1) - - def _loop_fun_through_inputs_loader( + def _loop_fun_through_loader( self, fun: Callable, - inputs_loader: InputsLoader, + loader: Union[InputsLoader, DataLoader, TargetsLoader], n_posterior_samples: int, rng: PRNGKeyArray, - distribute: bool = True, + shard: bool, + is_fun_ensembled: bool = False, + return_aux: Optional[List[str]] = None, **kwargs, - ) -> Array: - def fun2(_inputs): - return fun(_inputs, n_posterior_samples, rng, **kwargs) - - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - fun2 = pmap(fun2) - return jnp.concatenate( - [ - self.likelihood._unshard_array(fun2(inputs)) - for inputs in inputs_loader - ], - 0, + ) -> Union[tuple[Any, ...], Array]: + if shard and self.posterior.partition_manager.shardings is not None: + loader = ShardedPrefetchedLoader( + loader=loader, partition_manager=self.posterior.partition_manager ) - fun2 = jit(fun2) - return jnp.concatenate([fun2(inputs) for inputs in inputs_loader], 0) - def _loop_fun_through_data_loader( - self, - fun: Callable, - data_loader: DataLoader, - n_posterior_samples: int, - rng: PRNGKeyArray, - distribute: bool = True, - **kwargs, - ) -> Array: - def fun2(_batch): - return fun(_batch, n_posterior_samples, rng, **kwargs) - - if distribute: - data_loader = DeviceDimensionAugmentedLoader(data_loader) - fun2 = pmap(fun2) - return jnp.concatenate( - [self.likelihood._unshard_array(fun2(batch)) for batch in data_loader], - 0, - ) - fun2 = jit(fun2) - return jnp.concatenate([fun2(batch) for batch in data_loader], 0) + def fun2(_data): + if "return_aux" in inspect.getfullargspec(fun)[0]: + return fun( + _data, + n_posterior_samples, + rng, + shard, + return_aux=return_aux, + **kwargs, + ) + return fun(_data, n_posterior_samples, rng, shard, **kwargs) - def _loop_ensemble_fun_through_inputs_loader( - self, - fun: Callable, - inputs_loader: InputsLoader, - n_posterior_samples: int, - rng: PRNGKeyArray, - distribute: bool = True, - **kwargs, - ) -> Array: - def fun2(_inputs): - return fun(_inputs, n_posterior_samples, rng, **kwargs) - - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - fun2 = pmap(fun2) - return jnp.concatenate( + outs = [fun2(data) for data in loader] + if return_aux is not None: + return tuple( [ - self._unshard_ensemble_arrays(fun2(inputs)) - for inputs in inputs_loader - ], - 1, + tree_map(lambda v: jnp.concatenate(v, int(is_fun_ensembled)), out) + for out in zip(*outs) + ] ) - fun2 = jit(fun2) - return jnp.concatenate([fun2(inputs) for inputs in inputs_loader], 1) + return jnp.concatenate(outs, int(is_fun_ensembled)) diff --git a/fortuna/prob_model/predictive/regression.py b/fortuna/prob_model/predictive/regression.py index 7f9a67e4..96e31ffc 100644 --- a/fortuna/prob_model/predictive/regression.py +++ b/fortuna/prob_model/predictive/regression.py @@ -363,9 +363,6 @@ def quantile( if type(q) == list: q = jnp.array(q) samples = self.sample( - inputs_loader=inputs_loader, - n_target_samples=n_target_samples, - rng=rng, - distribute=distribute, + inputs_loader=inputs_loader, n_target_samples=n_target_samples, rng=rng ) return jnp.quantile(samples, q, axis=0) diff --git a/fortuna/prob_model/prob_model_calibrator.py b/fortuna/prob_model/prob_model_calibrator.py index 24cc1692..589840d7 100644 --- a/fortuna/prob_model/prob_model_calibrator.py +++ b/fortuna/prob_model/prob_model_calibrator.py @@ -6,19 +6,12 @@ Union, ) -from flax import jax_utils -import jax from jax._src.prng import PRNGKeyArray import jax.numpy as jnp -from jax.tree_util import tree_map -from fortuna.data import TargetsLoader from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.output_calibrator import ( - JittedMixin, - MultiDeviceMixin, - OutputCalibratorABC, -) +from fortuna.training.output_calibrator.base import OutputCalibratorABC +from fortuna.training.output_calibrator.mixins.sharding import ShardingMixin from fortuna.typing import ( Array, Batch, @@ -60,7 +53,7 @@ def training_loss_step( }, ) - def val_loss_step( + def validation_loss_step( self, state: OutputCalibState, batch: Batch, @@ -84,49 +77,5 @@ def __str__(self): return "calibration" -class ProbModelMultiDeviceMixin(MultiDeviceMixin): - @staticmethod - def _add_device_dim_to_outputs_loader( - outputs_loader: TargetsLoader, - ) -> TargetsLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[1] % n_devices != 0: - raise ValueError( - f"The size of all output batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch of outputs with shape {batch.shape[1]} was found. " - f"Please set an appropriate batch size." - ) - shape = batch.shape - return ( - batch.swapaxes(0, 1) - .reshape(n_devices, shape[1] // n_devices, shape[0], shape[2]) - .swapaxes(1, 2) - ) - - class TargetsLoaderWrapper: - def __init__(self, outputs_loader: TargetsLoader): - self._outputs_loader = outputs_loader - - def __iter__(self): - outputs_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._outputs_loader - ) - outputs_loader = jax_utils.prefetch_to_device(outputs_loader, 2) - yield from outputs_loader - - return ( - TargetsLoaderWrapper(outputs_loader) - if outputs_loader is not None - else outputs_loader - ) - - -class JittedProbModelOutputCalibrator(JittedMixin, ProbModelOutputCalibrator): - pass - - -class MultiDeviceProbModelOutputCalibrator( - ProbModelMultiDeviceMixin, ProbModelOutputCalibrator -): +class ShardedProbModelOutputCalibrator(ShardingMixin, ProbModelOutputCalibrator): pass diff --git a/fortuna/prob_model/regression.py b/fortuna/prob_model/regression.py index f39cdbc0..58745a1d 100755 --- a/fortuna/prob_model/regression.py +++ b/fortuna/prob_model/regression.py @@ -12,6 +12,8 @@ from fortuna.model_editor.base import ModelEditor from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager from fortuna.output_calibrator.regression import RegressionTemperatureScaler +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.base import ProbModel from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig @@ -39,6 +41,7 @@ def __init__( posterior_approximator: PosteriorApproximator = SWAGPosteriorApproximator(), output_calibrator: Optional[nn.Module] = RegressionTemperatureScaler(), model_editor: Optional[ModelEditor] = None, + partitioner: Partitioner = Partitioner(), seed: int = 0, ): r""" @@ -67,6 +70,8 @@ def __init__( calibration parameters. model_editor : ModelEditor A model_editor objects. It takes the forward pass and transforms the outputs. + partitioner : Partitioner + A partitioning object for data, fully sharded data model parallelization. seed: int A random seed. @@ -119,10 +124,14 @@ def __init__( self.model_manager, self.prob_output_layer, self.output_calib_manager ) self.joint = Joint(self.prior, self.likelihood) - + self.partition_manager = PartitionManager(partitioner) self.posterior = getattr( PosteriorApproximations, posterior_approximator.__str__() - ).value(joint=self.joint, posterior_approximator=posterior_approximator) + ).value( + joint=self.joint, + posterior_approximator=posterior_approximator, + partition_manager=self.partition_manager, + ) self.predictive = RegressionPredictive(self.posterior) super().__init__(seed=seed) @@ -141,6 +150,7 @@ def _check_output_dim(self, data_loader: DataLoader): outputs = self.model_manager.apply( params=s.params, inputs=np.zeros((1,) + input_shape), mutable=s.mutable ) + if outputs.shape[1] != 2 * output_dim: raise ValueError( f"""The outputs dimension of both `model` and `likelihood_log_variance_model` must be the same as diff --git a/fortuna/sagemaker/base.py b/fortuna/sagemaker/base.py index a7dfd289..9e88692c 100644 --- a/fortuna/sagemaker/base.py +++ b/fortuna/sagemaker/base.py @@ -98,13 +98,13 @@ def _run_training_job(cfg: DictConfig) -> None: rules=rules, ) - if "tuner" in cfg.sagemaker: - logger.info(f"Starting hyperparams optimization: {cfg.sagemaker.tuner}") + if "tuner" in cfg.hyperparams: + logger.info(f"Starting hyperparams optimization: {cfg.hyperparams.tuner}") estimator = HyperparameterTuner( estimator=estimator, base_tuning_job_name=base_job_name, metric_definitions=metrics, - **instantiate(cfg.sagemaker.tuner, _convert_="partial"), + **instantiate(cfg.hyperparams.tuner, _convert_="partial"), ) estimator.fit(inputs=channels, wait=False) diff --git a/fortuna/training/mixin.py b/fortuna/training/mixin.py deleted file mode 100755 index ce9772d6..00000000 --- a/fortuna/training/mixin.py +++ /dev/null @@ -1,163 +0,0 @@ -import logging -import os -from typing import ( - Dict, - Optional, -) - -from flax.training import checkpoints -from flax.training.early_stopping import EarlyStopping - -from fortuna.training.name_to_train_state import NameToTrainState -from fortuna.training.train_state import TrainState -from fortuna.typing import ( - OptaxOptimizer, - Path, -) - -logger = logging.getLogger(__name__) - - -class WithCheckpointingMixin: - def __init__( - self, - **kwargs, - ): - """ - Mixin class for all trainers that need checkpointing capabilities. This is a wrapper around functions in - `flax.training.checkpoints.*`. - """ - super(WithCheckpointingMixin, self).__init__(**kwargs) - - def save_checkpoint( - self, - state: TrainState, - save_checkpoint_dir: Path, - keep: int = 1, - force_save: bool = False, - prefix: str = "checkpoint_", - ) -> None: - if save_checkpoint_dir: - save_ckpt_fn = lambda state: checkpoints.save_checkpoint( - ckpt_dir=str(save_checkpoint_dir), - target=state, - step=state.step, - prefix=prefix, - keep=keep, - overwrite=force_save, - ) - if ( - hasattr(state, "grad_accumulated") - and state.grad_accumulated is not None - ): - # do not save grad accumulated in the ckpt - state = state.replace(grad_accumulated=None) - save_ckpt_fn(state) - else: - save_ckpt_fn(state) - - def restore_checkpoint( - self, - restore_checkpoint_path: Path, - optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - name_to_train_state: NameToTrainState = NameToTrainState, - **kwargs, - ) -> TrainState: - if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( - restore_checkpoint_path - ): - raise ValueError( - f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_path), - target=None, - step=None, - prefix=prefix, - parallel=True, - ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." - ) - name = "".join([chr(n) for n in d["encoded_name"].tolist()]) - return name_to_train_state[name].value.init_from_dict(d, optimizer, **kwargs) - - def get_path_latest_checkpoint( - self, checkpoint_dir: Path, prefix: str = "checkpoint_" - ) -> Optional[str]: - return checkpoints.latest_checkpoint(ckpt_dir=checkpoint_dir, prefix=prefix) - - -class WithEarlyStoppingMixin: - def __init__( - self, - *, - early_stopping_monitor: str = "val_loss", - early_stopping_min_delta: float = 0.0, - early_stopping_patience: Optional[int] = 0, - early_stopping_mode: str = "min", - early_stopping_verbose: bool = True, - **kwargs, - ): - super(WithEarlyStoppingMixin, self).__init__(**kwargs) - self.early_stopping_monitor = early_stopping_monitor - self.early_stopping_mode = early_stopping_mode - self.early_stopping_patience = early_stopping_patience - - if early_stopping_patience is None or early_stopping_patience <= 0: - if early_stopping_verbose: - logging.info( - f"Early stopping not enabled. Set `early_stopping_patience>=0` to enable it." - ) - elif self.early_stopping_mode is None or self.early_stopping_mode not in ( - "min", - "max", - ): - if early_stopping_verbose: - logging.warning( - f"`early_stopping_mode={early_stopping_mode}` is not a valid. Early stopping will be disabled." - ) - else: - self._early_stopping = EarlyStopping( - min_delta=early_stopping_min_delta, patience=early_stopping_patience - ) - if early_stopping_verbose: - logging.info( - "If validation data are provided, early stopping will be enabled." - ) - - @property - def is_early_stopping_active(self) -> bool: - return not ( - (self.early_stopping_patience is None or self.early_stopping_patience <= 0) - or ( - self.early_stopping_mode is None - or self.early_stopping_mode not in ("min", "max") - ) - ) - - def early_stopping_update( - self, validation_metrics: Dict[str, float] - ) -> Optional[bool]: - improved = None - if self.is_early_stopping_active: - early_stopping_monitor = validation_metrics[self.early_stopping_monitor] - if self.early_stopping_mode == "max": - early_stopping_monitor = -early_stopping_monitor - improved, self._early_stopping = self._early_stopping.update( - early_stopping_monitor - ) - return improved - - -class InputValidatorMixin: - def __init__(self, *args, **kwargs): - if len(args) > 0: - raise AttributeError("Cannot recognize inputs arguments: {}".format(args)) - if len(kwargs) > 0: - raise AttributeError( - "{} are not valid input arguments.".format(list(kwargs.keys())) - ) - super(InputValidatorMixin, self).__init__(*args, **kwargs) diff --git a/fortuna/training/mixins/checkpointing.py b/fortuna/training/mixins/checkpointing.py new file mode 100644 index 00000000..3a443b58 --- /dev/null +++ b/fortuna/training/mixins/checkpointing.py @@ -0,0 +1,164 @@ +import logging +from typing import Optional + +from flax.training.orbax_utils import save_args_from_target +from jax import ( + ShapeDtypeStruct, + local_devices, +) +from jax.sharding import SingleDeviceSharding +from jax.tree_util import ( + tree_map, + tree_map_with_path, +) +from orbax.checkpoint import ( + ArrayRestoreArgs, + CheckpointManager, +) + +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.name_to_train_state import NameToTrainState +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + OptaxOptimizer, + Path, +) +from fortuna.utils.checkpoint import get_checkpoint_manager + +logger = logging.getLogger(__name__) + + +class WithCheckpointingMixin: + def __init__( + self, + *, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + **kwargs, + ): + """ + Mixin class for all trainers that need checkpointing capabilities. This is a wrapper around functions in + `flax.training.checkpoints.*`. + + Parameters + ---------- + partition_manager: PartitionManager, + An object that manages partitions. + checkpoint_manager: CheckpointManager + A checkpoint manager + """ + super(WithCheckpointingMixin, self).__init__(**kwargs) + self.partition_manager = partition_manager + self.checkpoint_manager = checkpoint_manager + + def save_checkpoint( + self, + state: TrainState, + save_checkpoint_dir: Path, + keep: int = 1, + force_save: bool = False, + ) -> None: + checkpoint_manager = ( + get_checkpoint_manager( + checkpoint_dir=save_checkpoint_dir, keep_top_n_checkpoints=keep + ) + if save_checkpoint_dir is not None + else self.checkpoint_manager + ) + if checkpoint_manager is not None: + save_args = save_args_from_target(state) + + def save_ckpt_fn(_state): + return checkpoint_manager.save( + _state.step, + _state, + force=force_save, + save_kwargs={"save_args": save_args}, + ) + + if ( + hasattr(state, "grad_accumulated") + and state.grad_accumulated is not None + ): + # do not save grad accumulated in the ckpt + state = state.replace(grad_accumulated=None) + save_ckpt_fn(state) + + def restore_checkpoint( + self, + restore_checkpoint_dir: Path, + optimizer: Optional[OptaxOptimizer] = None, + name_to_train_state: NameToTrainState = NameToTrainState, + ) -> TrainState: + ref = self._get_ref(lazy=False) + restored = self.checkpoint_manager.restore( + self.checkpoint_manager.latest_step(), + items=ref, + restore_kwargs={"restore_args": ref}, + directory=restore_checkpoint_dir, + ) + if isinstance(restored, dict): + name = "".join([chr(n) for n in restored["encoded_name"]]) + restored = name_to_train_state[name].value.init_from_dict(restored) + + if optimizer is not None: + restored = restored.replace( + tx=optimizer, opt_state=optimizer.init(restored.params) + ) + + return restored + + def get_shapes_dtypes_checkpoint( + self, + restore_checkpoint_dir: Path, + name_to_train_state: NameToTrainState = NameToTrainState, + ): + ref = self._get_ref_without_shardings(lazy=True) + state = self.checkpoint_manager.restore( + self.checkpoint_manager.latest_step(), + items=ref, + restore_kwargs=dict(restore_args=ref), + directory=restore_checkpoint_dir, + ) + name = "".join([chr(n) for n in state["encoded_name"].get().tolist()]) + state = name_to_train_state[name].value.init_from_dict(state) + return tree_map(lambda v: _get_shapes_dtypes(v.get()), state) + + def _get_ref_from_shardings(self): + return tree_map_with_path( + lambda p, sharding, shape_dtype: ArrayRestoreArgsWithShape( + mesh=self.partition_manager.partitioner.mesh, + sharding=sharding, + dtype=shape_dtype.dtype, + shape=shape_dtype.shape, + ), + self.partition_manager.shardings, + self.partition_manager.shapes_dtypes, + ) + + def _get_ref_without_shardings(self, lazy): + return tree_map_with_path( + lambda p, v: ArrayRestoreArgs( + lazy=lazy, sharding=SingleDeviceSharding(device=local_devices()[0]) + ), + self.checkpoint_manager.structure(), + ) + + def _get_ref(self, lazy=False): + if ( + self.partition_manager is not None + and self.partition_manager.shardings is not None + and self.partition_manager.shapes_dtypes is not None + ): + return self._get_ref_from_shardings() + return self._get_ref_without_shardings(lazy=False) + + +class ArrayRestoreArgsWithShape(ArrayRestoreArgs): + def __init__(self, shape, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shape = shape + + +def _get_shapes_dtypes(v): + return ShapeDtypeStruct(shape=v.shape, dtype=v.dtype) diff --git a/fortuna/training/mixins/early_stopping.py b/fortuna/training/mixins/early_stopping.py new file mode 100644 index 00000000..543e9258 --- /dev/null +++ b/fortuna/training/mixins/early_stopping.py @@ -0,0 +1,71 @@ +import logging +from typing import ( + Dict, + Optional, +) + +from flax.training.early_stopping import EarlyStopping + +logger = logging.getLogger(__name__) + + +class WithEarlyStoppingMixin: + def __init__( + self, + *, + early_stopping_monitor: str = "val_loss", + early_stopping_min_delta: float = 0.0, + early_stopping_patience: Optional[int] = 0, + early_stopping_mode: str = "min", + early_stopping_verbose: bool = True, + **kwargs, + ): + super(WithEarlyStoppingMixin, self).__init__(**kwargs) + self.early_stopping_monitor = early_stopping_monitor + self.early_stopping_mode = early_stopping_mode + self.early_stopping_patience = early_stopping_patience + + if early_stopping_patience is None or early_stopping_patience <= 0: + if early_stopping_verbose: + logging.info( + f"Early stopping not enabled. Set `early_stopping_patience>=0` to enable it." + ) + elif self.early_stopping_mode is None or self.early_stopping_mode not in ( + "min", + "max", + ): + if early_stopping_verbose: + logging.warning( + f"`early_stopping_mode={early_stopping_mode}` is not a valid. Early stopping will be disabled." + ) + else: + self._early_stopping = EarlyStopping( + min_delta=early_stopping_min_delta, patience=early_stopping_patience + ) + if early_stopping_verbose: + logging.info( + "If validation data are provided, early stopping will be enabled." + ) + + @property + def is_early_stopping_active(self) -> bool: + return not ( + (self.early_stopping_patience is None or self.early_stopping_patience <= 0) + or ( + self.early_stopping_mode is None + or self.early_stopping_mode not in ("min", "max") + ) + ) + + def early_stopping_update( + self, validation_metrics: Dict[str, float] + ) -> Optional[bool]: + improved = None + if self.is_early_stopping_active: + early_stopping_monitor = validation_metrics[self.early_stopping_monitor] + if self.early_stopping_mode == "max": + early_stopping_monitor = -early_stopping_monitor + improved, self._early_stopping = self._early_stopping.update( + early_stopping_monitor + ) + return improved diff --git a/fortuna/training/mixins/input_validator.py b/fortuna/training/mixins/input_validator.py new file mode 100644 index 00000000..19df7fdd --- /dev/null +++ b/fortuna/training/mixins/input_validator.py @@ -0,0 +1,9 @@ +class InputValidatorMixin: + def __init__(self, *args, **kwargs): + if len(args) > 0: + raise AttributeError("Cannot recognize inputs arguments: {}".format(args)) + if len(kwargs) > 0: + raise AttributeError( + "{} are not valid input arguments.".format(list(kwargs.keys())) + ) + super(InputValidatorMixin, self).__init__(*args, **kwargs) diff --git a/fortuna/training/mixins/jitted.py b/fortuna/training/mixins/jitted.py new file mode 100644 index 00000000..a7323aeb --- /dev/null +++ b/fortuna/training/mixins/jitted.py @@ -0,0 +1,54 @@ +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Optional, + Tuple, +) + +from flax import jax_utils +from flax.core import FrozenDict +import jax +from jax._src.prng import PRNGKeyArray +import jax.numpy as jnp +from optax._src.base import PyTree + +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + Array, + Batch, +) + + +class JittedMixin: + @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7)) + def training_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, Any]]: + return super().training_step( + state, batch, loss_fun, rng, n_data, unravel, kwargs + ) + + @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7, 8)) + def validation_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Dict[str, jnp.ndarray]: + return super().validation_step( + state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs + ) diff --git a/fortuna/training/mixins/multi_device.py b/fortuna/training/mixins/multi_device.py new file mode 100644 index 00000000..a38b4eef --- /dev/null +++ b/fortuna/training/mixins/multi_device.py @@ -0,0 +1,158 @@ +from functools import partial +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, +) + +from flax import jax_utils +from flax.core import FrozenDict +import jax +from jax import ( + lax, + random, +) +from jax._src.prng import PRNGKeyArray +import jax.numpy as jnp +from jax.tree_util import tree_map +from optax._src.base import PyTree + +from fortuna.data.loader import DataLoader +from fortuna.training.callback import Callback +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + Array, + Batch, +) + + +class MultiDeviceMixin: + all_reduce_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.multi_device = True + + @staticmethod + def _add_device_dim_to_input_data_loader(data_loader: DataLoader) -> DataLoader: + def _reshape_input_batch(batch): + n_devices = jax.local_device_count() + if batch.shape[0] % n_devices != 0: + raise ValueError( + f"The size of all batches must be a multiple of {n_devices}, that is the number of " + f"available devices. Please set an appropriate batch size in the data loader." + ) + single_input_shape = batch.shape[1:] + # reshape to (local_devices, device_batch_size, *single_input_shape) + return batch.reshape((n_devices, -1) + single_input_shape) + + class DataLoaderWrapper: + def __init__(self, data_loader): + self.data_loader = data_loader + + def __iter__(self): + data_loader = map( + lambda batch: tree_map(_reshape_input_batch, batch), + self.data_loader, + ) + data_loader = jax_utils.prefetch_to_device(data_loader, 2) + yield from data_loader + + return ( + DataLoaderWrapper(data_loader) if data_loader is not None else data_loader + ) + + @staticmethod + def _sync_mutable(state: TrainState) -> TrainState: + return ( + state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) + if state.mutable is not None + else state + ) + + @staticmethod + def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: + arr = lax.pmean(arr, axis_name="batch") + return arr + + def _sync_state(self, state: TrainState) -> TrainState: + state = self._sync_mutable(state) + return jax.device_get(tree_map(lambda x: x[0], state)) + + def on_train_start( + self, state: TrainState, data_loaders: List[DataLoader], rng: PRNGKeyArray + ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: + state, data_loaders, rng = super(MultiDeviceMixin, self).on_train_start( + state, data_loaders, rng + ) + state = jax_utils.replicate(state) + data_loaders = [ + self._add_device_dim_to_input_data_loader(dl) for dl in data_loaders + ] + model_key = random.split(rng, jax.local_device_count()) + return state, data_loaders, model_key + + def on_train_end(self, state: TrainState) -> TrainState: + state = super(MultiDeviceMixin, self).on_train_end(state) + return jax.device_get(tree_map(lambda x: x[0], state)) + + def training_step_start(self, rng: PRNGKeyArray, step: int) -> PRNGKeyArray: + step = step if isinstance(step, int) or step.ndim == 0 else step[0] + return jax.vmap(lambda r: random.fold_in(r, step))(rng) + + @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7)) + def training_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, Any]]: + return super().training_step( + state, batch, loss_fun, rng, n_data, unravel, kwargs + ) + + def training_step_end( + self, + current_epoch: int, + state: TrainState, + aux: Dict[str, Any], + batch: Batch, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], + callbacks: Optional[List[Callback]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, jnp.ndarray]]: + state, training_losses_and_metrics = super( + MultiDeviceMixin, self + ).training_step_end( + current_epoch, state, aux, batch, metrics, callbacks, kwargs + ) + return state, tree_map(lambda x: x.mean(), training_losses_and_metrics) + + def on_validation_start(self, state: TrainState) -> TrainState: + state = super(MultiDeviceMixin, self).on_validation_start(state) + state = self._sync_mutable(state) + return state + + @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7, 8)) + def validation_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Dict[str, jnp.ndarray]: + validation_losses_and_metrics = super().validation_step( + state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs + ) + return lax.pmean(validation_losses_and_metrics, axis_name="batch") diff --git a/fortuna/training/mixins/sharding.py b/fortuna/training/mixins/sharding.py new file mode 100644 index 00000000..45246aaa --- /dev/null +++ b/fortuna/training/mixins/sharding.py @@ -0,0 +1,102 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, +) + +from flax.core import FrozenDict +from jax import eval_shape +from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit +import jax.numpy as jnp +from jax.sharding import PartitionSpec +from optax._src.base import PyTree + +from fortuna.data.loader import DataLoader +from fortuna.data.loader.base import ShardedPrefetchedLoader +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + Array, + Batch, +) + + +class ShardingMixin: + def __init__(self, *, partition_manager: PartitionManager, **kwargs): + super().__init__(partition_manager=partition_manager, **kwargs) + self.partition_manager = partition_manager + + def training_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, Any]]: + fun=super().training_step + with self.partition_manager.partitioner.mesh: + return pjit( + fun, + static_argnums=(2, 4, 5, 6), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + out_shardings=( + self.partition_manager.shardings, + PartitionSpec(), + ), + )(state, batch, loss_fun, rng, n_data, unravel, kwargs) + + def validation_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Dict[str, jnp.ndarray]: + with self.partition_manager.partitioner.mesh: + return pjit( + super().validation_step, + static_argnums=(2, 4, 5, 6, 7), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + )(state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs) + + def on_train_start( + self, state: TrainState, data_loaders: List[DataLoader], rng: PRNGKeyArray + ) -> Tuple[TrainState, List[ShardedPrefetchedLoader], PRNGKeyArray]: + state, data_loaders, rng = super(ShardingMixin, self).on_train_start( + state, data_loaders, rng + ) + + if self.freeze_fun is not None: + self.partition_manager = PartitionManager( + partitioner=self.partition_manager.partitioner + ) + self.partition_manager.shapes_dtypes = eval_shape(lambda: state) + + data_loaders = [ + ShardedPrefetchedLoader( + loader=dl, + partition_manager=self.partition_manager, + partition_spec=PartitionSpec(("dp", "fsdp")), + ) + for dl in data_loaders + ] + return state, data_loaders, rng diff --git a/fortuna/training/output_calibrator/__init__.py b/fortuna/training/output_calibrator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/training/output_calibrator.py b/fortuna/training/output_calibrator/base.py similarity index 55% rename from fortuna/training/output_calibrator.py rename to fortuna/training/output_calibrator/base.py index ca6e7373..9bee26db 100644 --- a/fortuna/training/output_calibrator.py +++ b/fortuna/training/output_calibrator/base.py @@ -1,6 +1,5 @@ import abc import collections -from functools import partial import logging from typing import ( Any, @@ -12,11 +11,8 @@ Union, ) -from flax import jax_utils from flax.training.common_utils import stack_forest -import jax from jax import ( - lax, random, value_and_grad, ) @@ -31,11 +27,10 @@ TargetsLoader, ) from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.mixin import ( - InputValidatorMixin, - WithCheckpointingMixin, - WithEarlyStoppingMixin, -) +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin from fortuna.typing import ( Array, Batch, @@ -58,6 +53,7 @@ def __init__( self, *args, calib_outputs_loader: TargetsLoader, + partition_manager: PartitionManager, predict_fn: Callable[[jnp.ndarray], jnp.ndarray], uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray], val_outputs_loader: Optional[TargetsLoader] = None, @@ -68,7 +64,9 @@ def __init__( eval_every_n_epochs: int = 1, **kwargs, ): - super(OutputCalibratorABC, self).__init__(*args, **kwargs) + super(OutputCalibratorABC, self).__init__( + *args, partition_manager=partition_manager, **kwargs + ) self._calib_outputs_loader = calib_outputs_loader self._val_outputs_loader = val_outputs_loader self.predict_fn = predict_fn @@ -78,7 +76,6 @@ def __init__( self.keep_top_n_checkpoints = keep_top_n_checkpoints self.disable_training_metrics_computation = disable_training_metrics_computation self.eval_every_n_epochs = eval_every_n_epochs - self.multi_device = False def train( self, @@ -96,7 +93,7 @@ def train( verbose: bool = True, ) -> Tuple[OutputCalibState, Status]: training_losses_and_metrics = collections.defaultdict(list) - val_losses_and_metrics = collections.defaultdict(list) + validation_losses_and_metrics = collections.defaultdict(list) state, data_loaders, outputs_loaders, rng = self.on_train_start( state, @@ -135,11 +132,11 @@ def train( # validation loop if self.should_perform_validation(val_data_loader, epoch): # performance evaluation on the whole validation dataset - state = self.on_val_start(state) + state = self.on_validation_start(state) ( - val_losses_and_metrics_current_epoch, - val_epoch_metrics_str, - ) = self._val_loop( + validation_losses_and_metrics_current_epoch, + validation_epoch_metrics_str, + ) = self._validation_loop( loss_fun=loss_fun, metrics=metrics, rng=rng, @@ -150,11 +147,13 @@ def train( verbose=verbose, ) if verbose: - logging.info(f"Epoch: {epoch + 1} | " + val_epoch_metrics_str) + logging.info( + f"Epoch: {epoch + 1} | " + validation_epoch_metrics_str + ) # keep track of training losses and metrics [granularity=epoch] and check for early stopping - for k in val_losses_and_metrics_current_epoch.keys(): - val_losses_and_metrics[k].append( - val_losses_and_metrics_current_epoch[k] + for k in validation_losses_and_metrics_current_epoch.keys(): + validation_losses_and_metrics[k].append( + validation_losses_and_metrics_current_epoch[k] ) # check for early stopping if self.is_early_stopping_active and self._early_stopping.should_stop: @@ -165,8 +164,10 @@ def train( training_status = { k: jnp.array(v) for k, v in training_losses_and_metrics.items() } - val_status = {k: jnp.array(v) for k, v in val_losses_and_metrics.items()} - status = dict(**training_status, **val_status) + validation_status = { + k: jnp.array(v) for k, v in validation_losses_and_metrics.items() + } + status = dict(**training_status, **validation_status) state = self.on_train_end(state) return state, status @@ -302,27 +303,14 @@ def training_step_end( if not self.disable_training_metrics_computation and metrics is not None: preds = self.predict_fn(aux["outputs"]) uncertainties = self.uncertainty_fn(aux["outputs"]) - if self.multi_device: - training_batch_metrics = self.compute_metrics( - preds.reshape((preds.shape[0] * preds.shape[1],) + preds.shape[2:]), - uncertainties.reshape( - (uncertainties.shape[0] * uncertainties.shape[1],) - + uncertainties.shape[2:] - ), - batch[1].reshape( - (batch[1].shape[0] * batch[1].shape[1],) + batch[1].shape[2:] - ), - metrics, - ) - else: - training_batch_metrics = self.compute_metrics( - preds, uncertainties, batch[1], metrics - ) + training_batch_metrics = self.compute_metrics( + preds, uncertainties, batch[1], metrics + ) for k, v in training_batch_metrics.items(): training_losses_and_metrics[k] = v return training_losses_and_metrics - def _val_loop( + def _validation_loop( self, loss_fun: Callable, metrics: Optional[ @@ -335,10 +323,10 @@ def _val_loop( val_dataset_size: int, verbose: bool = True, ) -> Tuple[Dict[str, float], str]: - val_losses_and_metrics_epoch_all_steps = [] - val_epoch_metrics_str = "" + validation_losses_and_metrics_epoch_all_steps = [] + validation_epoch_metrics_str = "" for batch, outputs in zip(val_data_loader, val_outputs_loader): - val_losses_and_metrics_current_batch = self.val_step( + validation_losses_and_metrics_current_batch = self.validation_step( state, batch, outputs, @@ -347,24 +335,24 @@ def _val_loop( val_dataset_size, metrics, ) - val_losses_and_metrics_epoch_all_steps.append( - val_losses_and_metrics_current_batch + validation_losses_and_metrics_epoch_all_steps.append( + validation_losses_and_metrics_current_batch ) # compute validation losses and metrics for the current epoch - val_losses_and_metrics_current_epoch = self.val_epoch_end( - val_losses_and_metrics_epoch_all_steps, state + validation_losses_and_metrics_current_epoch = self.validation_epoch_end( + validation_losses_and_metrics_epoch_all_steps, state ) # logging if verbose: - val_epoch_metrics_str = " | ".join( + validation_epoch_metrics_str = " | ".join( [ f"{m}: {round(float(v), 5)}" - for m, v in val_losses_and_metrics_current_epoch.items() + for m, v in validation_losses_and_metrics_current_epoch.items() ] ) - return val_losses_and_metrics_current_epoch, val_epoch_metrics_str + return validation_losses_and_metrics_current_epoch, validation_epoch_metrics_str - def val_step( + def validation_step( self, state: OutputCalibState, batch: Batch, @@ -376,12 +364,14 @@ def val_step( Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Array], ...] ] = None, ) -> Dict[str, jnp.ndarray]: - val_loss, aux = self.val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - val_metrics = self.val_metrics_step(aux, batch, metrics) - return {"val_loss": val_loss, **val_metrics} + validation_loss, aux = self.validation_loss_step( + state, batch, outputs, loss_fun, rng, n_data + ) + validation_metrics = self.validation_metrics_step(aux, batch, metrics) + return {"validation_loss": validation_loss, **validation_metrics} @abc.abstractmethod - def val_loss_step( + def validation_loss_step( self, state: OutputCalibState, batch: Batch, @@ -392,7 +382,7 @@ def val_loss_step( ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: pass - def val_metrics_step( + def validation_metrics_step( self, aux: Dict[str, jnp.ndarray], batch: Batch, @@ -401,13 +391,13 @@ def val_metrics_step( ] = None, ) -> Dict[str, jnp.ndarray]: if metrics is not None: - val_metrics = self.compute_metrics( + validation_metrics = self.compute_metrics( self.predict_fn(aux["outputs"]), self.uncertainty_fn(aux["outputs"]), batch[1], metrics, ) - return {f"val_{m}": v for m, v in val_metrics.items()} + return {f"validation_{m}": v for m, v in validation_metrics.items()} else: return {} @@ -418,19 +408,21 @@ def training_epoch_end( training_losses_and_metrics_current_epoch ) - def val_epoch_end( + def validation_epoch_end( self, - val_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], + validation_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], state: OutputCalibState, ) -> Dict[str, float]: - val_losses_and_metrics_current_epoch = self._get_mean_losses_and_metrics( - val_losses_and_metrics_current_epoch + validation_losses_and_metrics_current_epoch = self._get_mean_losses_and_metrics( + validation_losses_and_metrics_current_epoch ) # early stopping - improved = self.early_stopping_update(val_losses_and_metrics_current_epoch) - if improved and self.save_checkpoint_dir: + improved = self.early_stopping_update( + validation_losses_and_metrics_current_epoch + ) + if improved and self.save_checkpoint_dir is not None: self.save_checkpoint(state, self.save_checkpoint_dir, force_save=True) - return val_losses_and_metrics_current_epoch + return validation_losses_and_metrics_current_epoch def _get_mean_losses_and_metrics( self, losses_and_metrics: List[Dict[str, jnp.ndarray]] @@ -472,7 +464,7 @@ def on_train_end(self, state: OutputCalibState) -> OutputCalibState: ) return state - def on_val_start(self, state: OutputCalibState) -> OutputCalibState: + def on_validation_start(self, state: OutputCalibState) -> OutputCalibState: return state def compute_metrics( @@ -488,214 +480,3 @@ def compute_metrics( for metric in metrics: metrics_vals[metric.__name__] = metric(preds, uncertainties, targets) return metrics_vals - - -class JittedMixin: - @partial(jax.jit, static_argnums=(0, 4, 6)) - def training_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Tuple[OutputCalibState, Dict[str, Any]]: - return super().training_step(state, batch, outputs, loss_fun, rng, n_data) - - @partial(jax.jit, static_argnums=(0, 4, 6)) - def val_loss_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Dict[str, jnp.ndarray]: - return super().val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - - -class MultiDeviceMixin: - all_reduce_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.multi_device = True - - @staticmethod - def _add_device_dim_to_data_loader(data_loader: DataLoader) -> DataLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch with shape {batch.shape[0]} was found. " - f"Please set an appropriate batch size." - ) - return batch.reshape((n_devices, -1) + batch.shape[1:]) - - class DataLoaderWrapper: - def __init__(self, data_loader: DataLoader): - self._data_loader = data_loader - - def __iter__(self): - data_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._data_loader - ) - data_loader = jax_utils.prefetch_to_device(data_loader, 2) - yield from data_loader - - return ( - DataLoaderWrapper(data_loader) if data_loader is not None else data_loader - ) - - @staticmethod - def _add_device_dim_to_outputs_loader( - outputs_loader: TargetsLoader, - ) -> TargetsLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all output batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch of outputs with shape {batch.shape[0]} was found. " - f"Please set an appropriate batch size." - ) - return batch.reshape((n_devices, -1) + batch.shape[1:]) - - class TargetsLoaderWrapper: - def __init__(self, outputs_loader: TargetsLoader): - self._outputs_loader = outputs_loader - - def __iter__(self): - outputs_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._outputs_loader - ) - outputs_loader = jax_utils.prefetch_to_device(outputs_loader, 2) - yield from outputs_loader - - return ( - TargetsLoaderWrapper(outputs_loader) - if outputs_loader is not None - else outputs_loader - ) - - @staticmethod - def sync_mutable(state: OutputCalibState) -> OutputCalibState: - return ( - state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) - if state.mutable["output_calibrator"] is not None - else state - ) - - @staticmethod - def sync_gradients_and_loss( - grads: jnp.ndarray, loss: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - grad = lax.pmean(grads, axis_name="batch") - loss = lax.pmean(loss, axis_name="batch") - return grad, loss - - def save_checkpoint( - self, - state: OutputCalibState, - save_checkpoint_dir: Path, - keep: int = 1, - force_save: bool = False, - prefix: str = "checkpoint_", - ) -> None: - state = self.sync_mutable(state) - state = jax.device_get(tree_map(lambda x: x[0], state)) - return super(MultiDeviceMixin, self).save_checkpoint( - state, save_checkpoint_dir, keep, force_save, prefix - ) - - def on_train_start( - self, - state: OutputCalibState, - data_loaders: List[DataLoader], - outputs_loaders: List[TargetsLoader], - rng: PRNGKeyArray, - ) -> Tuple[OutputCalibState, List[DataLoader], List[TargetsLoader], PRNGKeyArray]: - state, data_loaders, outputs_loaders, rng = super( - MultiDeviceMixin, self - ).on_train_start(state, data_loaders, outputs_loaders, rng) - state = jax_utils.replicate(state) - data_loaders = [ - self._add_device_dim_to_data_loader(dl) if dl is not None else dl - for dl in data_loaders - ] - outputs_loaders = [ - self._add_device_dim_to_outputs_loader(ol) if ol is not None else ol - for ol in outputs_loaders - ] - model_key = random.split(rng, jax.local_device_count()) - return state, data_loaders, outputs_loaders, model_key - - def on_train_end(self, state: OutputCalibState) -> OutputCalibState: - state = super(MultiDeviceMixin, self).on_train_end(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 4, 6)) - def training_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Tuple[OutputCalibState, Dict[str, Any]]: - return super().training_step(state, batch, outputs, loss_fun, rng, n_data) - - def training_step_end( - self, - current_epoch: int, - state: OutputCalibState, - aux: Dict[str, Any], - batch: Batch, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], - ) -> Dict[str, jnp.ndarray]: - training_losses_and_metrics = super(MultiDeviceMixin, self).training_step_end( - current_epoch, state, aux, batch, metrics - ) - return tree_map(lambda x: x.mean(), training_losses_and_metrics) - - def on_val_start(self, state: OutputCalibState) -> OutputCalibState: - state = super(MultiDeviceMixin, self).on_val_start(state) - if state.mutable["output_calibrator"] is not None: - state = self.sync_mutable(state) - return state - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 4, 6)) - def val_loss_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Dict[str, jnp.ndarray]: - val_losses = super().val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - return lax.pmean(val_losses, axis_name="batch") - - def val_metrics_step( - self, - aux: Dict[str, jnp.ndarray], - batch: Batch, - metrics: Optional[ - Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Array], ...] - ] = None, - ) -> Dict[str, jnp.ndarray]: - outputs = aux["outputs"] - outputs = outputs.reshape(outputs.shape[0] * outputs.shape[1], -1) - targets = batch[1].reshape(batch[1].shape[0] * batch[1].shape[1], -1) - if metrics is not None: - val_metrics = self.compute_metrics( - self.predict_fn(outputs), self.uncertainty_fn(outputs), targets, metrics - ) - return {f"val_{m}": v for m, v in val_metrics.items()} - else: - return {} diff --git a/fortuna/training/output_calibrator/mixins/sharding.py b/fortuna/training/output_calibrator/mixins/sharding.py new file mode 100644 index 00000000..95e0e231 --- /dev/null +++ b/fortuna/training/output_calibrator/mixins/sharding.py @@ -0,0 +1,104 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Tuple, +) + +from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit +import jax.numpy as jnp +from jax.sharding import PartitionSpec + +from fortuna.data.loader import ( + DataLoader, + TargetsLoader, +) +from fortuna.data.loader.base import ShardedPrefetchedLoader +from fortuna.output_calib_model.state import OutputCalibState +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.typing import ( + Array, + Batch, +) + + +class ShardingMixin: + def __init__(self, *, partition_manager: PartitionManager, **kwargs): + super().__init__(partition_manager=partition_manager, **kwargs) + self.partition_manager = partition_manager + + def training_step( + self, + state: OutputCalibState, + batch: Batch, + outputs: Array, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + ) -> Tuple[OutputCalibState, Dict[str, Any]]: + with self.partition_manager.partitioner.mesh: + return pjit( + super().training_step, + static_argnums=(3, 5), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + outputs.sharding, + PartitionSpec(), + ), + )(state, batch, outputs, loss_fun, rng, n_data) + + def validation_loss_step( + self, + state: OutputCalibState, + batch: Batch, + outputs: Array, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + ) -> Dict[str, jnp.ndarray]: + with self.partition_manager.partitioner.mesh: + fun = super().validation_loss_step + return pjit( + fun, + static_argnums=(3, 5), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + outputs.sharding, + PartitionSpec(), + ), + )(state, batch, outputs, loss_fun, rng, n_data) + + def on_train_start( + self, + state: OutputCalibState, + data_loaders: List[DataLoader], + outputs_loaders: List[TargetsLoader], + rng: PRNGKeyArray, + ) -> Tuple[ + OutputCalibState, + List[ShardedPrefetchedLoader], + List[ShardedPrefetchedLoader], + PRNGKeyArray, + ]: + state, data_loaders, output_loaders, rng = super( + ShardingMixin, self + ).on_train_start(state, data_loaders, outputs_loaders, rng) + data_loaders = [ + ShardedPrefetchedLoader( + loader=data_loader, + partition_manager=self.partition_manager, + partition_spec=PartitionSpec(("dp", "fsdp")), + ) + for data_loader in data_loaders + ] + outputs_loaders = [ + ShardedPrefetchedLoader( + loader=output_loader, partition_manager=self.partition_manager + ) + for output_loader in output_loaders + ] + return state, data_loaders, outputs_loaders, rng diff --git a/fortuna/training/train_state_repository.py b/fortuna/training/train_state_repository.py index 25cd7fdd..1d2ceee8 100644 --- a/fortuna/training/train_state_repository.py +++ b/fortuna/training/train_state_repository.py @@ -1,5 +1,4 @@ -from copy import deepcopy -import os +from shutil import rmtree from typing import ( Dict, List, @@ -7,7 +6,11 @@ Union, ) -from fortuna.training.mixin import WithCheckpointingMixin +from jax import eval_shape +from orbax.checkpoint import CheckpointManager + +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.training.train_state import TrainState from fortuna.typing import ( OptaxOptimizer, @@ -16,95 +19,98 @@ class TrainStateRepository(WithCheckpointingMixin): - def __init__(self, checkpoint_dir: Optional[Path] = None): - super().__init__() - self.checkpoint_dir = checkpoint_dir - self.__state = None + def __init__( + self, + partition_manager: Optional[PartitionManager] = None, + checkpoint_manager: Optional[CheckpointManager] = None, + ): + super().__init__(partition_manager=partition_manager) + self.checkpoint_manager = checkpoint_manager + self._state = None def get( self, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, + _do_reshard: bool = True ) -> Union[Dict, TrainState]: - if not checkpoint_path and not self.checkpoint_dir and not self.__state: + if not checkpoint_dir and not self.checkpoint_manager and not self._state: raise ValueError("No state available.") - if checkpoint_path or self.checkpoint_dir: + if checkpoint_dir or self.checkpoint_manager: return self.restore_checkpoint( - restore_checkpoint_path=checkpoint_path or self.checkpoint_dir, - optimizer=optimizer, - prefix=prefix, - **kwargs, + restore_checkpoint_dir=checkpoint_dir, optimizer=optimizer ) if optimizer is not None: - self.__state = self.__state.replace( - tx=optimizer, opt_state=optimizer.init(self.__state.params) - ) - return deepcopy(self.__state) + if self.partition_manager is not None and _do_reshard: + state = self.partition_manager.reshard(self._state) + return state.replace(tx=optimizer, opt_state=optimizer.init(state.params)) + else: + self._state = self._state.replace(tx=optimizer, opt_state=optimizer.init(self._state.params)) + return self._state def put( self, state: TrainState, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, keep: int = 1, - prefix: str = "checkpoint_", ) -> None: - if checkpoint_path or self.checkpoint_dir: + if checkpoint_dir or self.checkpoint_manager: self.save_checkpoint( state=state, - save_checkpoint_dir=checkpoint_path or self.checkpoint_dir, + save_checkpoint_dir=checkpoint_dir, keep=keep, force_save=True, - prefix=prefix, ) else: - self.__state = state + self._state = state + + def remove( + self, + checkpoint_dir: Path = None, + ): + if checkpoint_dir or self.checkpoint_manager: + if checkpoint_dir is None: + step = self.checkpoint_manager.latest_step() + if step is not None: + self.checkpoint_manager.delete(step) + else: + rmtree(checkpoint_dir) def pull( self, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, ) -> TrainState: state = self.get( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, - **kwargs, ) - if checkpoint_path or self.checkpoint_dir: - os.remove( - checkpoint_path - or self.get_path_latest_checkpoint(self.checkpoint_dir, prefix=prefix) - ) + self.remove(checkpoint_dir) return state + def replace( + self, + state: TrainState, + checkpoint_dir: Optional[Path] = None, + keep: int = 1, + ): + self.remove(checkpoint_dir) + self.put(state, checkpoint_dir, keep=keep) + def update( self, variables: Dict, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, keep: int = 1, - prefix: str = "checkpoint_", - **kwargs, ): state = self.pull( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, - **kwargs, ) state = state.replace(**variables) - self.put(state, checkpoint_path=checkpoint_path, keep=keep, prefix=prefix) + self.put(state, checkpoint_dir=checkpoint_dir, keep=keep) - def extract( - self, - keys: List[str], - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", - **kwargs, - ) -> Dict: - state = self.get(checkpoint_path=checkpoint_path, prefix=prefix, **kwargs) + def extract(self, keys: List[str], checkpoint_dir: Optional[Path] = None) -> Dict: + state = self.get(checkpoint_dir=checkpoint_dir) return {k: getattr(state, k) for k in keys} diff --git a/fortuna/training/trainer.py b/fortuna/training/trainer.py index 5959bc2f..b4bfe21f 100755 --- a/fortuna/training/trainer.py +++ b/fortuna/training/trainer.py @@ -2,6 +2,7 @@ import collections from functools import partial import logging +import pathlib from typing import ( Any, Callable, @@ -12,12 +13,10 @@ Union, ) -from flax import jax_utils from flax.core import FrozenDict from flax.training.common_utils import stack_forest import jax from jax import ( - lax, random, value_and_grad, ) @@ -25,16 +24,16 @@ import jax.numpy as jnp from jax.tree_util import tree_map from optax._src.base import PyTree +from orbax.checkpoint import CheckpointManager from tqdm import trange from tqdm.std import tqdm as TqdmDecorator from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.training.callback import Callback -from fortuna.training.mixin import ( - InputValidatorMixin, - WithCheckpointingMixin, - WithEarlyStoppingMixin, -) +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin from fortuna.training.train_state import TrainState from fortuna.typing import ( AnyKey, @@ -71,6 +70,8 @@ def __init__( self, *args, predict_fn: Callable[[jnp.ndarray], jnp.ndarray], + partition_manager: Optional[PartitionManager], + checkpoint_manager: Optional[CheckpointManager], uncertainty_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, save_checkpoint_dir: Optional[Path] = None, save_every_n_steps: Optional[int] = None, @@ -80,7 +81,12 @@ def __init__( freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None, **kwargs, ): - super(TrainerABC, self).__init__(*args, **kwargs) + super(TrainerABC, self).__init__( + *args, + partition_manager=partition_manager, + checkpoint_manager=checkpoint_manager, + **kwargs, + ) self.predict_fn = predict_fn self.uncertainty_fn = uncertainty_fn self.save_checkpoint_dir = save_checkpoint_dir @@ -340,6 +346,7 @@ def validation_epoch_end( self, validation_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], state: TrainState, + mark_checkpoint_as_best: bool = True ) -> Dict[str, float]: validation_losses_and_metrics_current_epoch = self._get_mean_losses_and_metrics( validation_losses_and_metrics_current_epoch @@ -348,8 +355,15 @@ def validation_epoch_end( improved = self.early_stopping_update( validation_losses_and_metrics_current_epoch ) - if improved and self.save_checkpoint_dir: - self.save_checkpoint(state, self.save_checkpoint_dir, force_save=True) + if improved and self.save_checkpoint_dir is not None: + path = self.save_checkpoint_dir + if mark_checkpoint_as_best: + str(pathlib.Path(path) / "best") + self.save_checkpoint( + state, + path, + force_save=True, + ) return validation_losses_and_metrics_current_epoch def train( @@ -357,11 +371,11 @@ def train( rng: PRNGKeyArray, state: TrainState, loss_fun: Callable, - training_dataloader: DataLoader, + training_data_loader: DataLoader, training_dataset_size: int, n_epochs: int = 1, metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - validation_dataloader: Optional[DataLoader] = None, + validation_data_loader: Optional[DataLoader] = None, validation_dataset_size: Optional[int] = None, verbose: bool = True, unravel: Optional[Callable[[any], PyTree]] = None, @@ -369,18 +383,18 @@ def train( **kwargs, ) -> Tuple[TrainState, Status]: training_kwargs = FrozenDict(kwargs) - if validation_dataloader: + if validation_data_loader: assert ( validation_dataset_size is not None - ), "`validation_dataset_size` is required when `validation_dataloader` is provided." + ), "`validation_dataset_size` is required when `validation_data_loader` is provided." training_losses_and_metrics = collections.defaultdict(list) validation_losses_and_metrics = collections.defaultdict(list) - state, dataloaders, rng = self.on_train_start( - state, [training_dataloader, validation_dataloader], rng + state, data_loaders, rng = self.on_train_start( + state, [training_data_loader, validation_data_loader], rng ) - training_dataloader, validation_dataloader = dataloaders + training_data_loader, validation_data_loader = data_loaders progress_bar = trange(n_epochs, desc="Epoch") for epoch in progress_bar: @@ -395,7 +409,7 @@ def train( metrics, rng, state, - training_dataloader, + training_data_loader, training_dataset_size, training_kwargs, verbose, @@ -410,7 +424,7 @@ def train( ) # validation loop - if self.should_perform_validation(validation_dataloader, epoch): + if self.should_perform_validation(validation_data_loader, epoch): # performance evaluation on the whole validation dataset state = self.on_validation_start(state) ( @@ -422,7 +436,7 @@ def train( rng=rng, state=state, training_kwargs=training_kwargs, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=validation_dataset_size, verbose=verbose, unravel=unravel, @@ -460,7 +474,7 @@ def _training_loop( metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], jnp.ndarray], ...]], rng: PRNGKeyArray, state: TrainState, - training_dataloader: DataLoader, + training_data_loader: DataLoader, training_dataset_size: int, training_kwargs: FrozenDict[str, Any], verbose: bool, @@ -476,7 +490,7 @@ def _training_loop( state = self.training_epoch_start(state, callbacks) # ensure to use a different key at each step model_key = self.training_step_start(rng, state.step) - for step, batch in enumerate(training_dataloader): + for step, batch in enumerate(training_data_loader): # forward and backward pass state, aux = self.training_step( state, @@ -539,14 +553,14 @@ def _validation_loop( rng: PRNGKeyArray, state: TrainState, training_kwargs: FrozenDict[str, Any], - validation_dataloader: DataLoader, + validation_data_loader: DataLoader, validation_dataset_size: int, verbose: bool = True, unravel: Optional[Callable[[any], PyTree]] = None, ) -> Tuple[Dict[str, float], str]: validation_losses_and_metrics_epoch_all_steps = [] validation_epoch_metrics_str = "" - for batch in validation_dataloader: + for batch in validation_data_loader: validation_losses_and_metrics_current_batch = self.validation_step( state, batch, @@ -590,10 +604,10 @@ def _get_mean_losses_and_metrics( return losses_and_metrics def should_perform_validation( - self, validation_dataloader: Optional[DataLoader], epoch: int + self, validation_data_loader: Optional[DataLoader], epoch: int ) -> bool: return ( - validation_dataloader is not None + validation_data_loader is not None and self.eval_every_n_epochs > 0 and epoch % self.eval_every_n_epochs == 0 ) @@ -605,7 +619,7 @@ def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: def on_train_start( self, state: TrainState, - dataloaders: List[DataLoader], + data_loaders: List[DataLoader], rng: PRNGKeyArray, ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: if self.freeze_fun is not None: @@ -639,12 +653,18 @@ def on_train_start( ) ), ) - return state, dataloaders, rng + return state, data_loaders, rng - def on_train_end(self, state: TrainState) -> TrainState: + def on_train_end(self, state: TrainState, mark_checkpoint_as_last: bool = True) -> TrainState: + if self.save_checkpoint_dir is not None: + path = pathlib.Path(self.save_checkpoint_dir) + if mark_checkpoint_as_last: + path = str(path / "last") + else: + path = None self.save_checkpoint( state, - save_checkpoint_dir=self.save_checkpoint_dir, + save_checkpoint_dir=path, keep=self.keep_top_n_checkpoints, force_save=True, ) @@ -677,7 +697,7 @@ def compute_metrics( def training_step_start( self, rng: PRNGKeyArray, step: Union[int, jax.Array] ) -> PRNGKeyArray: - step = step if isinstance(step, int) or step.ndim == 0 else step[0] + # step = step if isinstance(step, int) or step.ndim == 0 else step[0] return random.fold_in(rng, step) def _sync_state(self, state: TrainState) -> TrainState: @@ -689,15 +709,16 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", ) -> None: if self.freeze_fun is not None: - state = state.replace( - params=self._get_all_params(state), frozen_params=None + return super().save_checkpoint( + self._sync_state( + state.replace( + params=self._get_all_params(state), frozen_params=None + ) + ), save_checkpoint_dir, keep, force_save ) - return super().save_checkpoint( - self._sync_state(state), save_checkpoint_dir, keep, force_save, prefix - ) + return super().save_checkpoint(state, save_checkpoint_dir, keep, force_save) def _get_all_params( self, state: TrainState, trainable_params: Optional[Params] = None @@ -712,162 +733,3 @@ def _get_all_params( ) ) return trainable_params if trainable_params is not None else state.params - - -class JittedMixin: - @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7)) - def training_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, Any]]: - return super().training_step( - state, batch, loss_fun, rng, n_data, unravel, kwargs - ) - - @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7, 8)) - def validation_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Dict[str, jnp.ndarray]: - return super().validation_step( - state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs - ) - - -class MultiDeviceMixin: - all_reduce_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.multi_device = True - - @staticmethod - def _add_device_dim_to_input_dataloader(dataloader: DataLoader) -> DataLoader: - def _reshape_input_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all batches must be a multiple of {n_devices}, that is the number of " - f"available devices. Please set an appropriate batch size in the data loader." - ) - single_input_shape = batch.shape[1:] - # reshape to (local_devices, device_batch_size, *single_input_shape) - return batch.reshape((n_devices, -1) + single_input_shape) - - class DataLoaderWrapper: - def __init__(self, dataloader): - self.dataloader = dataloader - - def __iter__(self): - dataloader = map( - lambda batch: tree_map(_reshape_input_batch, batch), self.dataloader - ) - dataloader = jax_utils.prefetch_to_device(dataloader, 2) - yield from dataloader - - return DataLoaderWrapper(dataloader) if dataloader is not None else dataloader - - @staticmethod - def _sync_mutable(state: TrainState) -> TrainState: - return ( - state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) - if state.mutable is not None - else state - ) - - @staticmethod - def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: - arr = lax.pmean(arr, axis_name="batch") - return arr - - def _sync_state(self, state: TrainState) -> TrainState: - state = self._sync_mutable(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - def on_train_start( - self, state: TrainState, dataloaders: List[DataLoader], rng: PRNGKeyArray - ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: - state, dataloaders, rng = super(MultiDeviceMixin, self).on_train_start( - state, dataloaders, rng - ) - state = jax_utils.replicate(state) - dataloaders = [ - self._add_device_dim_to_input_dataloader(dl) for dl in dataloaders - ] - model_key = random.split(rng, jax.local_device_count()) - return state, dataloaders, model_key - - def on_train_end(self, state: TrainState) -> TrainState: - state = super(MultiDeviceMixin, self).on_train_end(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - def training_step_start(self, rng: PRNGKeyArray, step: int) -> PRNGKeyArray: - step = step if isinstance(step, int) or step.ndim == 0 else step[0] - return jax.vmap(lambda r: random.fold_in(r, step))(rng) - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7)) - def training_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, Any]]: - return super().training_step( - state, batch, loss_fun, rng, n_data, unravel, kwargs - ) - - def training_step_end( - self, - current_epoch: int, - state: TrainState, - aux: Dict[str, Any], - batch: Batch, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], - callbacks: Optional[List[Callback]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, jnp.ndarray]]: - state, training_losses_and_metrics = super( - MultiDeviceMixin, self - ).training_step_end( - current_epoch, state, aux, batch, metrics, callbacks, kwargs - ) - return state, tree_map(lambda x: x.mean(), training_losses_and_metrics) - - def on_validation_start(self, state: TrainState) -> TrainState: - state = super(MultiDeviceMixin, self).on_validation_start(state) - state = self._sync_mutable(state) - return state - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7, 8)) - def validation_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Dict[str, jnp.ndarray]: - validation_losses_and_metrics = super().validation_step( - state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs - ) - return lax.pmean(validation_losses_and_metrics, axis_name="batch") diff --git a/fortuna/typing.py b/fortuna/typing.py index 62dd1f5c..e8f401b1 100755 --- a/fortuna/typing.py +++ b/fortuna/typing.py @@ -30,3 +30,4 @@ Predictions = jnp.ndarray AnyKey = Union[str, int] Shape = Union[Iterable[int], Dict[str, Iterable[int]]] +AxisDims = Dict diff --git a/fortuna/utils/checkpoint.py b/fortuna/utils/checkpoint.py new file mode 100644 index 00000000..ebedc6e2 --- /dev/null +++ b/fortuna/utils/checkpoint.py @@ -0,0 +1,22 @@ +from typing import Optional +from fortuna.typing import Path + +from orbax.checkpoint import ( + Checkpointer, + CheckpointManager, + CheckpointManagerOptions, + PyTreeCheckpointHandler, +) + + +def get_checkpoint_manager( + checkpoint_dir: Path, keep_top_n_checkpoints: Optional[int] = None +): + if checkpoint_dir is not None: + options = CheckpointManagerOptions( + create=True, max_to_keep=keep_top_n_checkpoints + ) + return CheckpointManager( + checkpoint_dir, Checkpointer(PyTreeCheckpointHandler()), options + ) + return None diff --git a/fortuna/utils/mesh.py b/fortuna/utils/mesh.py new file mode 100644 index 00000000..a0bb456e --- /dev/null +++ b/fortuna/utils/mesh.py @@ -0,0 +1,89 @@ +from typing import Dict + +from jax import local_device_count +from jax.experimental.mesh_utils import create_device_mesh +from jax.interpreters import pxla +from jax.lax import with_sharding_constraint +from jax.sharding import ( + Mesh, + PartitionSpec, +) +import numpy as np + +from fortuna.utils.partition.base import get_names_from_partition_spec + + +def get_mesh(axis_dims: Dict[str, int]): + keys = tuple(axis_dims.keys()) + dims = tuple(axis_dims.values()) + + allowed_keys = ("dp", "fsdp", "mp") + if set(keys) != set(allowed_keys): + raise ValueError( + f"`axis_dims` must contain exactly the following keys: {allowed_keys}." + ) + for v in dims: + if type(v) != int: + raise ValueError("All values in `axes_dims` must be integers or `-1`.") + if len(np.where(np.array(dims) == -1)[0]) > 1: + raise ValueError("At most one axis dimension can be `-1`.") + + n_devices = local_device_count() + + fixed_prod = np.prod([v for v in dims if v != -1]) + reminder = n_devices % fixed_prod + if fixed_prod > n_devices: + raise ValueError( + f"The product of the specified axes dimensions cannot be greater than {n_devices}, " + f"the number of available devices." + ) + if reminder != 0: + raise ValueError( + "The product of the axis dimensions must divide the number of available devices. " + f"However, {n_devices} were found, and {fixed_prod} to be the product of the specified axis " + f"dimensions." + ) + + dims = tuple([dims[np.where(np.array(keys) == k)[0][0]] for k in allowed_keys]) + mesh_shape = np.arange(n_devices).reshape(dims).shape + physical_mesh = create_device_mesh(mesh_shape) + return Mesh(physical_mesh, allowed_keys) + + +def names_in_current_mesh(*names) -> bool: + """ + Check if the axis names in the current mesh contain the names provided. + + Parameters + ---------- + names: List[str] + Provided names. + + Returns + ------- + bool: + Whether the list of provided names is contained in the list of axis names from the current mesh. + """ + mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names + return set(names) <= set(mesh_axis_names) + + +def with_conditional_sharding_constraint(x, partition_specs): + """ + + Parameters + ---------- + x + partition_specs + + Returns + ------- + + """ + """ A smarter version of with_sharding_constraint that only applies the + constraint if the current mesh contains the axes in the partition specs. + """ + axis_names = get_names_from_partition_spec(partition_specs) + if names_in_current_mesh(*axis_names): + x = with_sharding_constraint(x, partition_specs) + return x diff --git a/fortuna/utils/nested_dicts.py b/fortuna/utils/nested_dicts.py index f778036a..754255ce 100644 --- a/fortuna/utils/nested_dicts.py +++ b/fortuna/utils/nested_dicts.py @@ -8,6 +8,13 @@ ) from flax.core import FrozenDict +from jax.tree_util import ( + DictKey, + FlattenedIndexKey, + GetAttrKey, + SequenceKey, + tree_map_with_path, +) from fortuna.typing import AnyKey @@ -213,3 +220,39 @@ def nested_update( else: updated_mapping[k] = v return updated_mapping + + +def path_to_string( + path: Tuple[Union[DictKey, SequenceKey, GetAttrKey, FlattenedIndexKey, AnyKey]], + separator: str = None, +) -> Union[str, Tuple[str]]: + """ + Transform a sequence of keys into a string. + + Parameters + ---------- + path: Tuple[Union[DictKey, SequenceKey, GetAttrKey, FlattenedIndexKey, AnyKey]] + A sequence of keys. + separator: str + A string to interpret as separator. + + Returns + ------- + Union[str, Tuple[str]] + A sequence of keys. + """ + keys = [] + for key in path: + if isinstance(key, SequenceKey): + keys.append(str(key.idx)) + elif isinstance(key, DictKey): + keys.append(str(key.key)) + elif isinstance(key, GetAttrKey): + keys.append(str(key.name)) + elif isinstance(key, FlattenedIndexKey): + keys.append(str(key.key)) + else: + keys.append(str(key)) + if separator is None: + return tuple(keys) + return separator.join(keys) diff --git a/fortuna/utils/partition/__init__.py b/fortuna/utils/partition/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/utils/partition/base.py b/fortuna/utils/partition/base.py new file mode 100644 index 00000000..a6da693e --- /dev/null +++ b/fortuna/utils/partition/base.py @@ -0,0 +1,69 @@ +import re +from typing import ( + Dict, + Tuple, +) + +from jax.sharding import PartitionSpec +from jax.tree_util import tree_map_with_path +import numpy as np +from optax._src.base import PyTree + +from fortuna.utils.nested_dicts import path_to_string + + +def named_tree_map(f, tree, *rest, is_leaf=None, separator=None): + return tree_map_with_path( + lambda string_path, x, *r: f( + path_to_string(string_path, separator=separator), x, *r + ), + tree, + *rest, + is_leaf=is_leaf, + ) + + +def match_partition_specs( + partition_specs: Dict[str, PartitionSpec], tree: PyTree +) -> PyTree: + """ + Match partition specifics to a tree structure. + + Parameters + ---------- + partition_specs: Dict[str, Tuple[str]] + tree: PyTree + + Returns + ------- + PyTree + A tree of partition specifics. + """ + + def get_partition_spec(path, shape_leaf): + if len(shape_leaf.shape) == 0 or np.prod(shape_leaf.shape) == 1: + # do not partition scalar values + return PartitionSpec() + for rule, ps in partition_specs.items(): + if re.search(rule, path) is not None: + return ps + # raise ValueError(f"A partition rule for the following path was not found: `{path}`") + return PartitionSpec() + + return named_tree_map(get_partition_spec, tree, separator="/") + + +def get_names_from_partition_spec(partition_specs): + """Return axis names from partition specs.""" + names = set() + if isinstance(partition_specs, dict): + partition_specs = partition_specs.values() + for item in partition_specs: + if item is None: + continue + elif isinstance(item, str): + names.add(item) + else: + names.update(get_names_from_partition_spec(item)) + + return list(names) diff --git a/fortuna/utils/partition/default.py b/fortuna/utils/partition/default.py new file mode 100644 index 00000000..9787db31 --- /dev/null +++ b/fortuna/utils/partition/default.py @@ -0,0 +1,40 @@ +from fortuna.partitioner.base import Partitioner +from fortuna.typing import AxisDims + + +def get_default_partitioner(model_name_or_path: str, axes_dims: AxisDims): + names = ["gptj", "roberta"] + + if model_name_or_path.lower() == "eleutherai/gpt-j-6b": + return Partitioner( + axes_dims=axes_dims, + rules={ + 'transformer/wte/embedding': ('mp', 'fsdp'), + 'attn/(k_proj|q_proj|v_proj)/kernel': ('fsdp', 'mp'), + 'attn/out_proj/kernel': ('mp', 'fsdp'), + 'mlp/fc_in/kernel': ('fsdp', 'mp'), + 'mlp/fc_in/bias': ('mp',), + 'mlp/fc_out/kernel': ('mp', 'fsdp'), + 'lm_head/kernel': ('fsdp', 'mp'), + 'lm_head/bias': ('mp',), + } + ) + + if model_name_or_path == "roberta-base": + return Partitioner( + axes_dims=axes_dims, + rules={ + 'attention/self/(key|query|value)/kernel': ('fsdp', 'mp'), + 'attention/output/dense/kernel': ('mp', 'fsdp'), + 'intermediate/dense/kernel': ('fsdp', 'mp'), + 'intermediate/dense/bias': ('mp',), + 'output/dense/kernel': ('mp', 'fsdp'), + 'lm_head/decoder/kernel': ('fsdp', 'mp'), + 'lm_head/decoder/bias': ('mp',), + 'lm_head/dense/kernel': ('fsdp', 'mp'), + 'lm_head/dense/bias': ('mp',), + } + ) + + raise ValueError("`model_name_or_path` not recognized." + f"Please choose one among the following options: {names}.") diff --git a/fortuna/utils/port.py b/fortuna/utils/port.py new file mode 100644 index 00000000..1bd99bef --- /dev/null +++ b/fortuna/utils/port.py @@ -0,0 +1,6 @@ +import socket + + +def is_port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 diff --git a/fortuna/utils/probit.py b/fortuna/utils/probit.py index 72b58de8..4e6aaaad 100644 --- a/fortuna/utils/probit.py +++ b/fortuna/utils/probit.py @@ -1,5 +1,4 @@ from typing import ( - Any, Callable, Dict, Optional, @@ -7,7 +6,21 @@ Union, ) +from flax.core import FrozenDict +from jax import ( + ShapeDtypeStruct, + jit, + jvp, + lax, + pure_callback, + vjp, + vmap, +) import jax.numpy as jnp +from jax.tree_util import ( + tree_map, + tree_reduce, +) from fortuna.typing import ( AnyKey, @@ -15,7 +28,12 @@ InputData, Params, ) +from fortuna.utils.freeze import get_paths_with_label from fortuna.utils.grad import value_and_jacobian_squared_row_norm +from fortuna.utils.nested_dicts import ( + nested_get, + nested_set, +) def probit_scaling( @@ -39,3 +57,216 @@ def probit_scaling( if has_aux: return f, aux return f + + +def sequential_probit_scaling( + apply_fn: Callable[[Params, InputData], jnp.ndarray], + params: Params, + x: InputData, + log_var: Union[float, Array], + has_aux: bool = False, + freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None, + top_k: Optional[int] = None, + memory: Optional[int] = None, + n_final_tokens: Optional[int] = None, + stop_gradient: bool = False, +) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]: + params = params.unfreeze() + + params_paths = None + sub_params = None + if freeze_fun is not None: + params_paths = tuple( + get_paths_with_label( + params, freeze_fun, label=True, allowed_labels=[True, False] + ) + ) + sub_params = tuple([nested_get(d=params, keys=path) for path in params_paths]) + + def set_params(_p): + if params_paths is None: + return _p + return FrozenDict(nested_set(d=params, key_paths=params_paths, objs=_p)) + + def _apply_fn(_p, _x, tau): + _f = apply_fn(set_params(_p), _x) + if has_aux: + _f, _ = _f + _f = _f[0] + if _f.ndim > 1: + _f = _f[tau] + return _f + + f = apply_fn(params, x) + if has_aux: + f, aux = f + + if f.ndim > 3: + raise ValueError("The model outputs can be at most three dimensional.") + if f.ndim == 2: + f = f[:, None] + + n_outputs = f.shape[-1] + seq_length = f.shape[1] + if n_final_tokens is None: + n_final_tokens = seq_length + if n_final_tokens <= 0 or n_final_tokens > seq_length: + raise ValueError( + f"`n_final_tokens` must be greater than 0 and cannot be greater than {seq_length}." + ) + if memory is None: + memory = n_final_tokens + if memory <= 0 or memory > n_final_tokens: + raise ValueError( + f"`memory` must be greater than 0 and cannot be greater than {n_final_tokens}." + ) + + block_size = top_k if top_k is not None else n_outputs + tot_size = memory * block_size + batch_size = f.shape[0] + + indices = None + if top_k is not None: + indices = vmap( + lambda _fx: vmap(lambda _fxtau: jnp.argsort(_fxtau)[-top_k:])(_fx) + )(f) + + x = x[:, None] if not isinstance(x, dict) else tree_map(lambda v: v[:, None], x) + + def compute_cov(new_tau, prev_tau): + new_tau -= 1 + prev_tau -= 1 + + @vmap + def _compute_cov(_x, idx): + new_idx = idx[new_tau] if idx is not None else None + prev_idx = idx[prev_tau] if idx is not None else None + size = n_outputs if idx is None else len(prev_idx) + + new_fun = ( + lambda p: _apply_fn(p, _x, new_tau)[new_idx] + if idx is not None + else _apply_fn(p, _x, new_tau) + ) + prev_fun = ( + lambda p: _apply_fn(p, _x, prev_tau)[prev_idx] + if idx is not None + else _apply_fn(p, _x, prev_tau) + ) + + J1J2T_op = lambda v: jvp( + new_fun, + (sub_params if params_paths is not None else params,), + vjp(prev_fun, sub_params if params_paths is not None else params)[1](v), + )[1] + + return vmap(J1J2T_op)(jnp.eye(size)).T + + return jnp.where( + prev_tau != -1, _compute_cov(x, indices), jnp.empty(block_size) + ) + + init_tau = seq_length - n_final_tokens + 1 + + @jit + def compute_P(new_tau, old_taus): + P = vmap(lambda tau: compute_cov(new_tau, tau), out_axes=2)(old_taus) + return P.reshape(P.shape[0], P.shape[1], P.shape[2] * P.shape[3]) + + @vmap + def get_diag(mat): + return jnp.diag(mat) + + def fun(carry, tau): + Jinv, old_taus = carry + S = compute_cov(tau, tau) + + P = compute_P(tau, old_taus) + M = jnp.matmul(P, Jinv) + C = S - jnp.matmul(M, P.swapaxes(1, 2)) + + M = M[:, :, block_size:] + Jinv = Jinv[:, block_size:, block_size:] + Cinv = jnp.linalg.inv(C) + MtCinv = jnp.matmul(M.swapaxes(1, 2), Cinv) + + Jinv = jnp.concatenate( + ( + jnp.concatenate((Jinv + jnp.matmul(MtCinv, M), -MtCinv), axis=2), + jnp.concatenate((-MtCinv.swapaxes(1, 2), Cinv), axis=2), + ), + axis=1, + ) + + old_taus = jnp.concatenate((old_taus[1:], jnp.array([tau]))) + return (Jinv, old_taus), get_diag(C) + + def get_diagCs(_params): + old_taus = jnp.concatenate( + ( + jnp.zeros(memory - 1, dtype="int32") - 1, + jnp.array([init_tau], dtype="int32"), + ) + ) + C = compute_cov(old_taus[-1], old_taus[-1]) + + if n_final_tokens > 1: + Jinv = jnp.linalg.inv(C) + Jinv = jnp.concatenate( + ( + jnp.zeros((batch_size, (memory - 1) * block_size, tot_size)), + jnp.concatenate( + ( + jnp.zeros( + (batch_size, block_size, (memory - 1) * block_size) + ), + Jinv, + ), + axis=2, + ), + ), + axis=1, + ) + + _, diagCs = lax.scan( + fun, (Jinv, old_taus), jnp.arange(init_tau + 1, seq_length + 1) + ) + diagCs = jnp.concatenate( + (get_diag(C)[:, None], diagCs.swapaxes(0, 1)), axis=1 + ) + else: + diagCs = get_diag(C)[:, None] + + return diagCs + + diagCs = get_diagCs(params if sub_params is None else sub_params) + if stop_gradient: + diagCs = lax.stop_gradient(diagCs) + + if top_k is not None: + scales = jnp.max(diagCs, axis=2, keepdims=True).repeat(n_outputs, axis=2) + scales = vmap( + lambda i: vmap( + lambda j: scales[i, j] + .at[indices[i, j]] + .set(diagCs[i, j, indices[i, j]]) + )(jnp.arange(seq_length - n_final_tokens, seq_length)) + )(jnp.arange(batch_size)) + else: + scales = diagCs + + f = jnp.concatenate( + ( + f[:, : seq_length - n_final_tokens], + f[:, seq_length - n_final_tokens :] + / (1 + jnp.pi / 8 * jnp.exp(log_var) * scales), + ), + axis=1, + ) + + if seq_length == 1: + f = f[:, 0] + + if has_aux: + return f, aux + return f diff --git a/poetry.lock b/poetry.lock index 4f6aa650..cc9ffade 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1298,14 +1298,14 @@ files = [ [[package]] name = "flax" -version = "0.6.10" +version = "0.6.11" description = "Flax: A neural network library for JAX designed for flexibility" category = "main" optional = false python-versions = "*" files = [ - {file = "flax-0.6.10-py3-none-any.whl", hash = "sha256:8dccc7b84b00ff6f59a36dc0e79f5919498cfeb009a41f8c07f68bf2513198db"}, - {file = "flax-0.6.10.tar.gz", hash = "sha256:e2174a0df7bb4921f29b2cbd33f55ddf6eed161d6df61809fe374a25e473fb2f"}, + {file = "flax-0.6.11-py3-none-any.whl", hash = "sha256:3ce6843ed47a35abfd86a7eb47db3934a156d08d6513dc8dcb58d461b0dd6f39"}, + {file = "flax-0.6.11.tar.gz", hash = "sha256:ecedf179ceb16c0b511982a293834bb13086168dce1dff697ac083efa818fc72"}, ] [package.dependencies] @@ -2080,16 +2080,17 @@ arrow = ">=0.15.0" [[package]] name = "jax" -version = "0.4.10" +version = "0.4.13" description = "Differentiate, compile, and transform Numpy code." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jax-0.4.10.tar.gz", hash = "sha256:1bf0f2720f778f2937301a16a4d5cd3497f13a4d6c970c24a88918a81816a888"}, + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, ] [package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} ml_dtypes = ">=0.1.0" numpy = ">=1.21" opt_einsum = "*" @@ -2097,38 +2098,40 @@ scipy = ">=1.7" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.9)"] -cpu = ["jaxlib (==0.4.10)"] -cuda = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-cudnn82 = ["jaxlib (==0.4.10+cuda11.cudnn82)"] -cuda11-cudnn86 = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-local = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-pip = ["jaxlib (==0.4.10+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.6)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] -cuda12-local = ["jaxlib (==0.4.10+cuda12.cudnn88)"] -cuda12-pip = ["jaxlib (==0.4.10+cuda12.cudnn88)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] -minimum-jaxlib = ["jaxlib (==0.4.7)"] -tpu = ["jaxlib (==0.4.10)", "libtpu-nightly (==0.1.dev20230511)", "requests"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] [[package]] name = "jaxlib" -version = "0.4.10" +version = "0.4.13" description = "XLA library for JAX" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jaxlib-0.4.10-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:0814c478382e82b0f90aacd820fb898c4a9caa705a1f515c5fd0928198c814f3"}, - {file = "jaxlib-0.4.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:97e7b2b0f32debbb011556cb2dc82cdfb0087b618e302f92319475727408a64e"}, - {file = "jaxlib-0.4.10-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:f46edb93332285ab6f57b2843869183cbd495b4f35bea0fba25a3766a7429306"}, - {file = "jaxlib-0.4.10-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1b705e495149945defe478781865d403bd3994c11e326829aea7aafda0dfa639"}, - {file = "jaxlib-0.4.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:377757745d5e2097fccce71c31292973d544a36329b7ed85bf9c41837e107f74"}, - {file = "jaxlib-0.4.10-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a6349c98c3ffd879b390a3532390e8e49f084aa523c1553aa5c21374ca8b4ea9"}, - {file = "jaxlib-0.4.10-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:7dc9c89b2b07cf8c576d5fca433181f324fed52e51db60873d2b6d3e496588e2"}, - {file = "jaxlib-0.4.10-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62f3d2bad0476bb6728d1be813894cf3421a3d31706a0208b1f57eec86d310d5"}, - {file = "jaxlib-0.4.10-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:6ea7ad6b520732994e25429768d6bd731d55c59c75ef6f9faa2f59e419fb0ada"}, - {file = "jaxlib-0.4.10-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:749a1135a452db1afb4e5de7770fc5dafebb310c35d9db077ed925fcab028471"}, - {file = "jaxlib-0.4.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c966b13467c41ff44ba1b3b7cdceb37a76a75f0420f454a8a51543f8bbaabe4a"}, - {file = "jaxlib-0.4.10-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:07557fcf1e4c7c60bbb48c4f4f426909fcf610a7bfa56cbb139719ba3900722d"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac"}, + {file = "jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8"}, + {file = "jaxlib-0.4.13-cp310-cp310-win_amd64.whl", hash = "sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649"}, + {file = "jaxlib-0.4.13-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89"}, + {file = "jaxlib-0.4.13-cp311-cp311-win_amd64.whl", hash = "sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0"}, + {file = "jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc"}, + {file = "jaxlib-0.4.13-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5"}, + {file = "jaxlib-0.4.13-cp39-cp39-win_amd64.whl", hash = "sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e"}, ] [package.dependencies] @@ -2136,6 +2139,10 @@ ml-dtypes = ">=0.1.0" numpy = ">=1.21" scipy = ">=1.7" +[package.extras] +cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] + [[package]] name = "jedi" version = "0.18.2" diff --git a/tests/fortuna/calib_model/test_calib.py b/tests/fortuna/calib_model/test_calib.py new file mode 100755 index 00000000..5294d0d7 --- /dev/null +++ b/tests/fortuna/calib_model/test_calib.py @@ -0,0 +1,238 @@ +import tempfile + +from flax import linen as nn +import numpy as np + +from fortuna.data.loader import DataLoader +from fortuna.metric.classification import accuracy +from fortuna.metric.regression import rmse +from fortuna.partitioner.base import Partitioner +from fortuna.calib_model import ( + Config, + Monitor, + Checkpointer, + Optimizer +) +from fortuna.calib_model.classification import CalibClassifier +from fortuna.calib_model.regression import CalibRegressor +from tests.make_data import make_array_random_data +from tests.make_model import MyModel + +OUTPUT_DIM = 2 +BATCH_SIZE = 8 +INPUT_SHAPE = (3,) +N_DATA = 16 + + +def accuracy2(preds, probs, targets): + return accuracy(preds, targets) + + +def rmse2(preds, probs, targets): + return rmse(preds, targets) + + +def make_data_loader( + task, + n_data=N_DATA, + input_shape=INPUT_SHAPE, + output_dim=OUTPUT_DIM, + batch_size=BATCH_SIZE, +): + x_train, y_train = make_array_random_data( + n_data=n_data, + shape_inputs=input_shape, + output_dim=output_dim, + output_type="continuous" if task == "regression" else "discrete", + ) + x_train /= np.max(x_train) + if task == "regression": + y_train /= np.max(y_train) + return DataLoader.from_array_data((x_train, y_train), batch_size=batch_size) + + +def config( + task, restore_dir, start_current, save_dir, dump_state, save_n_steps, freeze +): + return Config( + optimizer=Optimizer(n_epochs=3, freeze_fun=freeze), + monitor=Monitor(metrics=(accuracy2 if task == "classification" else rmse2,)), + checkpointer=Checkpointer( + start_from_current_state=start_current, + restore_checkpoint_dir=restore_dir, + save_checkpoint_dir=save_dir, + dump_state=dump_state, + save_every_n_steps=save_n_steps, + ), + ) + + +def calibrate( + task, + model, + calib_data_loader, + val_data_loader, + restore_dir=None, + start_current=False, + save_dir=None, + dump_state=False, + save_n_steps=None, + freeze=None, +): + model.calibrate( + calib_data_loader=calib_data_loader, + val_data_loader=val_data_loader, + config=config( + task, + restore_dir, + start_current, + save_dir, + dump_state, + save_n_steps, + freeze, + ), + ) + + +def define_calib_model(task, model_editor=None): + partitioner = Partitioner( + axes_dims={"mp": 2, "fsdp": 2, "dp": 2}, + rules={"l1/kernel": (None, "mp"), "bn1": ("mp",)}, + ) + + if task == "regression": + return CalibRegressor( + model=MyModel(OUTPUT_DIM), + likelihood_log_variance_model=MyModel(OUTPUT_DIM), + model_editor=model_editor, + partitioner=partitioner, + ) + else: + return CalibClassifier( + model=MyModel(OUTPUT_DIM), + model_editor=model_editor, + partitioner=partitioner, + ) + + +class ModelEditor(nn.Module): + @nn.compact + def __call__(self, apply_fn, model_params, x, has_aux: bool): + log_temp = self.param("log_temp", nn.initializers.zeros, (1,)) + f = apply_fn(model_params, x) + if has_aux: + f, aux = f + f += log_temp + if has_aux: + return f, aux + return f + + +def dryrun_task(task): + freeze_fun = lambda p, v: "trainable" if "l2" in p and "model" in p else "frozen" + + calib_data_loader = make_data_loader(task) + val_data_loader = make_data_loader(task) + + calib_model = define_calib_model(task) + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + ) + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + start_current=True, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + save_dir=tmp_dir, + dump_state=True, + ) + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + restore_dir=tmp_dir, + ) + + calib_model = define_calib_model(task) + calib_model.load_state(tmp_dir + "/last") + calib_model.predictive.log_prob(calib_data_loader) + + calib_model = define_calib_model(task) + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + freeze=freeze_fun, + ) + + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + start_current=True, + freeze=freeze_fun, + ) + + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + save_dir=tmp_dir + "/frozen", + dump_state=True, + freeze=freeze_fun, + ) + + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + restore_dir=tmp_dir + "/frozen", + freeze=freeze_fun, + ) + + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + start_current=True, + save_dir=tmp_dir + "/frozen/tmp", + save_n_steps=1, + freeze=freeze_fun, + ) + calib_model = define_calib_model(task) + calib_model.load_state(tmp_dir + "/frozen/tmp") + calib_model.predictive.log_prob(calib_data_loader) + + calib_model = define_calib_model(task, model_editor=ModelEditor()) + calibrate( + task, + calib_model, + calib_data_loader, + val_data_loader, + ) + + +def test_dryrun_classification(): + dryrun_task(task="classification") + + +def test_dryrun_regression(): + dryrun_task(task="regression") diff --git a/tests/fortuna/calib_model/test_calib_model.py b/tests/fortuna/calib_model/test_calib_model.py index 437e3d6c..5d83edfc 100644 --- a/tests/fortuna/calib_model/test_calib_model.py +++ b/tests/fortuna/calib_model/test_calib_model.py @@ -128,7 +128,7 @@ def __init__(self, *args, **kwargs): ) self.class_config_restore = lambda restore_dir: Config( optimizer=Optimizer(n_epochs=3), - checkpointer=Checkpointer(restore_checkpoint_path=restore_dir), + checkpointer=Checkpointer(restore_checkpoint_dir=restore_dir), ) self.reg_config_nodir_nodump = Config( optimizer=Optimizer(n_epochs=3), monitor=Monitor(metrics=(scaled_mse,)) @@ -150,7 +150,7 @@ def __init__(self, *args, **kwargs): ) self.reg_config_restore = lambda restore_dir: Config( optimizer=Optimizer(n_epochs=3), - checkpointer=Checkpointer(restore_checkpoint_path=restore_dir), + checkpointer=Checkpointer(restore_checkpoint_dir=restore_dir), ) def test_dryrun_reg(self): @@ -217,10 +217,10 @@ def test_dryrun_reg(self): ) # load state - calib_reg.load_state(checkpoint_path=tmp_dir) + calib_reg.load_state(checkpoint_dir=tmp_dir) # save state - calib_reg.save_state(checkpoint_path=tmp_dir) + calib_reg.save_state(checkpoint_dir=tmp_dir) # model_editor calib_reg = CalibRegressor( @@ -304,10 +304,10 @@ def test_dryrun_class(self): ) # load state - calib_class.load_state(checkpoint_path=tmp_dir) + calib_class.load_state(checkpoint_dir=tmp_dir) # save state - calib_class.save_state(checkpoint_path=tmp_dir) + calib_class.save_state(checkpoint_dir=tmp_dir) # model_editor calib_class = CalibClassifier(model=model, model_editor=ModelEditor()) diff --git a/tests/fortuna/calib_model/test_output_calib_model.py b/tests/fortuna/calib_model/test_output_calib_model.py index 96b7a3a8..4f180448 100644 --- a/tests/fortuna/calib_model/test_output_calib_model.py +++ b/tests/fortuna/calib_model/test_output_calib_model.py @@ -123,7 +123,7 @@ def __init__(self, *args, **kwargs): self.calib_config_restore = lambda directory, metric: Config( optimizer=Optimizer(n_epochs=3), monitor=Monitor(metrics=(metric,)), - checkpointer=Checkpointer(restore_checkpoint_path=directory), + checkpointer=Checkpointer(restore_checkpoint_dir=directory), ) def test_dryrun_reg_map(self): @@ -178,10 +178,10 @@ def test_dryrun_reg_map(self): ) # load state - calib_model.load_state(checkpoint_path=tmp_dir) + calib_model.load_state(checkpoint_dir=tmp_dir) # save state - calib_model.save_state(checkpoint_path=tmp_dir) + calib_model.save_state(checkpoint_dir=tmp_dir) def test_dryrun_class_map(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -234,7 +234,7 @@ def test_dryrun_class_map(self): ) # load state - calib_model.load_state(checkpoint_path=tmp_dir) + calib_model.load_state(checkpoint_dir=tmp_dir) # save state - calib_model.save_state(checkpoint_path=tmp_dir) + calib_model.save_state(checkpoint_dir=tmp_dir) diff --git a/tests/fortuna/prob_model/test_train.py b/tests/fortuna/prob_model/test_train.py index 05fda2f0..bbff0eef 100755 --- a/tests/fortuna/prob_model/test_train.py +++ b/tests/fortuna/prob_model/test_train.py @@ -1,6 +1,6 @@ import tempfile -import flax.linen as nn +from flax import linen as nn import jax.numpy as jnp import numpy as np import pytest @@ -8,9 +8,11 @@ from fortuna.data.loader import DataLoader from fortuna.metric.classification import accuracy from fortuna.metric.regression import rmse +from fortuna.partitioner.base import Partitioner from fortuna.prob_model import ( CalibConfig, CalibOptimizer, + CalibProcessor, FitConfig, FitMonitor, SNGPPosteriorApproximator, @@ -45,7 +47,7 @@ OUTPUT_DIM = 2 BATCH_SIZE = 8 INPUT_SHAPE = (3,) -N_DATA = 10 +N_DATA = 16 METHODS = { "map": MAPPosteriorApproximator(), @@ -81,14 +83,14 @@ def make_data_loader( def fit_config( - task, restore_path, start_current, save_dir, dump_state, save_n_steps, freeze + task, restore_dir, start_current, save_dir, dump_state, save_n_steps, freeze ): return FitConfig( optimizer=FitOptimizer(n_epochs=3, freeze_fun=freeze), monitor=FitMonitor(metrics=(accuracy if task == "classification" else rmse,)), checkpointer=FitCheckpointer( start_from_current_state=start_current, - restore_checkpoint_path=restore_path, + restore_checkpoint_dir=restore_dir, save_checkpoint_dir=save_dir, dump_state=dump_state, save_every_n_steps=save_n_steps, @@ -96,7 +98,10 @@ def fit_config( ) -calib_config = CalibConfig(optimizer=CalibOptimizer(n_epochs=3)) +calib_config = CalibConfig( + optimizer=CalibOptimizer(n_epochs=3), + processor=CalibProcessor(n_posterior_samples=2), +) def train( @@ -105,7 +110,7 @@ def train( train_data_loader, val_data_loader, calib_data_loader, - restore_path=None, + restore_dir=None, start_current=False, save_dir=None, dump_state=False, @@ -119,7 +124,7 @@ def train( calib_data_loader=calib_data_loader, fit_config=fit_config( task, - restore_path, + restore_dir, start_current, save_dir, dump_state, @@ -147,7 +152,7 @@ def train_and_sample( train_data_loader, val_data_loader, calib_data_loader, - restore_path=None, + restore_dir=None, start_current=False, save_dir=None, dump_state=False, @@ -161,7 +166,7 @@ def train_and_sample( train_data_loader, val_data_loader, calib_data_loader, - restore_path, + restore_dir, start_current, save_dir, dump_state, @@ -173,12 +178,18 @@ def train_and_sample( def define_prob_model(task, method, model_editor=None): + partitioner = Partitioner( + axes_dims={"mp": 1, "fsdp": 1, "dp": 1}, + rules={"l1/kernel": (None, "mp"), "bn1": ("mp",)}, + ) + if task == "regression": return ProbRegressor( model=MyModel(OUTPUT_DIM), likelihood_log_variance_model=MyModel(OUTPUT_DIM), posterior_approximator=METHODS[method], model_editor=model_editor, + partitioner=partitioner, ) else: return ProbClassifier( @@ -187,6 +198,7 @@ def define_prob_model(task, method, model_editor=None): else MyModelWithSpectralNorm(OUTPUT_DIM), posterior_approximator=METHODS[method], model_editor=model_editor, + partitioner=partitioner, ) @@ -213,13 +225,14 @@ def dryrun_task(task, method): prob_model = define_prob_model(task, method) map_fit_config = fit_config( task, - restore_path=None, + restore_dir=None, start_current=None, save_dir=None, dump_state=False, save_n_steps=None, freeze=None, ) + train_and_sample( task, method, @@ -229,6 +242,7 @@ def dryrun_task(task, method): calib_data_loader, map_fit_config=map_fit_config, ) + train_and_sample( task, method, @@ -252,13 +266,14 @@ def dryrun_task(task, method): with tempfile.TemporaryDirectory() as tmp_dir: map_fit_config = fit_config( task, - restore_path=None, + restore_dir=None, start_current=None, save_dir=None, dump_state=False, save_n_steps=None, freeze=None, ) + train_and_sample( task, method, @@ -270,6 +285,7 @@ def dryrun_task(task, method): save_dir=tmp_dir, dump_state=True, ) + train_and_sample( task, method, @@ -277,7 +293,7 @@ def dryrun_task(task, method): train_data_loader, val_data_loader, calib_data_loader, - restore_path=tmp_dir, + restore_dir=tmp_dir, ) prob_model = define_prob_model(task, method) @@ -285,6 +301,7 @@ def dryrun_task(task, method): sample(method, prob_model, train_data_loader) prob_model.predictive.log_prob(train_data_loader) + prob_model = define_prob_model(task, method) if method not in ["laplace", "swag"]: train_and_sample( task, @@ -296,6 +313,17 @@ def dryrun_task(task, method): freeze=freeze_fun, ) + train_and_sample( + task, + method, + prob_model, + train_data_loader, + val_data_loader, + calib_data_loader, + start_current=True, + freeze=freeze_fun, + ) + train_and_sample( task, method, @@ -303,21 +331,12 @@ def dryrun_task(task, method): train_data_loader, val_data_loader, calib_data_loader, - start_current=True, - freeze=freeze_fun, - ) - train_and_sample( - task, - method, - prob_model, - train_data_loader, - val_data_loader, - calib_data_loader, - save_dir=tmp_dir, + save_dir=tmp_dir + "2", dump_state=True, - restore_path=tmp_dir, + restore_dir=tmp_dir, freeze=freeze_fun, ) + train_and_sample( task, method, @@ -325,11 +344,12 @@ def dryrun_task(task, method): train_data_loader, val_data_loader, calib_data_loader, - save_dir=tmp_dir, + save_dir=tmp_dir + "3", dump_state=True, - restore_path=tmp_dir, + restore_dir=tmp_dir + "2", freeze=freeze_fun, ) + train_and_sample( task, method, @@ -339,17 +359,21 @@ def dryrun_task(task, method): calib_data_loader, map_fit_config=fit_config( task, - restore_path=None, + restore_dir=None, start_current=None, save_dir=None, dump_state=False, save_n_steps=None, freeze=None, ), - save_dir=tmp_dir, + save_dir=tmp_dir + "4", dump_state=True, freeze=freeze_fun, ) + prob_model = define_prob_model(task, method) + prob_model.load_state(tmp_dir + "4") + sample(method, prob_model, train_data_loader) + prob_model.predictive.log_prob(train_data_loader) train_and_sample( task, @@ -359,14 +383,10 @@ def dryrun_task(task, method): val_data_loader, calib_data_loader, start_current=True, - save_dir=tmp_dir + "/tmp", + save_dir=tmp_dir + "5", save_n_steps=1, freeze=freeze_fun, ) - prob_model = define_prob_model(task, method) - prob_model.load_state(tmp_dir + "/tmp") - sample(method, prob_model, train_data_loader) - prob_model.predictive.log_prob(train_data_loader) prob_model = define_prob_model(task, method, model_editor=ModelEditor()) train_and_sample( diff --git a/tests/fortuna/test_mixin.py b/tests/fortuna/test_mixin.py index 642ab8c7..e834ec5d 100755 --- a/tests/fortuna/test_mixin.py +++ b/tests/fortuna/test_mixin.py @@ -7,10 +7,9 @@ from fortuna.prob_model.posterior.posterior_mixin import WithPosteriorCheckpointingMixin from fortuna.prob_model.posterior.state import PosteriorState -from fortuna.training.mixin import ( - InputValidatorMixin, - WithEarlyStoppingMixin, -) +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin class FakeTrainerWithCheckpointing( diff --git a/tests/fortuna/test_predictive.py b/tests/fortuna/test_predictive.py index 9a6bd5ca..fab79cb1 100755 --- a/tests/fortuna/test_predictive.py +++ b/tests/fortuna/test_predictive.py @@ -97,8 +97,7 @@ def test_pred_stats(self): assert ensemble_log_probs.shape == (self.n_post_samples, self.n_inputs) sample = self.prob_class.predictive.sample( - self.class_inputs_loader, - n_target_samples=self.n_post_samples, + self.class_inputs_loader, n_target_samples=self.n_post_samples ) assert sample.shape == ( self.n_post_samples, @@ -106,8 +105,7 @@ def test_pred_stats(self): ) sample = self.prob_reg.predictive.sample( - self.reg_inputs_loader, - n_target_samples=self.n_post_samples, + self.reg_inputs_loader, n_target_samples=self.n_post_samples ) assert sample.shape == (self.n_post_samples, self.n_inputs, self.output_dim) diff --git a/tests/fortuna/test_trainer.py b/tests/fortuna/test_trainer.py index fd243915..a6b81fa1 100755 --- a/tests/fortuna/test_trainer.py +++ b/tests/fortuna/test_trainer.py @@ -568,7 +568,7 @@ def test_should_perform_validation(self): self.assertTrue(trainer.should_perform_validation({}, 10)) def test__validation_loop(self): - validation_dataloader = [ + validation_data_loader = [ [jnp.array([[0, 0.0, 0.0], [0, 0.0, 0]]), jnp.array([0.0, 0.0])], [jnp.array([[0.1, 0.0, 10], [0, 0.0, 0]]), jnp.array([1.0, 0.0])], ] @@ -580,7 +580,7 @@ def test__validation_loop(self): observed_validation_epoch_metrics_str, ) = trainer._validation_loop( state=None, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=2, loss_fun=lambda x: x, rng=jax.random.PRNGKey(0), @@ -600,7 +600,7 @@ def test__validation_loop(self): observed_validation_epoch_metrics_str, ) = trainer._validation_loop( state=None, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=2, loss_fun=lambda x: x, rng=jax.random.PRNGKey(0), @@ -618,7 +618,7 @@ def test__validation_loop(self): ) def test__training_loop(self): - training_dataloader = [ + training_data_loader = [ [jnp.array([[0, 0.0, 0.0], [0, 0.0, 0]]), jnp.array([0.0, 0.0])], [jnp.array([[0.1, 0.0, 10], [0, 0.0, 0]]), jnp.array([1.0, 0.0])], ] @@ -635,7 +635,7 @@ def test__training_loop(self): metrics=(accuracy,), rng=jax.random.PRNGKey(0), state=FakeTrainState(), - training_dataloader=training_dataloader, + training_data_loader=training_data_loader, training_dataset_size=2, training_kwargs=FrozenDict({}), unravel=None, diff --git a/tests/make_model.py b/tests/make_model.py index 1469ea63..dec82faf 100644 --- a/tests/make_model.py +++ b/tests/make_model.py @@ -1,3 +1,5 @@ +from functools import partial + import flax.linen as nn import jax.numpy as jnp @@ -7,6 +9,7 @@ class MyModel(nn.Module): output_dim: int dense: nn.Module = nn.Dense + dtype: str = "float32" @nn.compact def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: @@ -14,8 +17,16 @@ def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: dense = self.spectral_norm(self.dense, train=train) else: dense = self.dense + norm = partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) x = x.reshape(x.shape[0], -1) - x = dense(2, name="l1")(x) + x = dense(4, name="l1")(x) + x = norm(name="bn1")(x) x = nn.Dropout(rate=0.9)(x, deterministic=not train) x = dense(self.output_dim, name="l2")(x) return x