From 44449075ec5f82b58019893f2076795b2f6f9143 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Sun, 25 Jun 2023 18:25:53 +0200 Subject: [PATCH] enable model and data sharding - create partition manager object - make MAP compatible - migrate to Orbax checkpointing - refactor predictive --- .../transformers/masked_language_modeling.py | 10 +- .../prob_model_text_classification.py | 26 +- .../output_calib_model/output_calib_model.rst | 6 +- examples/scaling_up_bayesian_inference.pct.py | 2 +- fortuna/calib_model/base.py | 26 +- fortuna/calib_model/calib_mixin.py | 60 +- fortuna/calib_model/calib_model_calibrator.py | 8 +- fortuna/calib_model/config/checkpointer.py | 8 +- fortuna/data/dataset/huggingface_datasets.py | 4 +- fortuna/data/loader/base.py | 33 +- fortuna/data/loader/huggingface_loaders.py | 2 +- fortuna/data/loader/utils.py | 42 +- fortuna/likelihood/base.py | 10 +- fortuna/model/llama.py | 1268 +++++++++++++++++ fortuna/model/model_manager/base.py | 7 +- fortuna/model/model_manager/classification.py | 2 +- fortuna/model/model_manager/regression.py | 7 +- .../transformers/classification.py | 2 +- fortuna/output_calib_model/base.py | 34 +- .../output_calib_model/config/checkpointer.py | 6 +- .../output_calib_model/output_calib_mixin.py | 40 - .../output_calibrator/__init__.py | 0 .../base.py} | 12 +- .../output_calib_manager/base.py | 6 +- fortuna/partitioner/__init__.py | 0 fortuna/partitioner/base.py | 39 + .../partitioner/partition_manager/__init__.py | 0 fortuna/partitioner/partition_manager/base.py | 78 + fortuna/prob_model/base.py | 73 +- .../prob_model/calib_config/checkpointer.py | 6 +- fortuna/prob_model/classification.py | 13 +- fortuna/prob_model/fit_config/checkpointer.py | 25 +- fortuna/prob_model/joint/base.py | 17 +- fortuna/prob_model/posterior/base.py | 66 +- .../deep_ensemble/deep_ensemble_posterior.py | 14 +- .../posterior/laplace/laplace_posterior.py | 2 +- .../prob_model/posterior/map/map_posterior.py | 91 +- .../prob_model/posterior/map/map_trainer.py | 11 +- .../normalizing_flow/advi/advi_posterior.py | 4 +- .../normalizing_flow/advi/advi_trainer.py | 12 +- .../prob_model/posterior/posterior_mixin.py | 27 +- .../posterior_multi_state_repository.py | 36 +- .../posterior/posterior_state_repository.py | 8 +- .../cyclical_sgld/cyclical_sgld_posterior.py | 4 +- .../posterior/sgmcmc/sghmc/sghmc_posterior.py | 4 +- .../posterior/sgmcmc/sgmcmc_posterior.py | 10 +- .../sgmcmc_posterior_state_repository.py | 24 +- .../posterior/swag/swag_posterior.py | 6 +- .../prob_model/posterior/swag/swag_trainer.py | 8 +- fortuna/prob_model/predictive/base.py | 472 +++--- fortuna/prob_model/predictive/regression.py | 5 +- fortuna/prob_model/prob_model_calibrator.py | 59 +- fortuna/prob_model/regression.py | 14 +- fortuna/training/mixin.py | 163 --- fortuna/training/mixins/checkpointing.py | 160 +++ fortuna/training/mixins/early_stopping.py | 71 + fortuna/training/mixins/input_validator.py | 9 + fortuna/training/mixins/jitted.py | 54 + fortuna/training/mixins/multi_device.py | 158 ++ fortuna/training/mixins/sharding.py | 101 ++ .../training/output_calibrator/__init__.py | 0 .../base.py} | 333 +---- .../output_calibrator/mixins/sharding.py | 104 ++ fortuna/training/train_state_repository.py | 90 +- fortuna/training/trainer.py | 238 +--- fortuna/utils/checkpoint.py | 21 + fortuna/utils/mesh.py | 389 +---- fortuna/utils/nested_dicts.py | 43 + fortuna/utils/partition.py | 73 + fortuna/utils/port.py | 6 + poetry.lock | 339 ++++- tests/fortuna/calib_model/test_calib_model.py | 12 +- .../calib_model/test_output_calib_model.py | 10 +- tests/fortuna/prob_model/test_train.py | 45 +- tests/fortuna/test_mixin.py | 7 +- tests/fortuna/test_predictive.py | 6 +- tests/fortuna/test_trainer.py | 10 +- tests/make_model.py | 13 +- 78 files changed, 3479 insertions(+), 1725 deletions(-) create mode 100644 fortuna/model/llama.py delete mode 100644 fortuna/output_calib_model/output_calib_mixin.py create mode 100644 fortuna/output_calib_model/output_calibrator/__init__.py rename fortuna/output_calib_model/{output_calib_model_calibrator.py => output_calibrator/base.py} (98%) create mode 100644 fortuna/partitioner/__init__.py create mode 100644 fortuna/partitioner/base.py create mode 100644 fortuna/partitioner/partition_manager/__init__.py create mode 100644 fortuna/partitioner/partition_manager/base.py delete mode 100755 fortuna/training/mixin.py create mode 100644 fortuna/training/mixins/checkpointing.py create mode 100644 fortuna/training/mixins/early_stopping.py create mode 100644 fortuna/training/mixins/input_validator.py create mode 100644 fortuna/training/mixins/jitted.py create mode 100644 fortuna/training/mixins/multi_device.py create mode 100644 fortuna/training/mixins/sharding.py create mode 100644 fortuna/training/output_calibrator/__init__.py rename fortuna/training/{output_calibrator.py => output_calibrator/base.py} (55%) create mode 100644 fortuna/training/output_calibrator/mixins/sharding.py create mode 100644 fortuna/utils/checkpoint.py create mode 100644 fortuna/utils/partition.py create mode 100644 fortuna/utils/port.py diff --git a/benchmarks/transformers/masked_language_modeling.py b/benchmarks/transformers/masked_language_modeling.py index 02f3ec39..539b0156 100644 --- a/benchmarks/transformers/masked_language_modeling.py +++ b/benchmarks/transformers/masked_language_modeling.py @@ -193,13 +193,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: try: logger.info(list(pathlib.Path(args.restore_checkpoint_dir).rglob("*"))) - restore_checkpoint_path = unpack_model_tar( + restore_checkpoint_dir = unpack_model_tar( list(pathlib.Path(args.restore_checkpoint_dir).rglob("*"))[0] ) - logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*"))) + logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*"))) except: logger.info("No checkpoint to restore") - restore_checkpoint_path = None + restore_checkpoint_dir = None tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) @@ -341,7 +341,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray: save_checkpoint_dir=args.save_checkpoint_dir, save_every_n_steps=args.save_every_n_steps, keep_top_n_checkpoints=args.keep_top_n_checkpoints, - restore_checkpoint_path=restore_checkpoint_path, + restore_checkpoint_dir=restore_checkpoint_dir, ), ) if args.last_layer_only and ( @@ -357,7 +357,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray: and args.last_layer_only else None, ) - if restore_checkpoint_path is not None: + if restore_checkpoint_dir is not None: fit_config.optimizer = last_layer_optimizer train_kwargs = {"fit_config": fit_config} else: diff --git a/benchmarks/transformers/prob_model_text_classification.py b/benchmarks/transformers/prob_model_text_classification.py index f07859ee..946e3deb 100644 --- a/benchmarks/transformers/prob_model_text_classification.py +++ b/benchmarks/transformers/prob_model_text_classification.py @@ -240,13 +240,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: try: logger.info(list(pathlib.Path(args.load_model_dir).rglob("*"))) - restore_checkpoint_path = unpack_model_tar( + restore_checkpoint_dir = unpack_model_tar( list(pathlib.Path(args.load_model_dir).rglob("*"))[0] ) - logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*"))) + logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*"))) except: logger.info("No checkpoint to restore") - restore_checkpoint_path = None + restore_checkpoint_dir = None tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) @@ -400,11 +400,17 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: model_editor = None if args.enable_probit_model_editor: - probit_freeze_fun = lambda p, v: True if "classifier" in p else False if args.probit_last_layer_only else None + probit_freeze_fun = ( + lambda p, v: True + if "classifier" in p + else False + if args.probit_last_layer_only + else None + ) model_editor = ProbitModelEditor( freeze_fun=probit_freeze_fun, init_log_var=args.probit_init_log_var, - stop_gradient=args.probit_stop_gradient + stop_gradient=args.probit_stop_gradient, ) ### TRAINING @@ -438,7 +444,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: save_checkpoint_dir=args.output_data_dir, save_every_n_steps=args.save_every_n_steps, keep_top_n_checkpoints=args.keep_top_n_checkpoints, - restore_checkpoint_path=restore_checkpoint_path, + restore_checkpoint_dir=restore_checkpoint_dir, ), callbacks=[ ResetCovarianceCallback( @@ -469,7 +475,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: last_layer_optimizer = FitOptimizer( method=optimizer, n_epochs=args.num_train_epochs, freeze_fun=freeze_fun ) - if restore_checkpoint_path is not None: + if restore_checkpoint_dir is not None: fit_config.optimizer = last_layer_optimizer train_kwargs = {"fit_config": fit_config} else: @@ -494,11 +500,11 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: calib_data_loader=None, **train_kwargs, ) - elif restore_checkpoint_path is not None: - prob_model.load_state(restore_checkpoint_path) + elif restore_checkpoint_dir is not None: + prob_model.load_state(restore_checkpoint_dir) else: raise ValueError( - "Either restore_checkpoint_path or num_train_epochs > 0 should be specified." + "Either restore_checkpoint_dir or num_train_epochs > 0 should be specified." ) if args.enable_probit_model_editor: diff --git a/docs/source/references/output_calib_model/output_calib_model.rst b/docs/source/references/output_calib_model/output_calib_model.rst index 51334698..165a4019 100644 --- a/docs/source/references/output_calib_model/output_calib_model.rst +++ b/docs/source/references/output_calib_model/output_calib_model.rst @@ -8,19 +8,19 @@ Please find their references below. .. automodule:: fortuna.output_calib_model.classification :members: - :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint + :exclude-members: save_checkpoint, restore_checkpoint .. _output_calib_regressor: .. automodule:: fortuna.output_calib_model.regression :members: - :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint + :exclude-members: save_checkpoint, restore_checkpoint .. _output_calib_base: .. automodule:: fortuna.output_calib_model.base :members: - :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint + :exclude-members: save_checkpoint, restore_checkpoint .. toctree:: :maxdepth: 1 diff --git a/examples/scaling_up_bayesian_inference.pct.py b/examples/scaling_up_bayesian_inference.pct.py index 4cdb2952..3c4b2e3b 100644 --- a/examples/scaling_up_bayesian_inference.pct.py +++ b/examples/scaling_up_bayesian_inference.pct.py @@ -89,7 +89,7 @@ def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: # We are ready to call `prob_model.train`, which will perform posterior inference under-the-hood. In order to do Bayesian inference on the last layer only and freeze the other parameters, all we need to do is to pass a function `freeze_fun` to the optimizer configuration object, deciding which parameters should be "frozen" and which should be "trainable". # -# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_path`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`. +# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_dir`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`. from fortuna.prob_model import FitConfig, FitOptimizer diff --git a/fortuna/calib_model/base.py b/fortuna/calib_model/base.py index 3c5a379a..f3176442 100644 --- a/fortuna/calib_model/base.py +++ b/fortuna/calib_model/base.py @@ -137,11 +137,11 @@ def _calibrate( rng=self.rng.get(), state=state, loss_fun=loss, - training_dataloader=calib_data_loader, + training_data_loader=calib_data_loader, training_dataset_size=n_calib_data, n_epochs=config.optimizer.n_epochs, metrics=config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=n_val_data, verbose=config.monitor.verbose, callbacks=config.callbacks, @@ -158,30 +158,28 @@ def _calibrate( logging.info("Calibration completed.") return status - def load_state(self, checkpoint_path: Path) -> None: + def load_state(self, checkpoint_dir: Path) -> None: """ Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the probabilistic model. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to a checkpoint file or directory to restore. """ try: - self.restore_checkpoint(checkpoint_path) + self.restore_checkpoint(checkpoint_dir) except ValueError: raise ValueError( - f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." + f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`." ) - self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_path) + self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_dir) - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: return self.predictive.state.put( self.predictive.state.get(), - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints, ) @@ -224,7 +222,7 @@ def _init(self, data_loader: DataLoader, config: Config): ) def _init_state(self, calib_data_loader: DataLoader, config: Config) -> CalibState: - if config.checkpointer.restore_checkpoint_path is None: + if config.checkpointer.restore_checkpoint_dir is None: if config.checkpointer.start_from_current_state: state = self.predictive.state.get(optimizer=config.optimizer.method) else: @@ -233,10 +231,10 @@ def _init_state(self, calib_data_loader: DataLoader, config: Config) -> CalibSta if config.checkpointer.start_from_current_state: logging.warning( "`config.checkpointer.start_from_current_state` will be ignored since " - "`config.checkpointer.restore_checkpoint_path` is given." + "`config.checkpointer.restore_checkpoint_dir` is given." ) state = self.restore_checkpoint( - restore_checkpoint_path=config.checkpointer.restore_checkpoint_path, + restore_checkpoint_dir=config.checkpointer.restore_checkpoint_dir, optimizer=config.optimizer.method, ) return state diff --git a/fortuna/calib_model/calib_mixin.py b/fortuna/calib_model/calib_mixin.py index 029e991e..d2577061 100644 --- a/fortuna/calib_model/calib_mixin.py +++ b/fortuna/calib_model/calib_mixin.py @@ -1,40 +1,42 @@ import os from typing import Optional -from flax.training import checkpoints - from fortuna.calib_model.state import CalibState -from fortuna.training.mixin import WithCheckpointingMixin +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.typing import ( OptaxOptimizer, Path, ) +# from flax.training import checkpoints -class WithCalibCheckpointingMixin(WithCheckpointingMixin): - def restore_checkpoint( - self, - restore_checkpoint_path: Path, - optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, - ) -> CalibState: - if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( - restore_checkpoint_path - ): - raise ValueError( - f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_path), - target=None, - step=None, - prefix=prefix, - parallel=True, - ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." - ) - return CalibState.init_from_dict(d, optimizer, **kwargs) + +class WithCalibCheckpointingMixin(WithCheckpointingMixin): + pass + # def restore_checkpoint( + # self, + # restore_checkpoint_dir: Path, + # optimizer: Optional[OptaxOptimizer] = None, + # prefix: str = "", + # **kwargs, + # ) -> CalibState: + # if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile( + # restore_checkpoint_dir + # ): + # raise ValueError( + # f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found." + # ) + # d = checkpoints.restore_checkpoint( + # ckpt_dir=str(restore_checkpoint_dir), + # target=None, + # step=None, + # prefix=prefix, + # parallel=True, + # ) + # if d is None: + # raise ValueError( + # f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`." + # ) + # + # return CalibState.init_from_dict(d, optimizer, **kwargs) diff --git a/fortuna/calib_model/calib_model_calibrator.py b/fortuna/calib_model/calib_model_calibrator.py index 81c58aa3..5425119c 100644 --- a/fortuna/calib_model/calib_model_calibrator.py +++ b/fortuna/calib_model/calib_model_calibrator.py @@ -13,11 +13,9 @@ from optax._src.base import PyTree from fortuna.calib_model.state import CalibState -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, - TrainerABC, -) +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin +from fortuna.training.trainer import TrainerABC from fortuna.typing import ( Array, Batch, diff --git a/fortuna/calib_model/config/checkpointer.py b/fortuna/calib_model/config/checkpointer.py index 645e776d..18ed6bae 100644 --- a/fortuna/calib_model/config/checkpointer.py +++ b/fortuna/calib_model/config/checkpointer.py @@ -7,7 +7,7 @@ class Checkpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, start_from_current_state: bool = False, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, @@ -20,10 +20,10 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] = None Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. start_from_current_state: bool = False - If True, the optimization will start from the current state. If `restore_checkpoint_path` is given, then + If True, the optimization will start from the current state. If `restore_checkpoint_dir` is given, then `start_from_current_state` is ignored. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -36,7 +36,7 @@ def __init__( """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.start_from_current_state = start_from_current_state self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state diff --git a/fortuna/data/dataset/huggingface_datasets.py b/fortuna/data/dataset/huggingface_datasets.py index 7d556459..60521779 100644 --- a/fortuna/data/dataset/huggingface_datasets.py +++ b/fortuna/data/dataset/huggingface_datasets.py @@ -112,12 +112,12 @@ def get_data_loader( drop_last: bool if True, the last batch (which potentially is smaller then the default batch size) is dropped. verbose: bool - Whether to show a progress bar while iterating over the dataloader or not. + Whether to show a progress bar while iterating over the data_loader or not. Returns ------- HuggingFaceDataLoader - The dataloader + The data_loader """ iterable = IterableData.from_callable( lambda *args, **kwargs: self._get_data_loader( diff --git a/fortuna/data/loader/base.py b/fortuna/data/loader/base.py index 2f7aed43..d34799ce 100644 --- a/fortuna/data/loader/base.py +++ b/fortuna/data/loader/base.py @@ -9,13 +9,19 @@ Tuple, Type, TypeVar, + Union, ) from flax import jax_utils import jax +from jax.sharding import PartitionSpec from jax.tree_util import tree_map -from fortuna.data.loader.utils import IterableData +from fortuna.data.loader.utils import ( + IterableData, + prefetch_to_mesh, +) +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.typing import ( Array, Batch, @@ -185,7 +191,7 @@ def from_tensorflow_data_loader(cls: Type[T], tf_data_loader) -> T: T A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`. """ - return cls(iterable=IterableData.from_tf_dataloader(tf_data_loader)) + return cls(iterable=IterableData.from_tf_data_loader(tf_data_loader)) @classmethod def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T: @@ -203,7 +209,7 @@ def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T: T A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`. """ - return cls(iterable=IterableData.from_torch_dataloader(torch_data_loader)) + return cls(iterable=IterableData.from_torch_data_loader(torch_data_loader)) @classmethod def from_inputs_loaders( @@ -545,3 +551,24 @@ def __iter__(self, *args, **kwargs): loader = map(lambda batch: tree_map(self._reshape_inputs, batch), self._loader) loader = jax_utils.prefetch_to_device(loader, 2) yield from loader + + +class ShardedPrefetchedLoader: + def __init__( + self, + loader, + partition_manager: Optional[PartitionManager] = None, + partition_spec: Optional[PartitionSpec] = None, + ): + self._loader = loader + self.partition_manager = partition_manager + self.partition_spec = partition_spec + + def __iter__(self, *args, **kwargs): + loader = prefetch_to_mesh( + iter(self._loader), + 2, + self.partition_manager.partitioner.mesh, + self.partition_spec, + ) + yield from loader diff --git a/fortuna/data/loader/huggingface_loaders.py b/fortuna/data/loader/huggingface_loaders.py index 3a89801c..39e44461 100644 --- a/fortuna/data/loader/huggingface_loaders.py +++ b/fortuna/data/loader/huggingface_loaders.py @@ -35,7 +35,7 @@ def __init__( Parameters ---------- iterable : Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array],Array]]] - A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_dataloader`. + A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_data_loader`. num_unique_labels: int Number of unique target labels in the task (classification only) num_inputs: Optional[int] diff --git a/fortuna/data/loader/utils.py b/fortuna/data/loader/utils.py index 6bba7211..28730813 100644 --- a/fortuna/data/loader/utils.py +++ b/fortuna/data/loader/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations +import collections from copy import deepcopy +import itertools from itertools import zip_longest from typing import ( Iterable, @@ -9,6 +11,12 @@ Union, ) +import jax +from jax.sharding import ( + Mesh, + NamedSharding, + PartitionSpec, +) import numpy as np from fortuna.typing import ( @@ -44,9 +52,9 @@ def _inner(): return cls(_inner) @classmethod - def from_tf_dataloader(cls, tf_dataloader) -> IterableData: + def from_tf_data_loader(cls, tf_data_loader) -> IterableData: def _inner(): - for batch_inputs, batch_targets in tf_dataloader: + for batch_inputs, batch_targets in tf_data_loader: if not isinstance(batch_inputs, dict): batch_inputs = batch_inputs.numpy() else: @@ -57,9 +65,9 @@ def _inner(): return cls(_inner) @classmethod - def from_torch_dataloader(cls, torch_dataloader) -> IterableData: + def from_torch_data_loader(cls, torch_data_loader) -> IterableData: def _inner(): - for batch_inputs, batch_targets in torch_dataloader: + for batch_inputs, batch_targets in torch_data_loader: if not isinstance(batch_inputs, dict): batch_inputs = batch_inputs.numpy() else: @@ -218,3 +226,29 @@ def _prefetch(self): if not self._ready: self._batch = self._generator.__next__() self._ready = True + + +def prefetch_to_mesh(iterator, size: int, mesh: Mesh, xs_spec): + queue = collections.deque() + + def _prefetch(xs): + return jax.device_put( + xs, + NamedSharding( + mesh, + xs_spec + if xs_spec is not None + else xs.sharding.spec + if hasattr(xs, "sharding") + else PartitionSpec(), + ), + ) + + def enqueue(n): # Enqueues *up to* `n` elements from the iterator. + for data in itertools.islice(iterator, n): + queue.append(jax.tree_util.tree_map(_prefetch, data)) + + enqueue(size) # Fill up the buffer. + while queue: + yield queue.popleft() + enqueue(1) diff --git a/fortuna/likelihood/base.py b/fortuna/likelihood/base.py index 118439f3..0f1e60b9 100644 --- a/fortuna/likelihood/base.py +++ b/fortuna/likelihood/base.py @@ -8,6 +8,7 @@ Union, ) +from flax.core import FrozenDict from jax import ( jit, pmap, @@ -214,9 +215,10 @@ def _batched_log_joint_prob( mutable=mutable, rng=rng, ) - if "mutable" in return_aux: + if train and mutable is not None: outputs, aux = outs - mutable = aux["mutable"] + if mutable in return_aux: + mutable = aux["mutable"] else: outputs = outs @@ -238,11 +240,11 @@ def _batched_log_joint_prob( and "calib_mutable" in return_aux ): outputs, aux["calib_mutable"] = outs - aux["calib_mutable"] = dict(output_calibrator=aux["calib_mutable"]) + aux["calib_mutable"] = FrozenDict(output_calibrator=aux["calib_mutable"]) else: outputs = outs if "calib_mutable" in return_aux: - aux["calib_mutable"] = dict(output_calibrator=None) + aux["calib_mutable"] = FrozenDict(output_calibrator=None) log_joint_prob = jnp.sum( self.prob_output_layer.log_prob(outputs, targets, train=train, **kwargs) diff --git a/fortuna/model/llama.py b/fortuna/model/llama.py new file mode 100644 index 00000000..49e2427d --- /dev/null +++ b/fortuna/model/llama.py @@ -0,0 +1,1268 @@ +from functools import partial +import json +import os +from shutil import copyfile +import tempfile +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Union, +) + +from flax.core.frozen_dict import ( + FrozenDict, + freeze, + unfreeze, +) +import flax.linen as nn +from flax.linen import ( + combine_masks, + make_causal_mask, + partitioning as nn_partitioning, +) +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import ( + flatten_dict, + unflatten_dict, +) +import jax +from jax import lax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as PS +import numpy as np +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxCausalLMOutput, +) +from transformers.modeling_flax_utils import FlaxPreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + +LLAMA_STANDARD_CONFIGS = { + "3b": { + "vocab_size": 32000, + "hidden_size": 3200, + "intermediate_size": 8640, + "num_hidden_layers": 26, + "num_attention_heads": 32, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "7b": { + "vocab_size": 32000, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "13b": { + "vocab_size": 32000, + "hidden_size": 5120, + "intermediate_size": 13824, + "num_hidden_layers": 40, + "num_attention_heads": 40, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "30b": { + "vocab_size": 32000, + "hidden_size": 6656, + "intermediate_size": 17920, + "num_hidden_layers": 60, + "num_attention_heads": 52, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, + "65b": { + "vocab_size": 32000, + "hidden_size": 8192, + "intermediate_size": 22016, + "num_hidden_layers": 80, + "num_attention_heads": 64, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-5, + "use_cache": True, + "tie_word_embeddings": False, + }, + "debug": { # A small model for debugging + "vocab_size": 32000, + "hidden_size": 128, + "intermediate_size": 256, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, + }, +} + + +class LLaMAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_sequence_length (`int`, *optional*, defaults to 2048): + Max sequence length for model (for RoPE computation) + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from transformers import LLaMAModel, LLaMAConfig + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LLaMAConfig() + >>> # Initializing a model from the llama-7b style configuration + >>> model = LLaMAModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "llama" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + max_sequence_length=2048, + rms_norm_eps=1e-6, + initializer_range=0.02, + use_cache=True, + # pad_token_id=-1, + bos_token_id=0, + eos_token_id=1, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + tie_word_embeddings=False, + remat_block="nothing_saveable", + remat_attention="", + remat_mlp="", + fcm_min_ratio=0.0, + fcm_max_ratio=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.remat_block = remat_block + self.remat_attention = remat_attention + self.remat_mlp = remat_mlp + self.fcm_min_ratio = fcm_min_ratio + self.fcm_max_ratio = fcm_max_ratio + super().__init__( + # pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class RMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.weight = self.param( + "kernel", + nn.initializers.ones, + (self.dim,), + self.param_dtype, + ) + + def _norm(self, x: jnp.ndarray) -> jnp.ndarray: + return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) + output = self._norm(x).astype(self.dtype) + weight = jnp.asarray(self.weight, self.dtype) + return output * weight + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32 +) -> jnp.ndarray: + freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) + t = np.arange(end) # type: ignore + freqs = np.outer(t, freqs).astype(dtype) # type: ignore + sin, cos = np.sin(freqs), np.cos(freqs) + freqs_cis = np.complex64(cos + 1j * sin) + return jnp.asarray(freqs_cis) + + +def apply_rotary_emb( + xq: jnp.ndarray, + xk: jnp.ndarray, + freqs_cis: jnp.ndarray, + dtype: jnp.dtype = jnp.float32, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) + reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + + xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) + xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + + # add head dim + freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) + + xq_out = xq_ * freqs_cis + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape( + *xq_out.shape[:-1], -1 + ) + + xk_out = xk_ * freqs_cis + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape( + *xk_out.shape[:-1], -1 + ) + + return xq_out.astype(dtype), xk_out.astype(dtype) + + +class FlaxLLaMAAttention(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + config = self.config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.wq = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wk = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wv = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wo = nn.Dense( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + + self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) + + self.causal_mask = make_causal_mask( + jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool" + ) + + self.freqs_cis = precompute_freqs_cis( + self.head_dim, + config.max_sequence_length * 2, + dtype=self.dtype, + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape( + hidden_states.shape[:2] + (self.num_heads, self.head_dim) + ) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable( + "cache", "cached_key", jnp.zeros, key.shape, key.dtype + ) + cached_value = self.variable( + "cache", "cached_value", jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32) + ) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + position_ids, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask=None, + ): + xq, xk, xv = ( + self.wq(hidden_states), + self.wk(hidden_states), + self.wv(hidden_states), + ) + + xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp")) + xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp")) + xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp")) + + xq = self._split_heads(xq) + xk = self._split_heads(xk) + xv = self._split_heads(xv) + + freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype) + + query_length, key_length = xq.shape[1], xk.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length), + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size,) + causal_mask.shape[1:] + ) + + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) + + dropout_rng = None + if not deterministic and self.config.attn_pdrop > 0.0: + dropout_rng = self.make_rng("dropout") + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + xk, xv, attention_mask = self._concatenate_to_cache( + xk, xv, xq, attention_mask + ) + + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype( + self.dtype + ), + ) + + # usual dot product attention + attn_weights = dot_product_attention_weights( + xq, + xk, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attn_pdrop, + deterministic=deterministic, + dtype=jnp.promote_types(self.dtype, jnp.float32), + precision=self.precision, + ) + attn_weights = with_sharding_constraint( + attn_weights, PS(("dp", "fsdp"), "mp", None, None) + ) + + attn_output = jnp.einsum( + "...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision + ) + attn_output = self._merge_heads(attn_output) + attn_output = self.wo(attn_output) + attn_output = self.resid_dropout(attn_output, deterministic=deterministic) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxLLaMAMLP(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self) -> None: + config = self.config + + self.w1 = nn.Dense( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.w2 = nn.Dense( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.w3 = nn.Dense( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.dropout = nn.Dropout(rate=self.config.resid_pdrop) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + x = self.w2(nn.silu(self.w1(x)) * self.w3(x)) + x = self.dropout(x, deterministic=deterministic) + return x + + +class FlaxLLaMABlock(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self) -> None: + attention_module = FlaxLLaMAAttention + mlp_module = FlaxLLaMAMLP + if self.config.remat_attention != "": + attention_module = remat( + FlaxLLaMAAttention, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_attention), + ) + if self.config.remat_mlp != "": + mlp_module = remat( + FlaxLLaMAMLP, + static_argnums=(1,), + policy=get_gradient_checkpoint_policy(self.config.remat_mlp), + ) + + self.attention = attention_module( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.feed_forward = mlp_module( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.attention_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.ffn_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask: Optional[jnp.ndarray] = None, + ): + attn_outputs = self.attention( + self.attention_norm(hidden_states), + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + attn_output = attn_outputs[0] + hidden_states = hidden_states + attn_output + + feed_forward_hidden_states = self.feed_forward( + self.ffn_norm(hidden_states), + deterministic, + ) + hidden_states = hidden_states + feed_forward_hidden_states + + return (hidden_states,) + attn_outputs[1:] + + +class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LLaMAConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: LLaMAConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__( + config, + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init, + ) + + def init_weights( + self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None + ) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape + ) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, position_ids, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + input_ids, + attention_mask, + position_ids, + return_dict=False, + init_cache=True, + ) + return init_variables["cache"] + + @add_start_docstrings_to_model_forward("") + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError( + "Make sure to provide `position_ids` when passing `past_key_values`." + ) + + position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxLLaMABlockCollection(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + block = FlaxLLaMABlock + if self.config.remat_block != "": + block = remat( + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), + ) + self.blocks = [ + block( + self.config, + name=str(i), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if not deterministic and self.config.fcm_max_ratio > 0: + # Apply forgetful causal mask + batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] + fcm_ratio = jax.random.uniform( + self.make_rng("fcm"), + shape=(batch_size, 1, 1, 1), + minval=self.config.fcm_min_ratio, + maxval=self.config.fcm_max_ratio, + ) + fcm_mask = ( + jax.random.uniform( + self.make_rng("fcm"), shape=(batch_size, 1, seq_length, seq_length) + ) + > fcm_ratio + ) + fcm_mask = fcm_mask.at[:, :, :, 0].set(True) + fcm_mask = fcm_mask.astype("bool") + else: + fcm_mask = None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxGPTJModule` will filter them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +class FlaxLLaMAModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.embed_dim = self.config.hidden_size + + self.wte = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.dropout = nn.Dropout(rate=self.config.embd_pdrop) + self.h = FlaxLLaMABlockCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.ln_f = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + input_embeds = self.wte(input_ids.astype("i4")) + + hidden_states = self.dropout(input_embeds, deterministic=deterministic) + + outputs = self.h( + hidden_states, + attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings("", "") +class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAModule + + +# append_call_sample_docstring( +# FlaxLLaMAModel, +# _TOKENIZER_FOR_DOC, +# _CHECKPOINT_FOR_DOC, +# FlaxCausalLMOutput, +# _CONFIG_FOR_DOC, +# ) + + +class FlaxLLaMAForCausalLMModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + precision=self.precision, + ) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + batch_size, seq_length = input_ids.shape + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length), + ) + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + lm_logits = self.lm_head.apply( + {"params": {"kernel": shared_kernel}}, hidden_states + ) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("", "") +class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAForCausalLMModule + + def prepare_inputs_for_generation( + self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None + ): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPTJ uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0) + ) + else: + position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +# append_call_sample_docstring( +# FlaxGPTJForCausalLM, +# _TOKENIZER_FOR_DOC, +# _CHECKPOINT_FOR_DOC, +# FlaxCausalLMOutput, +# _CONFIG_FOR_DOC, +# ) + + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +class LLaMATokenizer(PreTrainedTokenizer): + """ + Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding. + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=False, + add_eos_token=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + super().__init__( + bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs + ) + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + + with tempfile.NamedTemporaryFile() as tfile: + with open_file(self.vocab_file, "rb") as fin: + tfile.write(fin.read()) + tfile.flush() + tfile.seek(0) + self.sp_model.Load(tfile.name) + """ Initialisation""" + self.add_special_tokens( + dict( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + ) + ) + self.pad_token_id = self.unk_token_id + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] diff --git a/fortuna/model/model_manager/base.py b/fortuna/model/model_manager/base.py index 77f90773..0e8210c6 100755 --- a/fortuna/model/model_manager/base.py +++ b/fortuna/model/model_manager/base.py @@ -9,14 +9,15 @@ from flax import linen as nn from flax.core import FrozenDict -from flax.training.checkpoints import PyTree from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.typing import ( InputData, Mutable, Params, + Shape, ) from fortuna.utils.random import WithRNG @@ -67,14 +68,14 @@ def apply( @abc.abstractmethod def init( - self, input_shape: Tuple[int, ...], rng: Optional[PRNGKeyArray] = None, **kwargs + self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs ) -> Dict[str, Mapping]: """ Initialize random parameters and mutable objects. Parameters ---------- - input_shape : Tuple + input_shape : Shape The shape of the input variable. rng: Optional[PRNGKeyArray] A random number generator. diff --git a/fortuna/model/model_manager/classification.py b/fortuna/model/model_manager/classification.py index df4c85fe..f4566fd7 100644 --- a/fortuna/model/model_manager/classification.py +++ b/fortuna/model/model_manager/classification.py @@ -10,11 +10,11 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree import jax from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.model.model_manager.base import ModelManager from fortuna.model.utils.random_features import RandomFeatureGaussianProcess diff --git a/fortuna/model/model_manager/regression.py b/fortuna/model/model_manager/regression.py index 2458c6d6..2fef5659 100644 --- a/fortuna/model/model_manager/regression.py +++ b/fortuna/model/model_manager/regression.py @@ -7,11 +7,11 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree import jax from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.model.model_manager.base import ModelManager from fortuna.typing import ( @@ -65,6 +65,7 @@ def apply( lik_log_var_rngs = None if mutable is not None: + mutable = mutable.unfreeze() mutable["model"] = mutable.get("model") mutable["lik_log_var"] = mutable.get("lik_log_var") @@ -102,12 +103,12 @@ def apply_fn(p, x, m_mutable, llv_mutable): self._check_outputs(model_outputs, lik_log_var_outputs) aux = dict() - if m_mutable or llv_mutable: + if train and (m_mutable or llv_mutable): aux["mutable"] = dict() if m_mutable: aux["mutable"]["model"] = m_mutable if llv_mutable: - aux["mutable"]["lik_log_var"] = m_mutable + aux["mutable"]["lik_log_var"] = llv_mutable return jnp.concatenate( (model_outputs, lik_log_var_outputs), axis=-1 diff --git a/fortuna/model/model_manager/transformers/classification.py b/fortuna/model/model_manager/transformers/classification.py index 6b793433..4941b2a1 100644 --- a/fortuna/model/model_manager/transformers/classification.py +++ b/fortuna/model/model_manager/transformers/classification.py @@ -8,13 +8,13 @@ from flax import linen as nn from flax.core import FrozenDict -from flax.training.checkpoints import PyTree import jax from jax import ( numpy as jnp, random, ) from jax._src.prng import PRNGKeyArray +from optax._src.base import PyTree from fortuna.model.model_manager.classification import ( ClassificationModelManager, diff --git a/fortuna/output_calib_model/base.py b/fortuna/output_calib_model/base.py index 97f017cc..6007148b 100644 --- a/fortuna/output_calib_model/base.py +++ b/fortuna/output_calib_model/base.py @@ -10,19 +10,17 @@ from fortuna.output_calib_model.config.base import Config from fortuna.output_calib_model.loss import Loss -from fortuna.output_calib_model.output_calib_mixin import ( - WithOutputCalibCheckpointingMixin, +from fortuna.output_calib_model.output_calib_state_repository import ( + OutputCalibStateRepository, ) -from fortuna.output_calib_model.output_calib_model_calibrator import ( +from fortuna.output_calib_model.output_calibrator.output_calib_model_calibrator import ( JittedOutputCalibModelCalibrator, MultiDeviceOutputCalibModelCalibrator, OutputCalibModelCalibrator, ) -from fortuna.output_calib_model.output_calib_state_repository import ( - OutputCalibStateRepository, -) from fortuna.output_calib_model.state import OutputCalibState from fortuna.output_calibrator.output_calib_manager.state import OutputCalibManagerState +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.typing import ( Array, Outputs, @@ -34,7 +32,7 @@ from fortuna.utils.random import RandomNumberGenerator -class OutputCalibModel(WithOutputCalibCheckpointingMixin, abc.ABC): +class OutputCalibModel(WithCheckpointingMixin, abc.ABC): """ Abstract calibration model class. """ @@ -90,7 +88,7 @@ def _calibrate( early_stopping_patience=config.monitor.early_stopping_patience, ) - if config.checkpointer.restore_checkpoint_path is None: + if config.checkpointer.restore_checkpoint_dir is None: state = OutputCalibManagerState.init_from_dict( d=FrozenDict( output_calibrator=self.output_calib_manager.init( @@ -105,7 +103,7 @@ def _calibrate( ) else: state = self.restore_checkpoint( - config.checkpointer.restore_checkpoint_path, + config.checkpointer.restore_checkpoint_dir, optimizer=config.optimizer.method, ) @@ -133,35 +131,33 @@ def _calibrate( ) return status - def load_state(self, checkpoint_path: Path) -> None: + def load_state(self, checkpoint_dir: Path) -> None: """ Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to a checkpoint file or directory to restore. """ try: - self.restore_checkpoint(checkpoint_path) + self.restore_checkpoint(checkpoint_dir) except ValueError: raise ValueError( - f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." + f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`." ) self.predictive.state = OutputCalibStateRepository( - checkpoint_dir=checkpoint_path + checkpoint_dir=checkpoint_dir ) - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the calibration state as a checkpoint. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to file or directory where to save the current state. keep_top_n_checkpoints : int Number of past checkpoint files to keep. @@ -172,6 +168,6 @@ def save_state( ) return self.predictive.state.put( self.predictive.state.get(), - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints, ) diff --git a/fortuna/output_calib_model/config/checkpointer.py b/fortuna/output_calib_model/config/checkpointer.py index 1a63b0a7..63946ddf 100644 --- a/fortuna/output_calib_model/config/checkpointer.py +++ b/fortuna/output_calib_model/config/checkpointer.py @@ -7,7 +7,7 @@ class Checkpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, dump_state: bool = False, @@ -19,7 +19,7 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] = None Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -32,6 +32,6 @@ def __init__( """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state diff --git a/fortuna/output_calib_model/output_calib_mixin.py b/fortuna/output_calib_model/output_calib_mixin.py deleted file mode 100644 index a0551947..00000000 --- a/fortuna/output_calib_model/output_calib_mixin.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -from typing import Optional - -from flax.training import checkpoints - -from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.mixin import WithCheckpointingMixin -from fortuna.typing import ( - OptaxOptimizer, - Path, -) - - -class WithOutputCalibCheckpointingMixin(WithCheckpointingMixin): - def restore_checkpoint( - self, - restore_checkpoint_path: Path, - optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, - ) -> OutputCalibState: - if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( - restore_checkpoint_path - ): - raise ValueError( - f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_path), - target=None, - step=None, - prefix=prefix, - parallel=True, - ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." - ) - - return OutputCalibState.init_from_dict(d, optimizer, **kwargs) diff --git a/fortuna/output_calib_model/output_calibrator/__init__.py b/fortuna/output_calib_model/output_calibrator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/output_calib_model/output_calib_model_calibrator.py b/fortuna/output_calib_model/output_calibrator/base.py similarity index 98% rename from fortuna/output_calib_model/output_calib_model_calibrator.py rename to fortuna/output_calib_model/output_calibrator/base.py index d18ea327..4ebaba36 100644 --- a/fortuna/output_calib_model/output_calib_model_calibrator.py +++ b/fortuna/output_calib_model/output_calibrator/base.py @@ -31,11 +31,9 @@ TargetsLoader, ) from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.mixin import ( - InputValidatorMixin, - WithCheckpointingMixin, - WithEarlyStoppingMixin, -) +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin from fortuna.typing import ( Array, CalibMutable, @@ -442,7 +440,7 @@ def val_epoch_end( ) # early stopping improved = self.early_stopping_update(val_losses_and_metrics_current_epoch) - if improved and self.save_checkpoint_dir: + if improved and self.save_checkpoint_dir is not None: self.save_checkpoint(state, self.save_checkpoint_dir, force_save=True) return val_losses_and_metrics_current_epoch @@ -567,7 +565,7 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: state = self.sync_mutable(state) state = jax.device_get(tree_map(lambda x: x[0], state)) diff --git a/fortuna/output_calibrator/output_calib_manager/base.py b/fortuna/output_calibrator/output_calib_manager/base.py index fcc4b50d..2ebff850 100644 --- a/fortuna/output_calibrator/output_calib_manager/base.py +++ b/fortuna/output_calibrator/output_calib_manager/base.py @@ -6,10 +6,10 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp +from optax._src.base import PyTree from fortuna.typing import ( Array, @@ -106,7 +106,9 @@ def init( rng, params_key, dropout_key = random.split(rng, 3) rngs = {"params": params_key, "dropout": dropout_key} return ( - self.output_calibrator.init(rngs, jnp.zeros((1, output_dim)), **kwargs) + FrozenDict( + self.output_calibrator.init(rngs, jnp.zeros((1, output_dim)), **kwargs) + ) if self.output_calibrator is not None else None ) diff --git a/fortuna/partitioner/__init__.py b/fortuna/partitioner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/partitioner/base.py b/fortuna/partitioner/base.py new file mode 100644 index 00000000..1bf90b46 --- /dev/null +++ b/fortuna/partitioner/base.py @@ -0,0 +1,39 @@ +from typing import ( + Dict, + Optional, + Tuple, +) + +from jax.sharding import PartitionSpec + +from fortuna.utils.mesh import get_mesh +from fortuna.utils.port import is_port_in_use + + +class Partitioner: + def __init__( + self, + axis_dims: Optional[Dict[str, int]] = None, + rules: Optional[Dict[str, Tuple[str, ...]]] = None, + coordinator_address: Optional[str] = None, + n_devices: Optional[int] = None, + ): + if axis_dims is None: + axis_dims = {"dp": 1, "fsdp": 1, "mp": -1} + if rules is None: + rules = {} + self.specs = { + k: PartitionSpec(*v) if v is not None else PartitionSpec(None) + for k, v in rules.items() + } + self.mesh = get_mesh(axis_dims) + + if coordinator_address is None: + port = 8888 + while is_port_in_use(port): + port += 1 + self.coordinator_address = f"localhost/{port}" + else: + self.coordinator_address = coordinator_address + + self.n_devices = n_devices diff --git a/fortuna/partitioner/partition_manager/__init__.py b/fortuna/partitioner/partition_manager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/partitioner/partition_manager/base.py b/fortuna/partitioner/partition_manager/base.py new file mode 100644 index 00000000..20eacb13 --- /dev/null +++ b/fortuna/partitioner/partition_manager/base.py @@ -0,0 +1,78 @@ +from typing import ( + Any, + Callable, + List, + Optional, +) + +from jax import ( + device_put, + eval_shape, + random, +) +from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit +from jax.sharding import ( + NamedSharding, + PartitionSpec, +) +from jax.tree_util import ( + tree_map, + tree_map_with_path, +) + +from fortuna.partitioner.base import Partitioner +from fortuna.training.train_state import TrainState +from fortuna.utils.partition import match_partition_specs +from fortuna.utils.random import WithRNG + + +class PartitionManager(WithRNG): + def __init__(self, partitioner: Partitioner): + self.partitioner = partitioner + self._shapes_dtypes = None + self._shardings = None + + @property + def shapes_dtypes(self): + return self._shapes_dtypes + + @shapes_dtypes.setter + def shapes_dtypes(self, shapes_dtypes: TrainState): + self._shapes_dtypes = shapes_dtypes + partitions = match_partition_specs(self.partitioner.specs, self._shapes_dtypes) + self._shardings = tree_map( + lambda p: NamedSharding(mesh=self.partitioner.mesh, spec=p), partitions + ) + + @property + def shardings(self): + return self._shardings + + @shardings.setter + def shardings(self, shardings: Optional[TrainState]): + self._shardings = shardings + + def init_sharded_state(self, init_state_fn: Callable[[Any], TrainState], *args): + self.shapes_dtypes = eval_shape(init_state_fn, random.PRNGKey(0)) + + with self.partitioner.mesh: + return pjit( + init_state_fn, + in_shardings=PartitionSpec(), + out_shardings=self.shardings, + )(*args) + + def reshard( + self, state: TrainState, exclude: Optional[List[str]] = None + ) -> TrainState: + if self.shardings is not None: + if exclude is None: + exclude = [] + return tree_map_with_path( + lambda p, _v, s: device_put(_v, s) + if _v is not None and p[0].name not in exclude + else _v, + state, + self.shardings, + ) diff --git a/fortuna/prob_model/base.py b/fortuna/prob_model/base.py index 9f9f4e94..6b6c5e5f 100644 --- a/fortuna/prob_model/base.py +++ b/fortuna/prob_model/base.py @@ -6,17 +6,17 @@ Optional, ) -import jax +from jax import eval_shape import jax.numpy as jnp from fortuna.data.loader import DataLoader from fortuna.output_calib_model.state import OutputCalibState +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.prob_model_calibrator import ( - JittedProbModelOutputCalibrator, - MultiDeviceProbModelOutputCalibrator, ProbModelOutputCalibrator, + ShardedProbModelOutputCalibrator, ) from fortuna.typing import ( Array, @@ -24,7 +24,6 @@ Status, ) from fortuna.utils.data import check_data_loader_is_not_random -from fortuna.utils.device import select_trainer_given_devices from fortuna.utils.random import RandomNumberGenerator @@ -46,6 +45,7 @@ def __set_rng(self): self.joint.rng = self.rng self.posterior.rng = self.rng self.predictive.rng = self.rng + self.partition_manager.rng = self.rng def train( self, @@ -136,7 +136,7 @@ def _calibrate( "Pre-compute ensemble of outputs on the calibration data loader." ) - distribute = jax.local_devices()[0].platform != "cpu" + shard = not calib_config.processor.disable_jit ( calib_ensemble_outputs_loader, @@ -145,7 +145,7 @@ def _calibrate( inputs_loader=calib_data_loader.to_inputs_loader(), n_output_samples=calib_config.processor.n_posterior_samples, return_size=True, - distribute=distribute, + shard=shard, ) if calib_config.monitor.verbose: logging.info( @@ -156,22 +156,39 @@ def _calibrate( inputs_loader=val_data_loader.to_inputs_loader(), n_output_samples=calib_config.processor.n_posterior_samples, return_size=True, - distribute=distribute, + shard=shard, ) if val_data_loader is not None else (None, None) ) - trainer_cls = select_trainer_given_devices( - devices=calib_config.processor.devices, - base_trainer_cls=ProbModelOutputCalibrator, - jitted_trainer_cls=JittedProbModelOutputCalibrator, - multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator, - disable_jit=calib_config.processor.disable_jit, + output_calib_partition_manager = PartitionManager( + partitioner=self.partition_manager.partitioner + ) + + if calib_config.checkpointer.restore_checkpoint_dir is None: + calib_dict = self.posterior.state.extract_calib_keys() + state = OutputCalibState.init( + params=calib_dict["calib_params"], + mutable=calib_dict["calib_mutable"], + optimizer=calib_config.optimizer.method, + ) + else: + state = self.posterior.state.restore_checkpoint( + calib_config.checkpointer.restore_checkpoint_dir, + optimizer=calib_config.optimizer.method, + ) + output_calib_partition_manager.shapes_dtypes = eval_shape(lambda: state) + + trainer_cls = ( + ShardedProbModelOutputCalibrator + if not calib_config.processor.disable_jit + else ProbModelOutputCalibrator ) calibrator = trainer_cls( calib_outputs_loader=calib_ensemble_outputs_loader, + partition_manager=output_calib_partition_manager, val_outputs_loader=val_ensemble_outputs_loader, predict_fn=self.prob_output_layer.predict, uncertainty_fn=uncertainty_fn, @@ -185,20 +202,6 @@ def _calibrate( early_stopping_patience=calib_config.monitor.early_stopping_patience, ) - if calib_config.checkpointer.restore_checkpoint_path is None: - calib_dict = self.posterior.state.extract_calib_keys() - - state = OutputCalibState.init( - params=calib_dict["calib_params"], - mutable=calib_dict["calib_mutable"], - optimizer=calib_config.optimizer.method, - ) - else: - state = self.posterior.restore_checkpoint( - calib_config.checkpointer.restore_checkpoint_path, - optimizer=calib_config.optimizer.method, - ) - if calib_config.monitor.verbose: logging.info("Start calibration.") state, status = calibrator.train( @@ -225,7 +228,7 @@ def _calibrate( if calib_config.monitor.verbose: logging.info("Dump state to disk.") self.save_state( - checkpoint_path=calib_config.checkpointer.save_checkpoint_dir + checkpoint_dir=calib_config.checkpointer.save_checkpoint_dir ) if calib_config.monitor.verbose: @@ -233,29 +236,27 @@ def _calibrate( return status - def load_state(self, checkpoint_path: Path) -> None: + def load_state(self, checkpoint_dir: Path) -> None: """ Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the probabilistic model. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to a checkpoint file or directory to restore. """ - return self.posterior.load_state(checkpoint_path) + return self.posterior.load_state(checkpoint_dir) - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the posterior distribution state as a checkpoint. Parameters ---------- - checkpoint_path : Path + checkpoint_dir : Path Path to file or directory where to save the current state. keep_top_n_checkpoints : int Number of past checkpoint files to keep. """ - return self.posterior.save_state(checkpoint_path, keep_top_n_checkpoints) + return self.posterior.save_state(checkpoint_dir, keep_top_n_checkpoints) diff --git a/fortuna/prob_model/calib_config/checkpointer.py b/fortuna/prob_model/calib_config/checkpointer.py index 8ce9c9c3..ee50b54c 100644 --- a/fortuna/prob_model/calib_config/checkpointer.py +++ b/fortuna/prob_model/calib_config/checkpointer.py @@ -7,7 +7,7 @@ class CalibCheckpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, dump_state: bool = False, @@ -19,7 +19,7 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] = None Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -32,6 +32,6 @@ def __init__( """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state diff --git a/fortuna/prob_model/classification.py b/fortuna/prob_model/classification.py index 8f157256..e41300f3 100644 --- a/fortuna/prob_model/classification.py +++ b/fortuna/prob_model/classification.py @@ -20,6 +20,8 @@ from fortuna.model_editor.base import ModelEditor from fortuna.output_calibrator.classification import ClassificationTemperatureScaler from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.base import ProbModel from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig @@ -53,6 +55,7 @@ def __init__( posterior_approximator: PosteriorApproximator = SWAGPosteriorApproximator(), output_calibrator: Optional[nn.Module] = ClassificationTemperatureScaler(), model_editor: Optional[ModelEditor] = None, + partitioner: Partitioner = Partitioner(), seed: int = 0, ): r""" @@ -77,6 +80,8 @@ def __init__( calibration parameters. model_editor : ModelEditor A model_editor objects. It takes the forward pass and transforms the outputs. + partitioner : Partitioner + A partitioning object for data, fully sharded data model parallelization. seed: int A random seed. @@ -130,10 +135,14 @@ def __init__( self.model_manager, self.prob_output_layer, self.output_calib_manager ) self.joint = Joint(self.prior, self.likelihood) - + self.partition_manager = PartitionManager(partitioner) self.posterior = getattr( PosteriorApproximations, posterior_approximator.__str__() - ).value(joint=self.joint, posterior_approximator=posterior_approximator) + ).value( + joint=self.joint, + posterior_approximator=posterior_approximator, + partition_manager=self.partition_manager, + ) self.predictive = ClassificationPredictive(self.posterior) super().__init__(seed=seed) diff --git a/fortuna/prob_model/fit_config/checkpointer.py b/fortuna/prob_model/fit_config/checkpointer.py index 95ef2e9d..ea3b448e 100644 --- a/fortuna/prob_model/fit_config/checkpointer.py +++ b/fortuna/prob_model/fit_config/checkpointer.py @@ -7,11 +7,12 @@ class FitCheckpointer: def __init__( self, save_checkpoint_dir: Optional[Path] = None, - restore_checkpoint_path: Optional[Path] = None, + restore_checkpoint_dir: Optional[Path] = None, start_from_current_state: bool = False, save_every_n_steps: Optional[int] = None, keep_top_n_checkpoints: Optional[int] = 2, dump_state: bool = False, + checkpoint_type: str = "last", ): """ An object to configure saving and restoring of checkpoints during the posterior fitting. @@ -20,10 +21,10 @@ def __init__( ---------- save_checkpoint_dir: Optional[Path] Save directory location. - restore_checkpoint_path: Optional[Path] + restore_checkpoint_dir: Optional[Path] Path to checkpoint file or directory to restore. start_from_current_state: bool = False - If True, the optimization will start from the current state. If `restore_checkpoint_path` is given, then + If True, the optimization will start from the current state. If `restore_checkpoint_dir` is given, then `start_from_current_state` is ignored. save_every_n_steps: int Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no @@ -33,10 +34,26 @@ def __init__( dump_state: bool Dump the fitted posterior state as a checkpoint in `save_checkpoint_dir`. Any future call to the state will internally involve restoring it from memory. + checkpoint type: str + Which checkpoint type to pass to the state. + There are two possible options: + + - "last": this is the state obtained at the end of training. + - "best": this is the best checkpoint with respect to the metric monitored by early stopping. Notice that + this might be available only if validation data is provided, and both checkpoint saving and early + stopping are enabled. """ self.save_checkpoint_dir = save_checkpoint_dir self.save_every_n_steps = save_every_n_steps - self.restore_checkpoint_path = restore_checkpoint_path + self.restore_checkpoint_dir = restore_checkpoint_dir self.start_from_current_state = start_from_current_state self.keep_top_n_checkpoints = keep_top_n_checkpoints self.dump_state = dump_state + + allowed_checkpoint_types = ["last", "best"] + if checkpoint_type not in allowed_checkpoint_types: + raise ValueError( + f"`checkpoint_type={checkpoint_type}` not recognised. " + f"Pleas select one of the following options: {allowed_checkpoint_types}." + ) + self.checkpoint_type = checkpoint_type diff --git a/fortuna/prob_model/joint/base.py b/fortuna/prob_model/joint/base.py index b113c548..3a34ec8a 100755 --- a/fortuna/prob_model/joint/base.py +++ b/fortuna/prob_model/joint/base.py @@ -6,6 +6,7 @@ ) from flax.core import FrozenDict +from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp @@ -154,7 +155,9 @@ def _batched_negative_log_joint_prob( return loss, aux return -outs - def init(self, input_shape: Shape, **kwargs) -> JointState: + def init( + self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs + ) -> JointState: """ Initialize the state of the joint distribution. @@ -162,15 +165,19 @@ def init(self, input_shape: Shape, **kwargs) -> JointState: ---------- input_shape : Shape The shape of the input variable. + rng: Optional[PRNGKeyArray] + A random number generator key. Returns ------- A state of the joint distribution. """ + if rng is None: + rng = self.rng.get() + key1, key2 = random.split(rng) + oms = ModelManagerState.init_from_dict( - self.likelihood.model_manager.init( - input_shape, rng=self.rng.get(), **kwargs - ) + self.likelihood.model_manager.init(input_shape, rng=key1, **kwargs) ) inputs = get_inputs_from_shape(input_shape) outputs = self.likelihood.model_manager.apply( @@ -184,7 +191,7 @@ def init(self, input_shape: Shape, **kwargs) -> JointState: ocms = OutputCalibManagerState.init_from_dict( FrozenDict( output_calibrator=self.likelihood.output_calib_manager.init( - output_dim=output_dim + output_dim=output_dim, rng=key2 ) ) ) diff --git a/fortuna/prob_model/posterior/base.py b/fortuna/prob_model/posterior/base.py index 0676bb9f..ee98e5b4 100755 --- a/fortuna/prob_model/posterior/base.py +++ b/fortuna/prob_model/posterior/base.py @@ -10,12 +10,13 @@ from flax.core import FrozenDict from jax._src.prng import PRNGKeyArray +from orbax.checkpoint.checkpoint_manager import CheckpointManager from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState -from fortuna.prob_model.posterior.posterior_mixin import WithPosteriorCheckpointingMixin from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) @@ -24,6 +25,7 @@ Path, Status, ) +from fortuna.utils.checkpoint import get_checkpoint_manager from fortuna.utils.freeze import get_trainable_paths from fortuna.utils.nested_dicts import ( nested_get, @@ -46,10 +48,15 @@ def posterior_method_kwargs(self) -> Dict[str, Any]: return {} -class Posterior(WithRNG, WithPosteriorCheckpointingMixin): +class Posterior(WithRNG): state = None - def __init__(self, joint: Joint, posterior_approximator: PosteriorApproximator): + def __init__( + self, + joint: Joint, + posterior_approximator: PosteriorApproximator, + partition_manager: PartitionManager, + ): r""" Posterior distribution class. This refers to :math:`p(w|\mathcal{D}, \phi)`, where :math:`w` are the random model parameters, :math:`\mathcal{D}` is a training data set and :math:`\phi` are calibration parameters. @@ -60,35 +67,41 @@ def __init__(self, joint: Joint, posterior_approximator: PosteriorApproximator): A joint distribution object. posterior_approximator: PosteriorApproximator A posterior approximator. + partition_manager: PartitionManager + An object to manage partitions. """ super().__init__() self.joint = joint self.posterior_approximator = posterior_approximator + self.partition_manager = partition_manager def _restore_state_from_somewhere( self, fit_config: FitConfig, allowed_states: Optional[Tuple[Type[PosteriorState], ...]] = None, + checkpoint_manager: Optional[CheckpointManager] = None, ) -> PosteriorState: - if fit_config.checkpointer.restore_checkpoint_path is not None: - state = self.restore_checkpoint( - restore_checkpoint_path=fit_config.checkpointer.restore_checkpoint_path, - optimizer=fit_config.optimizer.method, + if checkpoint_manager is not None: + repo = PosteriorStateRepository( + partition_manager=self.partition_manager, + checkpoint_manager=checkpoint_manager, ) + state = repo.get(optimizer=fit_config.optimizer.method) elif fit_config.checkpointer.start_from_current_state is not None: state = self.state.get(optimizer=fit_config.optimizer.method) if allowed_states is not None and not isinstance(state, allowed_states): raise ValueError( f"The type of the restored checkpoint must be within {allowed_states}. " - f"However, {fit_config.checkpointer.restore_checkpoint_path} pointed to a state " - f"with type {type(state)}." + f"However, the restored checkpoint has type {type(state)}." ) return state - def _init_joint_state(self, data_loader: DataLoader) -> JointState: - return self.joint.init(input_shape=data_loader.input_shape) + def _init_joint_state( + self, data_loader: DataLoader, rng: Optional[PRNGKeyArray] = None + ) -> JointState: + return self.joint.init(input_shape=data_loader.input_shape, rng=rng) @staticmethod def _freeze_optimizer_in_state( @@ -165,33 +178,29 @@ def sample(self, rng: Optional[PRNGKeyArray] = None, *args, **kwargs) -> JointSt """ pass - def load_state(self, checkpoint_path: Path) -> None: + def load_state(self, checkpoint_dir: Path) -> None: """ Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the current probabilistic model. Parameters ---------- - checkpoint_path: Path + checkpoint_dir: Path Path to checkpoint file or directory to restore. """ - try: - self.restore_checkpoint(checkpoint_path) - except ValueError: - raise ValueError( - f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." - ) - self.state = PosteriorStateRepository(checkpoint_dir=checkpoint_path) + self.state = PosteriorStateRepository( + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager(checkpoint_dir=checkpoint_dir), + ) + self.partition_manager.shapes_dtypes = self.state.get_shapes_dtypes_checkpoint() - def save_state( - self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the state of the posterior distribution to a checkpoint directory. Parameters ---------- - checkpoint_path: Path + checkpoint_dir: Path Path to checkpoint file or directory to restore. keep_top_n_checkpoints: int Number of past checkpoint files to keep. @@ -203,7 +212,7 @@ def save_state( ) return self.state.put( self.state.get(), - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep_top_n_checkpoints, ) @@ -216,10 +225,9 @@ def _check_fit_config(self, fit_config: FitConfig): "`save_checkpoint_dir` must be passed when `dump_state` is set to True." ) - @staticmethod - def _is_state_available_somewhere(fit_config: FitConfig) -> bool: + def _is_state_available_somewhere(self, fit_config: FitConfig) -> bool: return ( - fit_config.checkpointer.restore_checkpoint_path is not None + fit_config.checkpointer.restore_checkpoint_dir is not None or fit_config.checkpointer.start_from_current_state ) @@ -234,7 +242,7 @@ def _warn_frozen_params_start_from_random( logging.warning( "Parameters frozen via `fit_config.optimizer.freeze_fun` will not be updated. To start " "from sensible frozen parameters, you should configure " - "`fit_config.checkpointer.restore_checkpoint_path`, or " + "`fit_config.checkpointer.restore_checkpoint_dir`, or " "`fit_config.checkpointer.start_from_current_state`, or `map_fit_config`. " "Otherwise, " "a randomly initialized configuration of frozen parameters will be returned." diff --git a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py index 6fe54e6f..7d230ce0 100755 --- a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py +++ b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py @@ -149,11 +149,11 @@ def _fit(i): rng=self.rng.get(), state=_state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_size, verbose=fit_config.monitor.verbose, callbacks=fit_config.callbacks, @@ -267,12 +267,12 @@ def _restore_state_from_somewhere( fit_config: FitConfig, allowed_states: Optional[Tuple[Type[MAPState], ...]] = None, ) -> MAPState: - if fit_config.checkpointer.restore_checkpoint_path is not None: - restore_checkpoint_path = pathlib.Path( - fit_config.checkpointer.restore_checkpoint_path + if fit_config.checkpointer.restore_checkpoint_dir is not None: + restore_checkpoint_dir = pathlib.Path( + fit_config.checkpointer.restore_checkpoint_dir ) / str(i) state = self.restore_checkpoint( - restore_checkpoint_path=restore_checkpoint_path, + restore_checkpoint_dir=restore_checkpoint_dir, optimizer=fit_config.optimizer.method, ) elif fit_config.checkpointer.start_from_current_state is not None: @@ -281,7 +281,7 @@ def _restore_state_from_somewhere( if allowed_states is not None and not isinstance(state, allowed_states): raise ValueError( f"The type of the restored checkpoint must be within {allowed_states}. " - f"However, {fit_config.checkpointer.restore_checkpoint_path} pointed to a state " + f"However, {fit_config.checkpointer.restore_checkpoint_dir} pointed to a state " f"with type {type(state)}." ) diff --git a/fortuna/prob_model/posterior/laplace/laplace_posterior.py b/fortuna/prob_model/posterior/laplace/laplace_posterior.py index ef852e6c..0103fd88 100755 --- a/fortuna/prob_model/posterior/laplace/laplace_posterior.py +++ b/fortuna/prob_model/posterior/laplace/laplace_posterior.py @@ -252,7 +252,7 @@ def fit( raise ValueError( "The Laplace approximation must start from a preliminary run of MAP or an existing " "checkpoint or state. Please configure `map_fit_config`, or " - "`fit_config.checkpointer.restore_checkpoint_path`, " + "`fit_config.checkpointer.restore_checkpoint_dir`, " "or `fit_config.checkpointer.start_from_current_state`." ) diff --git a/fortuna/prob_model/posterior/map/map_posterior.py b/fortuna/prob_model/posterior/map/map_posterior.py index c86be8c1..3b5a5649 100755 --- a/fortuna/prob_model/posterior/map/map_posterior.py +++ b/fortuna/prob_model/posterior/map/map_posterior.py @@ -1,9 +1,12 @@ import logging +from pathlib import Path from typing import Optional +from jax import eval_shape from jax._src.prng import PRNGKeyArray from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState @@ -12,16 +15,15 @@ from fortuna.prob_model.posterior.map.map_approximator import MAPPosteriorApproximator from fortuna.prob_model.posterior.map.map_state import MAPState from fortuna.prob_model.posterior.map.map_trainer import ( - JittedMAPTrainer, MAPTrainer, - MultiDeviceMAPTrainer, + ShardedMAPTrainer, ) from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) from fortuna.typing import Status from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype -from fortuna.utils.device import select_trainer_given_devices +from fortuna.utils.checkpoint import get_checkpoint_manager logger = logging.getLogger(__name__) @@ -31,6 +33,7 @@ def __init__( self, joint: Joint, posterior_approximator: MAPPosteriorApproximator, + partition_manager: PartitionManager, ): """ Maximum-a-Posteriori (MAP) approximate posterior class. @@ -41,8 +44,14 @@ def __init__( A Joint distribution object. posterior_approximator: MAPPosteriorApproximator A MAP posterior approximator. + partition_manager: PartitionManager + An object to manage partitions. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator) + super().__init__( + joint=joint, + posterior_approximator=posterior_approximator, + partition_manager=partition_manager, + ) def __str__(self): return MAP_NAME @@ -57,16 +66,17 @@ def fit( ) -> Status: super()._checks_on_fit_start(fit_config, map_fit_config) - trainer_cls = select_trainer_given_devices( - devices=fit_config.processor.devices, - base_trainer_cls=MAPTrainer, - jitted_trainer_cls=JittedMAPTrainer, - multi_device_trainer_cls=MultiDeviceMAPTrainer, - disable_jit=fit_config.processor.disable_jit, + trainer_cls = ( + ShardedMAPTrainer if not fit_config.processor.disable_jit else MAPTrainer ) trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + fit_config.checkpointer.save_checkpoint_dir, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ), save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, @@ -78,17 +88,37 @@ def fit( freeze_fun=fit_config.optimizer.freeze_fun, ) - if super()._is_state_available_somewhere(fit_config): + checkpoint_restorer = ( + get_checkpoint_manager( + str( + Path(fit_config.checkpointer.restore_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) + + if self._is_state_available_somewhere(fit_config): state = self._restore_state_from_somewhere( fit_config=fit_config, allowed_states=(MAPState,), + checkpoint_manager=checkpoint_restorer, ) + state = self._freeze_optimizer_in_state(state, fit_config) + self.partition_manager.shapes_dtypes = eval_shape(lambda: state) else: - state = self._init_state( - data_loader=train_data_loader, fit_config=fit_config - ) - state = super()._freeze_optimizer_in_state(state, fit_config) + def init_state_fn(rng): + _state = self._init_state( + data_loader=train_data_loader, fit_config=fit_config, rng=rng + ) + return self._freeze_optimizer_in_state(_state, fit_config) + + state = self.partition_manager.init_sharded_state( + init_state_fn, self.rng.get() + ) self._check_state(state) logging.info("Run MAP.") @@ -96,11 +126,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, @@ -110,11 +140,20 @@ def fit( gradient_accumulation_steps=fit_config.hyperparameters.gradient_accumulation_steps, ) self.state = PosteriorStateRepository( - fit_config.checkpointer.save_checkpoint_dir - if fit_config.checkpointer.dump_state is True - else None + partition_manager=self.partition_manager, + checkpoint_manager=get_checkpoint_manager( + checkpoint_dir=str( + Path(fit_config.checkpointer.save_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.save_checkpoint_dir is not None + and fit_config.checkpointer.dump_state + else None, ) - self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) + if self.state.checkpoint_manager is None: + self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") return status @@ -127,9 +166,13 @@ def sample(self, rng: Optional[PRNGKeyArray] = None, **kwargs) -> JointState: calib_mutable=state.calib_mutable, ) - def _init_state(self, data_loader: DataLoader, fit_config: FitConfig) -> MAPState: - state = super()._init_joint_state(data_loader=data_loader) - + def _init_state( + self, + data_loader: DataLoader, + fit_config: FitConfig, + rng: Optional[PRNGKeyArray] = None, + ) -> MAPState: + state = super()._init_joint_state(data_loader=data_loader, rng=rng) return MAPState.init( params=state.params, mutable=state.mutable, diff --git a/fortuna/prob_model/posterior/map/map_trainer.py b/fortuna/prob_model/posterior/map/map_trainer.py index d1d7d955..a9d0d686 100644 --- a/fortuna/prob_model/posterior/map/map_trainer.py +++ b/fortuna/prob_model/posterior/map/map_trainer.py @@ -17,10 +17,9 @@ from fortuna.prob_model.posterior.map import * from fortuna.prob_model.posterior.map.map_state import MAPState from fortuna.prob_model.posterior.posterior_trainer import PosteriorTrainerABC -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, -) +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin +from fortuna.training.mixins.sharding import ShardingMixin from fortuna.typing import ( Array, Batch, @@ -111,3 +110,7 @@ class JittedMAPTrainer(JittedMixin, MAPTrainer): class MultiDeviceMAPTrainer(MultiDeviceMixin, MAPTrainer): pass + + +class ShardedMAPTrainer(ShardingMixin, MAPTrainer): + pass diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py index 8f06738e..80e4e780 100755 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py @@ -172,11 +172,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py index 995276dd..bca4fc25 100644 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py @@ -18,10 +18,8 @@ from fortuna.prob_model.posterior.normalizing_flow.normalizing_flow_trainer import ( NormalizingFlowTrainer, ) -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, -) +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin from fortuna.typing import ( Params, Path, @@ -47,7 +45,7 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: state = state.replace( params=self._unravel_params(state.params), @@ -96,7 +94,7 @@ def _unravel_params( def on_train_start( self, state: NormalizingFlowState, - dataloaders: List[DataLoader], + data_loaders: List[DataLoader], rng: PRNGKeyArray, ) -> Tuple[NormalizingFlowState, List[DataLoader], PRNGKeyArray]: if self.freeze_fun is not None: @@ -152,7 +150,7 @@ def on_train_start( ), ) - return state, dataloaders, rng + return state, data_loaders, rng class JittedADVITrainer(JittedMixin, ADVITrainer): diff --git a/fortuna/prob_model/posterior/posterior_mixin.py b/fortuna/prob_model/posterior/posterior_mixin.py index d73fd0b7..8390b10e 100644 --- a/fortuna/prob_model/posterior/posterior_mixin.py +++ b/fortuna/prob_model/posterior/posterior_mixin.py @@ -1,8 +1,8 @@ from typing import Optional from fortuna.prob_model.posterior.name_to_posterior_state import NameToPosteriorState -from fortuna.prob_model.posterior.state import PosteriorState -from fortuna.training.mixin import WithCheckpointingMixin +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.name_to_train_state import NameToTrainState from fortuna.typing import ( OptaxOptimizer, Path, @@ -12,15 +12,22 @@ class WithPosteriorCheckpointingMixin(WithCheckpointingMixin): def restore_checkpoint( self, - restore_checkpoint_path: Path, + restore_checkpoint_dir: Path, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - name_to_train_state: NameToPosteriorState = NameToPosteriorState, - **kwargs, - ) -> PosteriorState: + name_to_train_state: NameToTrainState = NameToPosteriorState, + ): return super().restore_checkpoint( - restore_checkpoint_path, - optimizer, - prefix, + restore_checkpoint_dir=restore_checkpoint_dir, + optimizer=optimizer, + name_to_train_state=name_to_train_state, + ) + + def get_shapes_dtypes_checkpoint( + self, + restore_checkpoint_dir: Optional[Path] = None, + name_to_train_state: NameToTrainState = NameToPosteriorState, + ): + return super().get_shapes_dtypes_checkpoint( + restore_checkpoint_dir=restore_checkpoint_dir, name_to_train_state=name_to_train_state, ) diff --git a/fortuna/prob_model/posterior/posterior_multi_state_repository.py b/fortuna/prob_model/posterior/posterior_multi_state_repository.py index 1cca1e86..ce52a1b5 100644 --- a/fortuna/prob_model/posterior/posterior_multi_state_repository.py +++ b/fortuna/prob_model/posterior/posterior_multi_state_repository.py @@ -31,14 +31,14 @@ def __init__(self, size: int, checkpoint_dir: Optional[Path] = None): def get( self, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", + prefix: str = "", **kwargs, ) -> Union[List[PosteriorState], PosteriorState]: def _get(_i): return self.state[_i].get( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, prefix=prefix, **kwargs, @@ -55,13 +55,13 @@ def put( self, state: PosteriorState, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, keep: int = 1, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: def _put(_i): return self.state[_i].put( - state=state, checkpoint_path=checkpoint_path, keep=keep, prefix=prefix + state=state, checkpoint_dir=checkpoint_dir, keep=keep, prefix=prefix ) if i is not None: @@ -73,14 +73,14 @@ def _put(_i): def pull( self, i: int = None, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", + prefix: str = "", **kwargs, ) -> PosteriorState: def _pull(_i): return self.state[_i].pull( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, prefix=prefix, **kwargs, @@ -97,16 +97,16 @@ def update( self, variables: Dict, i: int = None, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, keep: int = 1, - prefix: str = "checkpoint_", + prefix: str = "", **kwargs, ): def _update(_i): self.state[_i].update( variables=variables, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, keep=keep, prefix=prefix, @@ -123,13 +123,13 @@ def extract( self, keys: List[str], i: int = None, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", + checkpoint_dir: Optional[Path] = None, + prefix: str = "", **kwargs, ) -> Union[Dict, List[Dict]]: def _extract(_i): return self.state[_i].extract( - keys=keys, checkpoint_path=checkpoint_path, prefix=prefix, **kwargs + keys=keys, checkpoint_dir=checkpoint_dir, prefix=prefix, **kwargs ) if i is not None: @@ -141,10 +141,10 @@ def _extract(_i): def extract_calib_keys( self, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", + checkpoint_dir: Optional[Path] = None, + prefix: str = "", **kwargs, ) -> Dict: return self.extract( - ["calib_params", "calib_mutable"], 0, checkpoint_path, prefix, **kwargs + ["calib_params", "calib_mutable"], 0, checkpoint_dir, prefix, **kwargs ) diff --git a/fortuna/prob_model/posterior/posterior_state_repository.py b/fortuna/prob_model/posterior/posterior_state_repository.py index 1d907940..6496c3a9 100644 --- a/fortuna/prob_model/posterior/posterior_state_repository.py +++ b/fortuna/prob_model/posterior/posterior_state_repository.py @@ -11,10 +11,6 @@ class PosteriorStateRepository(WithPosteriorCheckpointingMixin, TrainStateRepository): def extract_calib_keys( self, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", - **kwargs, + checkpoint_dir: Optional[Path] = None, ) -> Dict: - return super().extract( - ["calib_params", "calib_mutable"], checkpoint_path, prefix, **kwargs - ) + return super().extract(["calib_params", "calib_mutable"], checkpoint_dir) diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py index d8f25150..7ad79090 100644 --- a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py @@ -176,11 +176,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py index 4641a71d..1f69dda0 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py @@ -172,11 +172,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py index ac48e64c..a4e60aca 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py @@ -84,12 +84,12 @@ def _restore_state_from_somewhere( fit_config: FitConfig, allowed_states: Optional[Tuple[Type[MAPState], ...]] = None, ) -> MAPState: - if fit_config.checkpointer.restore_checkpoint_path is not None: - restore_checkpoint_path = ( - pathlib.Path(fit_config.checkpointer.restore_checkpoint_path) / "c" + if fit_config.checkpointer.restore_checkpoint_dir is not None: + restore_checkpoint_dir = ( + pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir) / "c" ) state = self.restore_checkpoint( - restore_checkpoint_path=restore_checkpoint_path, + restore_checkpoint_dir=restore_checkpoint_dir, optimizer=fit_config.optimizer.method, ) elif fit_config.checkpointer.start_from_current_state is not None: @@ -101,7 +101,7 @@ def _restore_state_from_somewhere( if allowed_states is not None and not isinstance(state, allowed_states): raise ValueError( f"The type of the restored checkpoint must be within {allowed_states}. " - f"However, {fit_config.checkpointer.restore_checkpoint_path} pointed to a state " + f"However, {fit_config.checkpointer.restore_checkpoint_dir} pointed to a state " f"with type {type(state)}." ) diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py index be343127..1d23ef8e 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py @@ -39,14 +39,14 @@ def __init__( def get( self, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", + prefix: str = "", **kwargs, ) -> Union[List[PosteriorState], PosteriorState]: state = super().get( i=i, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, prefix=prefix, **kwargs, @@ -57,15 +57,15 @@ def put( self, state: PosteriorState, i: int = None, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, keep: int = 1, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: state = self._update_state(state, modify="remove") return super().put( state=state, i=i, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, keep=keep, prefix=prefix, ) @@ -73,14 +73,14 @@ def put( def pull( self, i: int = None, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", + prefix: str = "", **kwargs, ) -> PosteriorState: state = super().pull( i=i, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, prefix=prefix, **kwargs, @@ -91,14 +91,14 @@ def extract( self, keys: List[str], i: int = None, - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", + checkpoint_dir: Optional[Path] = None, + prefix: str = "", **kwargs, ) -> Union[Dict, List[Dict]]: def _extract(_i): state = self.get( i=_i, - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, prefix=prefix, ) return {k: getattr(state, k) for k in keys} diff --git a/fortuna/prob_model/posterior/swag/swag_posterior.py b/fortuna/prob_model/posterior/swag/swag_posterior.py index 1cb4bb30..19bf40a8 100755 --- a/fortuna/prob_model/posterior/swag/swag_posterior.py +++ b/fortuna/prob_model/posterior/swag/swag_posterior.py @@ -112,7 +112,7 @@ def fit( raise ValueError( "The SWAG approximation must start from a preliminary run of MAP or an existing " "checkpoint or state. Please configure `map_fit_config`, or " - "`fit_config.checkpointer.restore_checkpoint_path`, " + "`fit_config.checkpointer.restore_checkpoint_dir`, " "or `fit_config.checkpointer.start_from_current_state`." ) @@ -155,11 +155,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/swag/swag_trainer.py b/fortuna/prob_model/posterior/swag/swag_trainer.py index 30e4fa95..aae03118 100644 --- a/fortuna/prob_model/posterior/swag/swag_trainer.py +++ b/fortuna/prob_model/posterior/swag/swag_trainer.py @@ -16,10 +16,8 @@ from fortuna.prob_model.posterior.map.map_trainer import MAPTrainer from fortuna.prob_model.posterior.swag.swag_state import SWAGState from fortuna.training.callback import Callback -from fortuna.training.trainer import ( - JittedMixin, - MultiDeviceMixin, -) +from fortuna.training.mixins.jitted import JittedMixin +from fortuna.training.mixins.multi_device import MultiDeviceMixin from fortuna.typing import ( Array, Batch, @@ -98,7 +96,7 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: state = self._update_state_with_stats(state) super().save_checkpoint(state, save_checkpoint_dir, keep, force_save, prefix) diff --git a/fortuna/prob_model/predictive/base.py b/fortuna/prob_model/predictive/base.py index 9e900eef..dc7f7420 100755 --- a/fortuna/prob_model/predictive/base.py +++ b/fortuna/prob_model/predictive/base.py @@ -1,5 +1,7 @@ import abc +import inspect from typing import ( + Any, Callable, Dict, List, @@ -8,24 +10,24 @@ Union, ) -from flax.training.common_utils import shard_prng_key from jax import ( jit, lax, - pmap, random, ) from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit import jax.numpy as jnp import jax.scipy as jsp +from jax.sharding import PartitionSpec from jax.tree_util import tree_map from fortuna.data.loader import ( DataLoader, - DeviceDimensionAugmentedLoader, InputsLoader, TargetsLoader, ) +from fortuna.data.loader.base import ShardedPrefetchedLoader from fortuna.prob_model.posterior.base import Posterior from fortuna.typing import ( Array, @@ -54,7 +56,7 @@ def log_prob( data_loader: DataLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, **kwargs, ) -> jnp.ndarray: r""" @@ -77,8 +79,8 @@ def log_prob( that would be produced using the posterior distribution state. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -88,12 +90,12 @@ def log_prob( if rng is None: rng = self.rng.get() - return self._loop_fun_through_data_loader( + return self._loop_fun_through_loader( self._batched_log_prob, data_loader, n_posterior_samples, rng, - distribute, + shard, **kwargs, ) @@ -102,25 +104,49 @@ def _batched_log_prob( batch: Batch, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, **kwargs, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]: if rng is None: rng = self.rng.get() keys = random.split(rng, n_posterior_samples) - def _lik_log_batched_prob(key): - sample = self.posterior.sample(inputs=batch[0], rng=key) + def _lik_log_batched_prob(params, mutable, calib_params, calib_mutable): return self.likelihood._batched_log_prob( - sample.params, + params, batch, - mutable=sample.mutable, - calib_params=sample.calib_params, - calib_mutable=sample.calib_mutable, + mutable=mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, **kwargs, ) + if shard: + _lik_log_batched_prob = pjit( + _lik_log_batched_prob, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("dp", "fsdp")), + ) + else: + _lik_log_batched_prob = jit(_lik_log_batched_prob) + + def _fun(key): + sample = self.posterior.sample(inputs=batch[0], rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _lik_log_batched_prob( + sample.params, + sample.mutable, + sample.calib_params, + sample.calib_mutable, + ) + return jsp.special.logsumexp( - lax.map(_lik_log_batched_prob, keys), axis=0 + jnp.stack(list(map(_fun, keys))), axis=0 ) - jnp.log(n_posterior_samples) def _batched_log_joint_prob( @@ -238,9 +264,9 @@ def sample( n_target_samples: int = 1, return_aux: Optional[List[str]] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, **kwargs, - ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]]: + ) -> Tuple[Array, Dict[str, Array]] | Array: r""" Sample from an approximation of the predictive distribution for each input data point, that is @@ -254,6 +280,7 @@ def sample( Parameters ---------- + **kwargs inputs_loader : InputsLoader A loader of input data points. n_target_samples : int @@ -262,63 +289,27 @@ def sample( Return auxiliary objects. We currently support 'outputs'. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- - Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]] + Tuple[Array, Dict[str, Array]] | Array Samples for each input data point. Optionally, an auxiliary object is returned. """ if not rng: rng = self.rng.get() - def fun(_inputs): - return self._batched_sample( - _inputs, n_target_samples, return_aux, rng, **kwargs - ) - - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - fun = pmap(fun) - if return_aux is None or len(return_aux) == 0: - return jnp.concatenate( - [ - self.likelihood._unshard_array(fun(inputs)) - for inputs in inputs_loader - ], - 1, - ) - else: - samples, aux_outputs = [], [] - for inputs in inputs_loader: - _samples, _aux = fun(inputs) - samples.append(self.likelihood._unshard_array(_samples)) - if "outputs" in _aux: - aux_outputs.append( - self.likelihood._unshard_array(_aux["outputs"]) - ) - samples = jnp.concatenate(samples, axis=0) - aux = dict() - if "outputs" in aux: - aux["outputs"] = jnp.concatenate(aux_outputs, axis=0) - return samples, aux - else: - fun = jit(fun) - if return_aux is None or len(return_aux) == 0: - return jnp.concatenate([fun(inputs) for inputs in inputs_loader], 1) - else: - samples, aux_outputs = [], [] - for inputs in inputs_loader: - _samples, _aux = fun(inputs) - samples.append(_samples) - if "outputs" in _aux: - aux_outputs.append(_aux["outputs"]) - samples = jnp.concatenate(samples, axis=0) - aux = dict() - if "outputs" in aux: - aux["outputs"] = jnp.concatenate(aux_outputs, axis=0) - return samples, aux + return self._loop_fun_through_loader( + self._batched_sample, + inputs_loader, + n_target_samples, + rng, + shard, + is_fun_ensembled=False, + return_aux=return_aux, + **kwargs, + ) def _batched_sample( self, @@ -326,8 +317,9 @@ def _batched_sample( n_target_samples: int = 1, return_aux: Optional[List[str]] = None, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, **kwargs, - ) -> jnp.ndarray: + ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: if return_aux is None: return_aux = [] @@ -335,33 +327,57 @@ def _batched_sample( rng = self.rng.get() keys = random.split(rng, n_target_samples) - def _sample(key): - key1, key2 = random.split(key, 2) - _post_sample = self.posterior.sample(inputs=inputs, rng=key1) - outs = self.likelihood._batched_sample( + def _sample(rng, params, mutable, calib_params, calib_mutable): + return self.likelihood._batched_sample( 1, - _post_sample.params, + params, inputs, - mutable=_post_sample.mutable, - calib_params=_post_sample.calib_params, - calib_mutable=_post_sample.calib_mutable, + mutable=mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, return_aux=return_aux, - rng=key2, + rng=rng, **kwargs, ) - if len(return_aux) > 0: - _samples, aux = outs - return _samples.squeeze(0), aux - return outs.squeeze(0) - return lax.map(_sample, keys) + if shard: + _sample = pjit( + _sample, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("dp", "fsdp")) + if not len(return_aux) + else (PartitionSpec(("dp", "fsdp")), PartitionSpec()), + ) + + def _fun(key): + key1, key2 = random.split(key, 2) + with self.posterior.partition_manager.partitioner.mesh: + sample = self.posterior.sample(inputs=inputs, rng=key1) + return _sample( + key2, + sample.params, + sample.mutable, + sample.calib_params, + sample.calib_mutable, + ) + + samples = list(map(_fun, keys)) + if len(return_aux): + samples, aux = samples + return jnp.stack(samples), aux + return jnp.stack(samples) def sample_calibrated_outputs( self, inputs_loader: InputsLoader, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Sample parameters from the posterior distribution state and compute calibrated outputs. @@ -374,8 +390,8 @@ def sample_calibrated_outputs( Number of output samples to draw for each input. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -385,12 +401,13 @@ def sample_calibrated_outputs( if rng is None: rng = self.rng.get() - return self._loop_ensemble_fun_through_inputs_loader( + return self._loop_fun_through_loader( self._sample_batched_calibrated_outputs, inputs_loader, n_output_samples, rng, - distribute, + shard, + is_fun_ensembled=True, ) def _sample_batched_calibrated_outputs( @@ -398,39 +415,58 @@ def _sample_batched_calibrated_outputs( inputs: Array, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) - def _sample(key): - sample = self.posterior.sample(inputs=inputs, rng=key) + def _apply_fn(params, mutable, calib_params, calib_mutable): return self.likelihood._get_batched_calibrated_outputs( - params=sample.params, + params=params, inputs=inputs, - mutable=sample.mutable, - calib_params=sample.calib_params, - calib_mutable=sample.calib_mutable, + mutable=mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, ) - return lax.map(_sample, keys) + if shard: + _apply_fn = pjit( + _apply_fn, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("fsdp", "dp")), + ) + else: + _apply_fn = jit(_apply_fn) + + def _sample(key): + sample = self.posterior.sample(inputs=inputs, rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _apply_fn(sample.params, sample.mutable) + + return jnp.stack(list(map(_sample, keys))) def _sample_outputs( self, inputs_loader: InputsLoader, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() - return self._loop_fun_through_inputs_loader( + return self._loop_fun_through_loader( self._sample_batched_outputs, inputs_loader, n_output_samples, rng, - distribute, + shard, ) def _sample_batched_outputs( @@ -438,18 +474,35 @@ def _sample_batched_outputs( inputs: Array, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) - def _sample(key): - sample = self.posterior.sample(inputs=inputs, rng=key) + def _apply_fn(params, mutable): return self.likelihood.model_manager.apply( - params=sample.params, inputs=inputs, mutable=sample.mutable + params=params, inputs=inputs, mutable=mutable + ) + + if shard: + _apply_fn = pjit( + _apply_fn, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + ), + out_shardings=PartitionSpec(("fsdp", "dp")), ) + else: + _apply_fn = jit(_apply_fn) + + def _sample(key): + sample = self.posterior.sample(inputs=inputs, rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _apply_fn(sample.params, sample.mutable) - return lax.map(_sample, keys) + return jnp.stack(list(map(_sample, keys))) def _sample_outputs_loader( self, @@ -457,23 +510,12 @@ def _sample_outputs_loader( n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, return_size: bool = False, - distribute: bool = True, + shard: bool = True, ) -> Union[TargetsLoader, Tuple[TargetsLoader, int]]: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - - def _sample(key, _inputs): - sample = self.posterior.sample(inputs=_inputs, rng=key) - return self.likelihood.model_manager.apply( - params=sample.params, inputs=_inputs, mutable=sample.mutable - ) - - _sample = pmap(_sample) if distribute else jit(_sample) - iterable = [] size = 0 for inputs in inputs_loader: @@ -482,13 +524,18 @@ def _sample(key, _inputs): if not isinstance(inputs, dict) else inputs[list(inputs.keys())[0]].shape[0] ) - if distribute: - outputs = jnp.stack( - list(map(lambda key: _sample(shard_prng_key(key), inputs), keys)) + outputs = jnp.stack( + list( + map( + lambda key: self._sample_batched_outputs( + inputs=inputs, + rng=key, + shard=shard, + )[0], + keys, + ) ) - outputs = self._unshard_ensemble_arrays(outputs) - else: - outputs = lax.map(lambda key: _sample(key, inputs), keys) + ) iterable.append(outputs) iterable = TargetsLoader.from_iterable(iterable=iterable) if return_size: @@ -500,7 +547,7 @@ def mean( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive mean of the target variable, that is @@ -522,8 +569,8 @@ def mean( Number of samples to draw from the posterior distribution for each input. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -533,8 +580,8 @@ def mean( if rng is None: rng = self.rng.get() - return self._loop_fun_through_inputs_loader( - self._batched_mean, inputs_loader, n_posterior_samples, rng, distribute + return self._loop_fun_through_loader( + self._batched_mean, inputs_loader, n_posterior_samples, rng, shard ) def _batched_mean( @@ -542,25 +589,46 @@ def _batched_mean( inputs: Array, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() keys = random.split(rng, n_posterior_samples) - def fun(i, _curr_sum): - _sample = self.posterior.sample(inputs=inputs, rng=keys[i]) - _curr_sum += self.likelihood._batched_mean( - _sample.params, + def _lik_batched_mean(params, mutable, calib_params, calib_mutable): + return self.likelihood._batched_mean( + params, inputs, - _sample.mutable, - calib_params=_sample.calib_params, - calib_mutable=_sample.calib_mutable, + mutable, + calib_params=calib_params, + calib_mutable=calib_mutable, ) - return _curr_sum - curr_sum = fun(0, 0.0) - curr_sum = lax.fori_loop(1, n_posterior_samples, fun, curr_sum) - return curr_sum / n_posterior_samples + if shard: + _lik_batched_mean = pjit( + _lik_batched_mean, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable, + self.posterior.partition_manager.shardings.calib_params, + self.posterior.partition_manager.shardings.calib_mutable, + ), + out_shardings=PartitionSpec(("dp", "fsdp")), + ) + else: + _lik_batched_mean = jit(_lik_batched_mean) + + def fun(key): + _sample = self.posterior.sample(inputs=inputs, rng=key) + with self.posterior.partition_manager.partitioner.mesh: + return _lik_batched_mean( + params=_sample.params, + mutable=_sample.mutable, + calib_params=_sample.calib_params, + calib_mutable=_sample.calib_mutable, + ) + + return jnp.mean(list(map(fun, keys)), axis=0) @abc.abstractmethod def mode( @@ -569,7 +637,7 @@ def mode( n_posterior_samples: int = 30, means: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive mode of the target variable, that is @@ -592,8 +660,8 @@ def mode( An estimate of the predictive mean. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -607,7 +675,7 @@ def aleatoric_variance( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive aleatoric variance of the target variable, that is @@ -629,8 +697,8 @@ def aleatoric_variance( Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -645,7 +713,7 @@ def aleatoric_variance( inputs_loader, n_posterior_samples, rng, - distribute, + shard, ) def _batched_aleatoric_variance( @@ -678,7 +746,7 @@ def epistemic_variance( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive epistemic variance of the one-hot encoded target variable, that is @@ -700,8 +768,8 @@ def epistemic_variance( Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -711,12 +779,12 @@ def epistemic_variance( if rng is None: rng = self.rng.get() - return self._loop_fun_through_inputs_loader( + return self._loop_fun_through_loader( self._batched_epistemic_variance, inputs_loader, n_posterior_samples, rng, - distribute, + shard, ) def _batched_epistemic_variance( @@ -758,7 +826,7 @@ def variance( aleatoric_variances: Optional[jnp.ndarray] = None, epistemic_variances: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive variance of the target variable, that is @@ -785,8 +853,8 @@ def variance( An estimate of the epistemic predictive variance for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -801,7 +869,7 @@ def variance( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=key, - distribute=distribute, + shard=shard, ) if epistemic_variances is None: rng, key = random.split(rng) @@ -809,7 +877,7 @@ def variance( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=key, - distribute=distribute, + shard=shard, ) return aleatoric_variances + epistemic_variances @@ -819,7 +887,7 @@ def std( n_posterior_samples: int = 30, variances: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive standard deviation of the target variable, that is @@ -842,8 +910,8 @@ def std( An estimate of the predictive variance. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -855,84 +923,44 @@ def std( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=rng, - distribute=distribute, + shard=shard, ) return jnp.sqrt(variances) - @staticmethod - def _unshard_ensemble_arrays(arr: Array) -> Array: - arr = arr.swapaxes(1, 2) - arr = arr.reshape((arr.shape[0] * arr.shape[1],) + arr.shape[2:]) - return arr.swapaxes(0, 1) - - def _loop_fun_through_inputs_loader( + def _loop_fun_through_loader( self, fun: Callable, - inputs_loader: InputsLoader, + loader: Union[InputsLoader, DataLoader, TargetsLoader], n_posterior_samples: int, rng: PRNGKeyArray, - distribute: bool = True, + shard: bool, + is_fun_ensembled: bool = False, + return_aux: Optional[List[str]] = None, **kwargs, - ) -> Array: - def fun2(_inputs): - return fun(_inputs, n_posterior_samples, rng, **kwargs) - - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - fun2 = pmap(fun2) - return jnp.concatenate( - [ - self.likelihood._unshard_array(fun2(inputs)) - for inputs in inputs_loader - ], - 0, + ) -> tuple[Any, ...] | Array: + if shard: + loader = ShardedPrefetchedLoader( + loader=loader, partition_manager=self.posterior.partition_manager ) - fun2 = jit(fun2) - return jnp.concatenate([fun2(inputs) for inputs in inputs_loader], 0) - def _loop_fun_through_data_loader( - self, - fun: Callable, - data_loader: DataLoader, - n_posterior_samples: int, - rng: PRNGKeyArray, - distribute: bool = True, - **kwargs, - ) -> Array: - def fun2(_batch): - return fun(_batch, n_posterior_samples, rng, **kwargs) - - if distribute: - data_loader = DeviceDimensionAugmentedLoader(data_loader) - fun2 = pmap(fun2) - return jnp.concatenate( - [self.likelihood._unshard_array(fun2(batch)) for batch in data_loader], - 0, - ) - fun2 = jit(fun2) - return jnp.concatenate([fun2(batch) for batch in data_loader], 0) + def fun2(_data): + if "return_aux" in inspect.getfullargspec(fun)[0]: + return fun( + _data, + n_posterior_samples, + rng, + shard, + return_aux=return_aux, + **kwargs, + ) + return fun(_data, n_posterior_samples, rng, shard, **kwargs) - def _loop_ensemble_fun_through_inputs_loader( - self, - fun: Callable, - inputs_loader: InputsLoader, - n_posterior_samples: int, - rng: PRNGKeyArray, - distribute: bool = True, - **kwargs, - ) -> Array: - def fun2(_inputs): - return fun(_inputs, n_posterior_samples, rng, **kwargs) - - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - fun2 = pmap(fun2) - return jnp.concatenate( + outs = [fun2(data) for data in loader] + if return_aux is not None: + return tuple( [ - self._unshard_ensemble_arrays(fun2(inputs)) - for inputs in inputs_loader - ], - 1, + tree_map(lambda v: jnp.concatenate(v, int(is_fun_ensembled)), out) + for out in zip(*outs) + ] ) - fun2 = jit(fun2) - return jnp.concatenate([fun2(inputs) for inputs in inputs_loader], 1) + return jnp.concatenate(outs, int(is_fun_ensembled)) diff --git a/fortuna/prob_model/predictive/regression.py b/fortuna/prob_model/predictive/regression.py index 7f9a67e4..96e31ffc 100644 --- a/fortuna/prob_model/predictive/regression.py +++ b/fortuna/prob_model/predictive/regression.py @@ -363,9 +363,6 @@ def quantile( if type(q) == list: q = jnp.array(q) samples = self.sample( - inputs_loader=inputs_loader, - n_target_samples=n_target_samples, - rng=rng, - distribute=distribute, + inputs_loader=inputs_loader, n_target_samples=n_target_samples, rng=rng ) return jnp.quantile(samples, q, axis=0) diff --git a/fortuna/prob_model/prob_model_calibrator.py b/fortuna/prob_model/prob_model_calibrator.py index 24cc1692..589840d7 100644 --- a/fortuna/prob_model/prob_model_calibrator.py +++ b/fortuna/prob_model/prob_model_calibrator.py @@ -6,19 +6,12 @@ Union, ) -from flax import jax_utils -import jax from jax._src.prng import PRNGKeyArray import jax.numpy as jnp -from jax.tree_util import tree_map -from fortuna.data import TargetsLoader from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.output_calibrator import ( - JittedMixin, - MultiDeviceMixin, - OutputCalibratorABC, -) +from fortuna.training.output_calibrator.base import OutputCalibratorABC +from fortuna.training.output_calibrator.mixins.sharding import ShardingMixin from fortuna.typing import ( Array, Batch, @@ -60,7 +53,7 @@ def training_loss_step( }, ) - def val_loss_step( + def validation_loss_step( self, state: OutputCalibState, batch: Batch, @@ -84,49 +77,5 @@ def __str__(self): return "calibration" -class ProbModelMultiDeviceMixin(MultiDeviceMixin): - @staticmethod - def _add_device_dim_to_outputs_loader( - outputs_loader: TargetsLoader, - ) -> TargetsLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[1] % n_devices != 0: - raise ValueError( - f"The size of all output batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch of outputs with shape {batch.shape[1]} was found. " - f"Please set an appropriate batch size." - ) - shape = batch.shape - return ( - batch.swapaxes(0, 1) - .reshape(n_devices, shape[1] // n_devices, shape[0], shape[2]) - .swapaxes(1, 2) - ) - - class TargetsLoaderWrapper: - def __init__(self, outputs_loader: TargetsLoader): - self._outputs_loader = outputs_loader - - def __iter__(self): - outputs_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._outputs_loader - ) - outputs_loader = jax_utils.prefetch_to_device(outputs_loader, 2) - yield from outputs_loader - - return ( - TargetsLoaderWrapper(outputs_loader) - if outputs_loader is not None - else outputs_loader - ) - - -class JittedProbModelOutputCalibrator(JittedMixin, ProbModelOutputCalibrator): - pass - - -class MultiDeviceProbModelOutputCalibrator( - ProbModelMultiDeviceMixin, ProbModelOutputCalibrator -): +class ShardedProbModelOutputCalibrator(ShardingMixin, ProbModelOutputCalibrator): pass diff --git a/fortuna/prob_model/regression.py b/fortuna/prob_model/regression.py index f39cdbc0..58745a1d 100755 --- a/fortuna/prob_model/regression.py +++ b/fortuna/prob_model/regression.py @@ -12,6 +12,8 @@ from fortuna.model_editor.base import ModelEditor from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager from fortuna.output_calibrator.regression import RegressionTemperatureScaler +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.base import ProbModel from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig @@ -39,6 +41,7 @@ def __init__( posterior_approximator: PosteriorApproximator = SWAGPosteriorApproximator(), output_calibrator: Optional[nn.Module] = RegressionTemperatureScaler(), model_editor: Optional[ModelEditor] = None, + partitioner: Partitioner = Partitioner(), seed: int = 0, ): r""" @@ -67,6 +70,8 @@ def __init__( calibration parameters. model_editor : ModelEditor A model_editor objects. It takes the forward pass and transforms the outputs. + partitioner : Partitioner + A partitioning object for data, fully sharded data model parallelization. seed: int A random seed. @@ -119,10 +124,14 @@ def __init__( self.model_manager, self.prob_output_layer, self.output_calib_manager ) self.joint = Joint(self.prior, self.likelihood) - + self.partition_manager = PartitionManager(partitioner) self.posterior = getattr( PosteriorApproximations, posterior_approximator.__str__() - ).value(joint=self.joint, posterior_approximator=posterior_approximator) + ).value( + joint=self.joint, + posterior_approximator=posterior_approximator, + partition_manager=self.partition_manager, + ) self.predictive = RegressionPredictive(self.posterior) super().__init__(seed=seed) @@ -141,6 +150,7 @@ def _check_output_dim(self, data_loader: DataLoader): outputs = self.model_manager.apply( params=s.params, inputs=np.zeros((1,) + input_shape), mutable=s.mutable ) + if outputs.shape[1] != 2 * output_dim: raise ValueError( f"""The outputs dimension of both `model` and `likelihood_log_variance_model` must be the same as diff --git a/fortuna/training/mixin.py b/fortuna/training/mixin.py deleted file mode 100755 index ce9772d6..00000000 --- a/fortuna/training/mixin.py +++ /dev/null @@ -1,163 +0,0 @@ -import logging -import os -from typing import ( - Dict, - Optional, -) - -from flax.training import checkpoints -from flax.training.early_stopping import EarlyStopping - -from fortuna.training.name_to_train_state import NameToTrainState -from fortuna.training.train_state import TrainState -from fortuna.typing import ( - OptaxOptimizer, - Path, -) - -logger = logging.getLogger(__name__) - - -class WithCheckpointingMixin: - def __init__( - self, - **kwargs, - ): - """ - Mixin class for all trainers that need checkpointing capabilities. This is a wrapper around functions in - `flax.training.checkpoints.*`. - """ - super(WithCheckpointingMixin, self).__init__(**kwargs) - - def save_checkpoint( - self, - state: TrainState, - save_checkpoint_dir: Path, - keep: int = 1, - force_save: bool = False, - prefix: str = "checkpoint_", - ) -> None: - if save_checkpoint_dir: - save_ckpt_fn = lambda state: checkpoints.save_checkpoint( - ckpt_dir=str(save_checkpoint_dir), - target=state, - step=state.step, - prefix=prefix, - keep=keep, - overwrite=force_save, - ) - if ( - hasattr(state, "grad_accumulated") - and state.grad_accumulated is not None - ): - # do not save grad accumulated in the ckpt - state = state.replace(grad_accumulated=None) - save_ckpt_fn(state) - else: - save_ckpt_fn(state) - - def restore_checkpoint( - self, - restore_checkpoint_path: Path, - optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - name_to_train_state: NameToTrainState = NameToTrainState, - **kwargs, - ) -> TrainState: - if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( - restore_checkpoint_path - ): - raise ValueError( - f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_path), - target=None, - step=None, - prefix=prefix, - parallel=True, - ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." - ) - name = "".join([chr(n) for n in d["encoded_name"].tolist()]) - return name_to_train_state[name].value.init_from_dict(d, optimizer, **kwargs) - - def get_path_latest_checkpoint( - self, checkpoint_dir: Path, prefix: str = "checkpoint_" - ) -> Optional[str]: - return checkpoints.latest_checkpoint(ckpt_dir=checkpoint_dir, prefix=prefix) - - -class WithEarlyStoppingMixin: - def __init__( - self, - *, - early_stopping_monitor: str = "val_loss", - early_stopping_min_delta: float = 0.0, - early_stopping_patience: Optional[int] = 0, - early_stopping_mode: str = "min", - early_stopping_verbose: bool = True, - **kwargs, - ): - super(WithEarlyStoppingMixin, self).__init__(**kwargs) - self.early_stopping_monitor = early_stopping_monitor - self.early_stopping_mode = early_stopping_mode - self.early_stopping_patience = early_stopping_patience - - if early_stopping_patience is None or early_stopping_patience <= 0: - if early_stopping_verbose: - logging.info( - f"Early stopping not enabled. Set `early_stopping_patience>=0` to enable it." - ) - elif self.early_stopping_mode is None or self.early_stopping_mode not in ( - "min", - "max", - ): - if early_stopping_verbose: - logging.warning( - f"`early_stopping_mode={early_stopping_mode}` is not a valid. Early stopping will be disabled." - ) - else: - self._early_stopping = EarlyStopping( - min_delta=early_stopping_min_delta, patience=early_stopping_patience - ) - if early_stopping_verbose: - logging.info( - "If validation data are provided, early stopping will be enabled." - ) - - @property - def is_early_stopping_active(self) -> bool: - return not ( - (self.early_stopping_patience is None or self.early_stopping_patience <= 0) - or ( - self.early_stopping_mode is None - or self.early_stopping_mode not in ("min", "max") - ) - ) - - def early_stopping_update( - self, validation_metrics: Dict[str, float] - ) -> Optional[bool]: - improved = None - if self.is_early_stopping_active: - early_stopping_monitor = validation_metrics[self.early_stopping_monitor] - if self.early_stopping_mode == "max": - early_stopping_monitor = -early_stopping_monitor - improved, self._early_stopping = self._early_stopping.update( - early_stopping_monitor - ) - return improved - - -class InputValidatorMixin: - def __init__(self, *args, **kwargs): - if len(args) > 0: - raise AttributeError("Cannot recognize inputs arguments: {}".format(args)) - if len(kwargs) > 0: - raise AttributeError( - "{} are not valid input arguments.".format(list(kwargs.keys())) - ) - super(InputValidatorMixin, self).__init__(*args, **kwargs) diff --git a/fortuna/training/mixins/checkpointing.py b/fortuna/training/mixins/checkpointing.py new file mode 100644 index 00000000..a8e3ad5c --- /dev/null +++ b/fortuna/training/mixins/checkpointing.py @@ -0,0 +1,160 @@ +import logging +from typing import Optional + +from flax.training.orbax_utils import save_args_from_target +from jax import ( + ShapeDtypeStruct, + local_devices, +) +from jax.sharding import SingleDeviceSharding +from jax.tree_util import ( + tree_map, + tree_map_with_path, +) +from orbax.checkpoint import ( + ArrayRestoreArgs, + CheckpointManager, +) + +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.name_to_train_state import NameToTrainState +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + OptaxOptimizer, + Path, +) +from fortuna.utils.checkpoint import get_checkpoint_manager + +logger = logging.getLogger(__name__) + + +class WithCheckpointingMixin: + def __init__( + self, + *, + partition_manager: PartitionManager, + checkpoint_manager: Optional[CheckpointManager] = None, + **kwargs, + ): + """ + Mixin class for all trainers that need checkpointing capabilities. This is a wrapper around functions in + `flax.training.checkpoints.*`. + + Parameters + ---------- + partition_manager: PartitionManager, + An object that manages partitions. + checkpoint_manager: CheckpointManager + A checkpoint manager + """ + super(WithCheckpointingMixin, self).__init__(**kwargs) + self.partition_manager = partition_manager + self.checkpoint_manager = checkpoint_manager + + def save_checkpoint( + self, + state: TrainState, + save_checkpoint_dir: Path, + keep: int = 1, + force_save: bool = False, + ) -> None: + checkpoint_manager = ( + get_checkpoint_manager( + checkpoint_dir=save_checkpoint_dir, keep_top_n_checkpoints=keep + ) + if save_checkpoint_dir is not None + else self.checkpoint_manager + ) + if checkpoint_manager is not None: + save_args = save_args_from_target(state) + + def save_ckpt_fn(_state): + return checkpoint_manager.save( + _state.step, + _state, + force=force_save, + save_kwargs={"save_args": save_args}, + ) + + if ( + hasattr(state, "grad_accumulated") + and state.grad_accumulated is not None + ): + # do not save grad accumulated in the ckpt + state = state.replace(grad_accumulated=None) + save_ckpt_fn(state) + + def restore_checkpoint( + self, + restore_checkpoint_dir: Path, + optimizer: Optional[OptaxOptimizer] = None, + name_to_train_state: NameToTrainState = NameToTrainState, + ) -> TrainState: + ref = self._get_ref() + restored = self.checkpoint_manager.restore( + self.checkpoint_manager.latest_step(), + items=ref, + restore_kwargs={"restore_args": ref}, + directory=restore_checkpoint_dir, + ) + + if optimizer is not None: + restored = restored.replace( + tx=optimizer, opt_state=optimizer.init(restored.params) + ) + + return restored + + def get_shapes_dtypes_checkpoint( + self, + restore_checkpoint_dir: Path, + name_to_train_state: NameToTrainState = NameToTrainState, + ): + ref = self._get_ref_without_shardings() + state = self.checkpoint_manager.restore( + self.checkpoint_manager.latest_step(), + items=ref, + restore_kwargs=dict(restore_args=ref), + directory=restore_checkpoint_dir, + ) + name = "".join([chr(n) for n in state["encoded_name"].get().tolist()]) + state = name_to_train_state[name].value.init_from_dict(state) + return tree_map(lambda v: _get_shapes_dtypes(v.get()), state) + + def _get_ref_from_shardings(self): + return tree_map_with_path( + lambda p, sharding, shape_dtype: ArrayRestoreArgsWithShape( + mesh=self.partition_manager.partitioner.mesh, + sharding=sharding, + dtype=shape_dtype.dtype, + shape=shape_dtype.shape, + ), + self.partition_manager.shardings, + self.partition_manager.shapes_dtypes, + ) + + def _get_ref_without_shardings(self): + return tree_map_with_path( + lambda p, v: ArrayRestoreArgs( + lazy=True, sharding=SingleDeviceSharding(device=local_devices()[0]) + ), + self.checkpoint_manager.structure(), + ) + + def _get_ref(self): + if ( + self.partition_manager.shardings is not None + and self.partition_manager.shapes_dtypes is not None + ): + return self._get_ref_from_shardings() + return self._get_ref_without_shardings() + + +class ArrayRestoreArgsWithShape(ArrayRestoreArgs): + def __init__(self, shape, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shape = shape + + +def _get_shapes_dtypes(v): + return ShapeDtypeStruct(shape=v.shape, dtype=v.dtype) diff --git a/fortuna/training/mixins/early_stopping.py b/fortuna/training/mixins/early_stopping.py new file mode 100644 index 00000000..543e9258 --- /dev/null +++ b/fortuna/training/mixins/early_stopping.py @@ -0,0 +1,71 @@ +import logging +from typing import ( + Dict, + Optional, +) + +from flax.training.early_stopping import EarlyStopping + +logger = logging.getLogger(__name__) + + +class WithEarlyStoppingMixin: + def __init__( + self, + *, + early_stopping_monitor: str = "val_loss", + early_stopping_min_delta: float = 0.0, + early_stopping_patience: Optional[int] = 0, + early_stopping_mode: str = "min", + early_stopping_verbose: bool = True, + **kwargs, + ): + super(WithEarlyStoppingMixin, self).__init__(**kwargs) + self.early_stopping_monitor = early_stopping_monitor + self.early_stopping_mode = early_stopping_mode + self.early_stopping_patience = early_stopping_patience + + if early_stopping_patience is None or early_stopping_patience <= 0: + if early_stopping_verbose: + logging.info( + f"Early stopping not enabled. Set `early_stopping_patience>=0` to enable it." + ) + elif self.early_stopping_mode is None or self.early_stopping_mode not in ( + "min", + "max", + ): + if early_stopping_verbose: + logging.warning( + f"`early_stopping_mode={early_stopping_mode}` is not a valid. Early stopping will be disabled." + ) + else: + self._early_stopping = EarlyStopping( + min_delta=early_stopping_min_delta, patience=early_stopping_patience + ) + if early_stopping_verbose: + logging.info( + "If validation data are provided, early stopping will be enabled." + ) + + @property + def is_early_stopping_active(self) -> bool: + return not ( + (self.early_stopping_patience is None or self.early_stopping_patience <= 0) + or ( + self.early_stopping_mode is None + or self.early_stopping_mode not in ("min", "max") + ) + ) + + def early_stopping_update( + self, validation_metrics: Dict[str, float] + ) -> Optional[bool]: + improved = None + if self.is_early_stopping_active: + early_stopping_monitor = validation_metrics[self.early_stopping_monitor] + if self.early_stopping_mode == "max": + early_stopping_monitor = -early_stopping_monitor + improved, self._early_stopping = self._early_stopping.update( + early_stopping_monitor + ) + return improved diff --git a/fortuna/training/mixins/input_validator.py b/fortuna/training/mixins/input_validator.py new file mode 100644 index 00000000..19df7fdd --- /dev/null +++ b/fortuna/training/mixins/input_validator.py @@ -0,0 +1,9 @@ +class InputValidatorMixin: + def __init__(self, *args, **kwargs): + if len(args) > 0: + raise AttributeError("Cannot recognize inputs arguments: {}".format(args)) + if len(kwargs) > 0: + raise AttributeError( + "{} are not valid input arguments.".format(list(kwargs.keys())) + ) + super(InputValidatorMixin, self).__init__(*args, **kwargs) diff --git a/fortuna/training/mixins/jitted.py b/fortuna/training/mixins/jitted.py new file mode 100644 index 00000000..a7323aeb --- /dev/null +++ b/fortuna/training/mixins/jitted.py @@ -0,0 +1,54 @@ +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Optional, + Tuple, +) + +from flax import jax_utils +from flax.core import FrozenDict +import jax +from jax._src.prng import PRNGKeyArray +import jax.numpy as jnp +from optax._src.base import PyTree + +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + Array, + Batch, +) + + +class JittedMixin: + @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7)) + def training_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, Any]]: + return super().training_step( + state, batch, loss_fun, rng, n_data, unravel, kwargs + ) + + @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7, 8)) + def validation_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Dict[str, jnp.ndarray]: + return super().validation_step( + state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs + ) diff --git a/fortuna/training/mixins/multi_device.py b/fortuna/training/mixins/multi_device.py new file mode 100644 index 00000000..a38b4eef --- /dev/null +++ b/fortuna/training/mixins/multi_device.py @@ -0,0 +1,158 @@ +from functools import partial +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, +) + +from flax import jax_utils +from flax.core import FrozenDict +import jax +from jax import ( + lax, + random, +) +from jax._src.prng import PRNGKeyArray +import jax.numpy as jnp +from jax.tree_util import tree_map +from optax._src.base import PyTree + +from fortuna.data.loader import DataLoader +from fortuna.training.callback import Callback +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + Array, + Batch, +) + + +class MultiDeviceMixin: + all_reduce_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.multi_device = True + + @staticmethod + def _add_device_dim_to_input_data_loader(data_loader: DataLoader) -> DataLoader: + def _reshape_input_batch(batch): + n_devices = jax.local_device_count() + if batch.shape[0] % n_devices != 0: + raise ValueError( + f"The size of all batches must be a multiple of {n_devices}, that is the number of " + f"available devices. Please set an appropriate batch size in the data loader." + ) + single_input_shape = batch.shape[1:] + # reshape to (local_devices, device_batch_size, *single_input_shape) + return batch.reshape((n_devices, -1) + single_input_shape) + + class DataLoaderWrapper: + def __init__(self, data_loader): + self.data_loader = data_loader + + def __iter__(self): + data_loader = map( + lambda batch: tree_map(_reshape_input_batch, batch), + self.data_loader, + ) + data_loader = jax_utils.prefetch_to_device(data_loader, 2) + yield from data_loader + + return ( + DataLoaderWrapper(data_loader) if data_loader is not None else data_loader + ) + + @staticmethod + def _sync_mutable(state: TrainState) -> TrainState: + return ( + state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) + if state.mutable is not None + else state + ) + + @staticmethod + def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: + arr = lax.pmean(arr, axis_name="batch") + return arr + + def _sync_state(self, state: TrainState) -> TrainState: + state = self._sync_mutable(state) + return jax.device_get(tree_map(lambda x: x[0], state)) + + def on_train_start( + self, state: TrainState, data_loaders: List[DataLoader], rng: PRNGKeyArray + ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: + state, data_loaders, rng = super(MultiDeviceMixin, self).on_train_start( + state, data_loaders, rng + ) + state = jax_utils.replicate(state) + data_loaders = [ + self._add_device_dim_to_input_data_loader(dl) for dl in data_loaders + ] + model_key = random.split(rng, jax.local_device_count()) + return state, data_loaders, model_key + + def on_train_end(self, state: TrainState) -> TrainState: + state = super(MultiDeviceMixin, self).on_train_end(state) + return jax.device_get(tree_map(lambda x: x[0], state)) + + def training_step_start(self, rng: PRNGKeyArray, step: int) -> PRNGKeyArray: + step = step if isinstance(step, int) or step.ndim == 0 else step[0] + return jax.vmap(lambda r: random.fold_in(r, step))(rng) + + @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7)) + def training_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, Any]]: + return super().training_step( + state, batch, loss_fun, rng, n_data, unravel, kwargs + ) + + def training_step_end( + self, + current_epoch: int, + state: TrainState, + aux: Dict[str, Any], + batch: Batch, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], + callbacks: Optional[List[Callback]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, jnp.ndarray]]: + state, training_losses_and_metrics = super( + MultiDeviceMixin, self + ).training_step_end( + current_epoch, state, aux, batch, metrics, callbacks, kwargs + ) + return state, tree_map(lambda x: x.mean(), training_losses_and_metrics) + + def on_validation_start(self, state: TrainState) -> TrainState: + state = super(MultiDeviceMixin, self).on_validation_start(state) + state = self._sync_mutable(state) + return state + + @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7, 8)) + def validation_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Dict[str, jnp.ndarray]: + validation_losses_and_metrics = super().validation_step( + state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs + ) + return lax.pmean(validation_losses_and_metrics, axis_name="batch") diff --git a/fortuna/training/mixins/sharding.py b/fortuna/training/mixins/sharding.py new file mode 100644 index 00000000..4cf4be8b --- /dev/null +++ b/fortuna/training/mixins/sharding.py @@ -0,0 +1,101 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, +) + +from flax.core import FrozenDict +from jax import eval_shape +from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit +import jax.numpy as jnp +from jax.sharding import PartitionSpec +from optax._src.base import PyTree + +from fortuna.data.loader import DataLoader +from fortuna.data.loader.base import ShardedPrefetchedLoader +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + Array, + Batch, +) + + +class ShardingMixin: + def __init__(self, *, partition_manager: PartitionManager, **kwargs): + super().__init__(partition_manager=partition_manager, **kwargs) + self.partition_manager = partition_manager + + def training_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, Any]]: + with self.partition_manager.partitioner.mesh: + return pjit( + super().training_step, + static_argnums=(2, 4, 5, 6), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + out_shardings=( + self.partition_manager.shardings, + PartitionSpec(), + ), + )(state, batch, loss_fun, rng, n_data, unravel, kwargs) + + def validation_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Dict[str, jnp.ndarray]: + with self.partition_manager.partitioner.mesh: + return pjit( + super().validation_step, + static_argnums=(2, 4, 5, 6, 7), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + )(state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs) + + def on_train_start( + self, state: TrainState, data_loaders: List[DataLoader], rng: PRNGKeyArray + ) -> Tuple[TrainState, List[ShardedPrefetchedLoader], PRNGKeyArray]: + state, data_loaders, rng = super(ShardingMixin, self).on_train_start( + state, data_loaders, rng + ) + + if self.freeze_fun is not None: + self.partition_manager = PartitionManager( + partitioner=self.partition_manager.partitioner + ) + self.partition_manager.shapes_dtypes = eval_shape(lambda: state) + + data_loaders = [ + ShardedPrefetchedLoader( + loader=dl, + partition_manager=self.partition_manager, + partition_spec=PartitionSpec(("dp", "fsdp")), + ) + for dl in data_loaders + ] + return state, data_loaders, rng diff --git a/fortuna/training/output_calibrator/__init__.py b/fortuna/training/output_calibrator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/training/output_calibrator.py b/fortuna/training/output_calibrator/base.py similarity index 55% rename from fortuna/training/output_calibrator.py rename to fortuna/training/output_calibrator/base.py index ca6e7373..9bee26db 100644 --- a/fortuna/training/output_calibrator.py +++ b/fortuna/training/output_calibrator/base.py @@ -1,6 +1,5 @@ import abc import collections -from functools import partial import logging from typing import ( Any, @@ -12,11 +11,8 @@ Union, ) -from flax import jax_utils from flax.training.common_utils import stack_forest -import jax from jax import ( - lax, random, value_and_grad, ) @@ -31,11 +27,10 @@ TargetsLoader, ) from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.mixin import ( - InputValidatorMixin, - WithCheckpointingMixin, - WithEarlyStoppingMixin, -) +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin from fortuna.typing import ( Array, Batch, @@ -58,6 +53,7 @@ def __init__( self, *args, calib_outputs_loader: TargetsLoader, + partition_manager: PartitionManager, predict_fn: Callable[[jnp.ndarray], jnp.ndarray], uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray], val_outputs_loader: Optional[TargetsLoader] = None, @@ -68,7 +64,9 @@ def __init__( eval_every_n_epochs: int = 1, **kwargs, ): - super(OutputCalibratorABC, self).__init__(*args, **kwargs) + super(OutputCalibratorABC, self).__init__( + *args, partition_manager=partition_manager, **kwargs + ) self._calib_outputs_loader = calib_outputs_loader self._val_outputs_loader = val_outputs_loader self.predict_fn = predict_fn @@ -78,7 +76,6 @@ def __init__( self.keep_top_n_checkpoints = keep_top_n_checkpoints self.disable_training_metrics_computation = disable_training_metrics_computation self.eval_every_n_epochs = eval_every_n_epochs - self.multi_device = False def train( self, @@ -96,7 +93,7 @@ def train( verbose: bool = True, ) -> Tuple[OutputCalibState, Status]: training_losses_and_metrics = collections.defaultdict(list) - val_losses_and_metrics = collections.defaultdict(list) + validation_losses_and_metrics = collections.defaultdict(list) state, data_loaders, outputs_loaders, rng = self.on_train_start( state, @@ -135,11 +132,11 @@ def train( # validation loop if self.should_perform_validation(val_data_loader, epoch): # performance evaluation on the whole validation dataset - state = self.on_val_start(state) + state = self.on_validation_start(state) ( - val_losses_and_metrics_current_epoch, - val_epoch_metrics_str, - ) = self._val_loop( + validation_losses_and_metrics_current_epoch, + validation_epoch_metrics_str, + ) = self._validation_loop( loss_fun=loss_fun, metrics=metrics, rng=rng, @@ -150,11 +147,13 @@ def train( verbose=verbose, ) if verbose: - logging.info(f"Epoch: {epoch + 1} | " + val_epoch_metrics_str) + logging.info( + f"Epoch: {epoch + 1} | " + validation_epoch_metrics_str + ) # keep track of training losses and metrics [granularity=epoch] and check for early stopping - for k in val_losses_and_metrics_current_epoch.keys(): - val_losses_and_metrics[k].append( - val_losses_and_metrics_current_epoch[k] + for k in validation_losses_and_metrics_current_epoch.keys(): + validation_losses_and_metrics[k].append( + validation_losses_and_metrics_current_epoch[k] ) # check for early stopping if self.is_early_stopping_active and self._early_stopping.should_stop: @@ -165,8 +164,10 @@ def train( training_status = { k: jnp.array(v) for k, v in training_losses_and_metrics.items() } - val_status = {k: jnp.array(v) for k, v in val_losses_and_metrics.items()} - status = dict(**training_status, **val_status) + validation_status = { + k: jnp.array(v) for k, v in validation_losses_and_metrics.items() + } + status = dict(**training_status, **validation_status) state = self.on_train_end(state) return state, status @@ -302,27 +303,14 @@ def training_step_end( if not self.disable_training_metrics_computation and metrics is not None: preds = self.predict_fn(aux["outputs"]) uncertainties = self.uncertainty_fn(aux["outputs"]) - if self.multi_device: - training_batch_metrics = self.compute_metrics( - preds.reshape((preds.shape[0] * preds.shape[1],) + preds.shape[2:]), - uncertainties.reshape( - (uncertainties.shape[0] * uncertainties.shape[1],) - + uncertainties.shape[2:] - ), - batch[1].reshape( - (batch[1].shape[0] * batch[1].shape[1],) + batch[1].shape[2:] - ), - metrics, - ) - else: - training_batch_metrics = self.compute_metrics( - preds, uncertainties, batch[1], metrics - ) + training_batch_metrics = self.compute_metrics( + preds, uncertainties, batch[1], metrics + ) for k, v in training_batch_metrics.items(): training_losses_and_metrics[k] = v return training_losses_and_metrics - def _val_loop( + def _validation_loop( self, loss_fun: Callable, metrics: Optional[ @@ -335,10 +323,10 @@ def _val_loop( val_dataset_size: int, verbose: bool = True, ) -> Tuple[Dict[str, float], str]: - val_losses_and_metrics_epoch_all_steps = [] - val_epoch_metrics_str = "" + validation_losses_and_metrics_epoch_all_steps = [] + validation_epoch_metrics_str = "" for batch, outputs in zip(val_data_loader, val_outputs_loader): - val_losses_and_metrics_current_batch = self.val_step( + validation_losses_and_metrics_current_batch = self.validation_step( state, batch, outputs, @@ -347,24 +335,24 @@ def _val_loop( val_dataset_size, metrics, ) - val_losses_and_metrics_epoch_all_steps.append( - val_losses_and_metrics_current_batch + validation_losses_and_metrics_epoch_all_steps.append( + validation_losses_and_metrics_current_batch ) # compute validation losses and metrics for the current epoch - val_losses_and_metrics_current_epoch = self.val_epoch_end( - val_losses_and_metrics_epoch_all_steps, state + validation_losses_and_metrics_current_epoch = self.validation_epoch_end( + validation_losses_and_metrics_epoch_all_steps, state ) # logging if verbose: - val_epoch_metrics_str = " | ".join( + validation_epoch_metrics_str = " | ".join( [ f"{m}: {round(float(v), 5)}" - for m, v in val_losses_and_metrics_current_epoch.items() + for m, v in validation_losses_and_metrics_current_epoch.items() ] ) - return val_losses_and_metrics_current_epoch, val_epoch_metrics_str + return validation_losses_and_metrics_current_epoch, validation_epoch_metrics_str - def val_step( + def validation_step( self, state: OutputCalibState, batch: Batch, @@ -376,12 +364,14 @@ def val_step( Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Array], ...] ] = None, ) -> Dict[str, jnp.ndarray]: - val_loss, aux = self.val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - val_metrics = self.val_metrics_step(aux, batch, metrics) - return {"val_loss": val_loss, **val_metrics} + validation_loss, aux = self.validation_loss_step( + state, batch, outputs, loss_fun, rng, n_data + ) + validation_metrics = self.validation_metrics_step(aux, batch, metrics) + return {"validation_loss": validation_loss, **validation_metrics} @abc.abstractmethod - def val_loss_step( + def validation_loss_step( self, state: OutputCalibState, batch: Batch, @@ -392,7 +382,7 @@ def val_loss_step( ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: pass - def val_metrics_step( + def validation_metrics_step( self, aux: Dict[str, jnp.ndarray], batch: Batch, @@ -401,13 +391,13 @@ def val_metrics_step( ] = None, ) -> Dict[str, jnp.ndarray]: if metrics is not None: - val_metrics = self.compute_metrics( + validation_metrics = self.compute_metrics( self.predict_fn(aux["outputs"]), self.uncertainty_fn(aux["outputs"]), batch[1], metrics, ) - return {f"val_{m}": v for m, v in val_metrics.items()} + return {f"validation_{m}": v for m, v in validation_metrics.items()} else: return {} @@ -418,19 +408,21 @@ def training_epoch_end( training_losses_and_metrics_current_epoch ) - def val_epoch_end( + def validation_epoch_end( self, - val_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], + validation_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], state: OutputCalibState, ) -> Dict[str, float]: - val_losses_and_metrics_current_epoch = self._get_mean_losses_and_metrics( - val_losses_and_metrics_current_epoch + validation_losses_and_metrics_current_epoch = self._get_mean_losses_and_metrics( + validation_losses_and_metrics_current_epoch ) # early stopping - improved = self.early_stopping_update(val_losses_and_metrics_current_epoch) - if improved and self.save_checkpoint_dir: + improved = self.early_stopping_update( + validation_losses_and_metrics_current_epoch + ) + if improved and self.save_checkpoint_dir is not None: self.save_checkpoint(state, self.save_checkpoint_dir, force_save=True) - return val_losses_and_metrics_current_epoch + return validation_losses_and_metrics_current_epoch def _get_mean_losses_and_metrics( self, losses_and_metrics: List[Dict[str, jnp.ndarray]] @@ -472,7 +464,7 @@ def on_train_end(self, state: OutputCalibState) -> OutputCalibState: ) return state - def on_val_start(self, state: OutputCalibState) -> OutputCalibState: + def on_validation_start(self, state: OutputCalibState) -> OutputCalibState: return state def compute_metrics( @@ -488,214 +480,3 @@ def compute_metrics( for metric in metrics: metrics_vals[metric.__name__] = metric(preds, uncertainties, targets) return metrics_vals - - -class JittedMixin: - @partial(jax.jit, static_argnums=(0, 4, 6)) - def training_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Tuple[OutputCalibState, Dict[str, Any]]: - return super().training_step(state, batch, outputs, loss_fun, rng, n_data) - - @partial(jax.jit, static_argnums=(0, 4, 6)) - def val_loss_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Dict[str, jnp.ndarray]: - return super().val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - - -class MultiDeviceMixin: - all_reduce_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.multi_device = True - - @staticmethod - def _add_device_dim_to_data_loader(data_loader: DataLoader) -> DataLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch with shape {batch.shape[0]} was found. " - f"Please set an appropriate batch size." - ) - return batch.reshape((n_devices, -1) + batch.shape[1:]) - - class DataLoaderWrapper: - def __init__(self, data_loader: DataLoader): - self._data_loader = data_loader - - def __iter__(self): - data_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._data_loader - ) - data_loader = jax_utils.prefetch_to_device(data_loader, 2) - yield from data_loader - - return ( - DataLoaderWrapper(data_loader) if data_loader is not None else data_loader - ) - - @staticmethod - def _add_device_dim_to_outputs_loader( - outputs_loader: TargetsLoader, - ) -> TargetsLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all output batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch of outputs with shape {batch.shape[0]} was found. " - f"Please set an appropriate batch size." - ) - return batch.reshape((n_devices, -1) + batch.shape[1:]) - - class TargetsLoaderWrapper: - def __init__(self, outputs_loader: TargetsLoader): - self._outputs_loader = outputs_loader - - def __iter__(self): - outputs_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._outputs_loader - ) - outputs_loader = jax_utils.prefetch_to_device(outputs_loader, 2) - yield from outputs_loader - - return ( - TargetsLoaderWrapper(outputs_loader) - if outputs_loader is not None - else outputs_loader - ) - - @staticmethod - def sync_mutable(state: OutputCalibState) -> OutputCalibState: - return ( - state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) - if state.mutable["output_calibrator"] is not None - else state - ) - - @staticmethod - def sync_gradients_and_loss( - grads: jnp.ndarray, loss: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - grad = lax.pmean(grads, axis_name="batch") - loss = lax.pmean(loss, axis_name="batch") - return grad, loss - - def save_checkpoint( - self, - state: OutputCalibState, - save_checkpoint_dir: Path, - keep: int = 1, - force_save: bool = False, - prefix: str = "checkpoint_", - ) -> None: - state = self.sync_mutable(state) - state = jax.device_get(tree_map(lambda x: x[0], state)) - return super(MultiDeviceMixin, self).save_checkpoint( - state, save_checkpoint_dir, keep, force_save, prefix - ) - - def on_train_start( - self, - state: OutputCalibState, - data_loaders: List[DataLoader], - outputs_loaders: List[TargetsLoader], - rng: PRNGKeyArray, - ) -> Tuple[OutputCalibState, List[DataLoader], List[TargetsLoader], PRNGKeyArray]: - state, data_loaders, outputs_loaders, rng = super( - MultiDeviceMixin, self - ).on_train_start(state, data_loaders, outputs_loaders, rng) - state = jax_utils.replicate(state) - data_loaders = [ - self._add_device_dim_to_data_loader(dl) if dl is not None else dl - for dl in data_loaders - ] - outputs_loaders = [ - self._add_device_dim_to_outputs_loader(ol) if ol is not None else ol - for ol in outputs_loaders - ] - model_key = random.split(rng, jax.local_device_count()) - return state, data_loaders, outputs_loaders, model_key - - def on_train_end(self, state: OutputCalibState) -> OutputCalibState: - state = super(MultiDeviceMixin, self).on_train_end(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 4, 6)) - def training_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Tuple[OutputCalibState, Dict[str, Any]]: - return super().training_step(state, batch, outputs, loss_fun, rng, n_data) - - def training_step_end( - self, - current_epoch: int, - state: OutputCalibState, - aux: Dict[str, Any], - batch: Batch, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], - ) -> Dict[str, jnp.ndarray]: - training_losses_and_metrics = super(MultiDeviceMixin, self).training_step_end( - current_epoch, state, aux, batch, metrics - ) - return tree_map(lambda x: x.mean(), training_losses_and_metrics) - - def on_val_start(self, state: OutputCalibState) -> OutputCalibState: - state = super(MultiDeviceMixin, self).on_val_start(state) - if state.mutable["output_calibrator"] is not None: - state = self.sync_mutable(state) - return state - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 4, 6)) - def val_loss_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Dict[str, jnp.ndarray]: - val_losses = super().val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - return lax.pmean(val_losses, axis_name="batch") - - def val_metrics_step( - self, - aux: Dict[str, jnp.ndarray], - batch: Batch, - metrics: Optional[ - Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Array], ...] - ] = None, - ) -> Dict[str, jnp.ndarray]: - outputs = aux["outputs"] - outputs = outputs.reshape(outputs.shape[0] * outputs.shape[1], -1) - targets = batch[1].reshape(batch[1].shape[0] * batch[1].shape[1], -1) - if metrics is not None: - val_metrics = self.compute_metrics( - self.predict_fn(outputs), self.uncertainty_fn(outputs), targets, metrics - ) - return {f"val_{m}": v for m, v in val_metrics.items()} - else: - return {} diff --git a/fortuna/training/output_calibrator/mixins/sharding.py b/fortuna/training/output_calibrator/mixins/sharding.py new file mode 100644 index 00000000..95e0e231 --- /dev/null +++ b/fortuna/training/output_calibrator/mixins/sharding.py @@ -0,0 +1,104 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Tuple, +) + +from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit +import jax.numpy as jnp +from jax.sharding import PartitionSpec + +from fortuna.data.loader import ( + DataLoader, + TargetsLoader, +) +from fortuna.data.loader.base import ShardedPrefetchedLoader +from fortuna.output_calib_model.state import OutputCalibState +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.typing import ( + Array, + Batch, +) + + +class ShardingMixin: + def __init__(self, *, partition_manager: PartitionManager, **kwargs): + super().__init__(partition_manager=partition_manager, **kwargs) + self.partition_manager = partition_manager + + def training_step( + self, + state: OutputCalibState, + batch: Batch, + outputs: Array, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + ) -> Tuple[OutputCalibState, Dict[str, Any]]: + with self.partition_manager.partitioner.mesh: + return pjit( + super().training_step, + static_argnums=(3, 5), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + outputs.sharding, + PartitionSpec(), + ), + )(state, batch, outputs, loss_fun, rng, n_data) + + def validation_loss_step( + self, + state: OutputCalibState, + batch: Batch, + outputs: Array, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + ) -> Dict[str, jnp.ndarray]: + with self.partition_manager.partitioner.mesh: + fun = super().validation_loss_step + return pjit( + fun, + static_argnums=(3, 5), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + outputs.sharding, + PartitionSpec(), + ), + )(state, batch, outputs, loss_fun, rng, n_data) + + def on_train_start( + self, + state: OutputCalibState, + data_loaders: List[DataLoader], + outputs_loaders: List[TargetsLoader], + rng: PRNGKeyArray, + ) -> Tuple[ + OutputCalibState, + List[ShardedPrefetchedLoader], + List[ShardedPrefetchedLoader], + PRNGKeyArray, + ]: + state, data_loaders, output_loaders, rng = super( + ShardingMixin, self + ).on_train_start(state, data_loaders, outputs_loaders, rng) + data_loaders = [ + ShardedPrefetchedLoader( + loader=data_loader, + partition_manager=self.partition_manager, + partition_spec=PartitionSpec(("dp", "fsdp")), + ) + for data_loader in data_loaders + ] + outputs_loaders = [ + ShardedPrefetchedLoader( + loader=output_loader, partition_manager=self.partition_manager + ) + for output_loader in output_loaders + ] + return state, data_loaders, outputs_loaders, rng diff --git a/fortuna/training/train_state_repository.py b/fortuna/training/train_state_repository.py index 25cd7fdd..b54b94a1 100644 --- a/fortuna/training/train_state_repository.py +++ b/fortuna/training/train_state_repository.py @@ -1,5 +1,4 @@ -from copy import deepcopy -import os +from shutil import rmtree from typing import ( Dict, List, @@ -7,7 +6,11 @@ Union, ) -from fortuna.training.mixin import WithCheckpointingMixin +from jax import eval_shape +from orbax.checkpoint import CheckpointManager + +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.training.train_state import TrainState from fortuna.typing import ( OptaxOptimizer, @@ -16,95 +19,78 @@ class TrainStateRepository(WithCheckpointingMixin): - def __init__(self, checkpoint_dir: Optional[Path] = None): - super().__init__() - self.checkpoint_dir = checkpoint_dir - self.__state = None + def __init__( + self, + partition_manager: PartitionManager, + checkpoint_manager: Optional[CheckpointManager] = None, + ): + super().__init__(partition_manager=partition_manager) + self.checkpoint_manager = checkpoint_manager + self._state = None def get( self, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, ) -> Union[Dict, TrainState]: - if not checkpoint_path and not self.checkpoint_dir and not self.__state: + if not checkpoint_dir and not self.checkpoint_manager and not self._state: raise ValueError("No state available.") - if checkpoint_path or self.checkpoint_dir: + if checkpoint_dir or self.checkpoint_manager: return self.restore_checkpoint( - restore_checkpoint_path=checkpoint_path or self.checkpoint_dir, - optimizer=optimizer, - prefix=prefix, - **kwargs, + restore_checkpoint_dir=checkpoint_dir, optimizer=optimizer ) if optimizer is not None: - self.__state = self.__state.replace( - tx=optimizer, opt_state=optimizer.init(self.__state.params) - ) - return deepcopy(self.__state) + state = self.partition_manager.reshard(self._state) + state = state.replace(tx=optimizer, opt_state=optimizer.init(state.params)) + return state + return self._state def put( self, state: TrainState, - checkpoint_path: Optional[Path] = None, + checkpoint_dir: Optional[Path] = None, keep: int = 1, - prefix: str = "checkpoint_", ) -> None: - if checkpoint_path or self.checkpoint_dir: + if checkpoint_dir or self.checkpoint_manager: self.save_checkpoint( state=state, - save_checkpoint_dir=checkpoint_path or self.checkpoint_dir, + save_checkpoint_dir=checkpoint_dir, keep=keep, force_save=True, - prefix=prefix, ) else: - self.__state = state + self._state = state def pull( self, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "checkpoint_", - **kwargs, ) -> TrainState: state = self.get( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, - **kwargs, ) - if checkpoint_path or self.checkpoint_dir: - os.remove( - checkpoint_path - or self.get_path_latest_checkpoint(self.checkpoint_dir, prefix=prefix) - ) + if checkpoint_dir or self.checkpoint_manager: + if checkpoint_dir is None: + self.checkpoint_manager.delete(self.checkpoint_manager.latest_step()) + else: + rmtree(checkpoint_dir) return state def update( self, variables: Dict, - checkpoint_path: Path = None, + checkpoint_dir: Path = None, optimizer: Optional[OptaxOptimizer] = None, keep: int = 1, - prefix: str = "checkpoint_", - **kwargs, ): state = self.pull( - checkpoint_path=checkpoint_path, + checkpoint_dir=checkpoint_dir, optimizer=optimizer, - prefix=prefix, - **kwargs, ) state = state.replace(**variables) - self.put(state, checkpoint_path=checkpoint_path, keep=keep, prefix=prefix) + self.put(state, checkpoint_dir=checkpoint_dir, keep=keep) - def extract( - self, - keys: List[str], - checkpoint_path: Optional[Path] = None, - prefix: str = "checkpoint_", - **kwargs, - ) -> Dict: - state = self.get(checkpoint_path=checkpoint_path, prefix=prefix, **kwargs) + def extract(self, keys: List[str], checkpoint_dir: Optional[Path] = None) -> Dict: + state = self.get(checkpoint_dir=checkpoint_dir) return {k: getattr(state, k) for k in keys} diff --git a/fortuna/training/trainer.py b/fortuna/training/trainer.py index 5959bc2f..7f54950e 100755 --- a/fortuna/training/trainer.py +++ b/fortuna/training/trainer.py @@ -2,6 +2,7 @@ import collections from functools import partial import logging +from pathlib import Path as _Path from typing import ( Any, Callable, @@ -12,29 +13,28 @@ Union, ) -from flax import jax_utils from flax.core import FrozenDict from flax.training.common_utils import stack_forest import jax from jax import ( - lax, random, value_and_grad, + vmap, ) from jax._src.prng import PRNGKeyArray import jax.numpy as jnp from jax.tree_util import tree_map from optax._src.base import PyTree +from orbax.checkpoint import CheckpointManager from tqdm import trange from tqdm.std import tqdm as TqdmDecorator from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.training.callback import Callback -from fortuna.training.mixin import ( - InputValidatorMixin, - WithCheckpointingMixin, - WithEarlyStoppingMixin, -) +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin from fortuna.training.train_state import TrainState from fortuna.typing import ( AnyKey, @@ -71,6 +71,8 @@ def __init__( self, *args, predict_fn: Callable[[jnp.ndarray], jnp.ndarray], + partition_manager: Optional[PartitionManager], + checkpoint_manager: Optional[CheckpointManager], uncertainty_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, save_checkpoint_dir: Optional[Path] = None, save_every_n_steps: Optional[int] = None, @@ -80,7 +82,12 @@ def __init__( freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None, **kwargs, ): - super(TrainerABC, self).__init__(*args, **kwargs) + super(TrainerABC, self).__init__( + *args, + partition_manager=partition_manager, + checkpoint_manager=checkpoint_manager, + **kwargs, + ) self.predict_fn = predict_fn self.uncertainty_fn = uncertainty_fn self.save_checkpoint_dir = save_checkpoint_dir @@ -348,8 +355,13 @@ def validation_epoch_end( improved = self.early_stopping_update( validation_losses_and_metrics_current_epoch ) - if improved and self.save_checkpoint_dir: - self.save_checkpoint(state, self.save_checkpoint_dir, force_save=True) + if improved and self.save_checkpoint_dir is not None: + self.save_checkpoint( + state, + str(_Path(self.save_checkpoint_dir) / "best"), + force_save=True, + prefix="", + ) return validation_losses_and_metrics_current_epoch def train( @@ -357,11 +369,11 @@ def train( rng: PRNGKeyArray, state: TrainState, loss_fun: Callable, - training_dataloader: DataLoader, + training_data_loader: DataLoader, training_dataset_size: int, n_epochs: int = 1, metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - validation_dataloader: Optional[DataLoader] = None, + validation_data_loader: Optional[DataLoader] = None, validation_dataset_size: Optional[int] = None, verbose: bool = True, unravel: Optional[Callable[[any], PyTree]] = None, @@ -369,18 +381,18 @@ def train( **kwargs, ) -> Tuple[TrainState, Status]: training_kwargs = FrozenDict(kwargs) - if validation_dataloader: + if validation_data_loader: assert ( validation_dataset_size is not None - ), "`validation_dataset_size` is required when `validation_dataloader` is provided." + ), "`validation_dataset_size` is required when `validation_data_loader` is provided." training_losses_and_metrics = collections.defaultdict(list) validation_losses_and_metrics = collections.defaultdict(list) - state, dataloaders, rng = self.on_train_start( - state, [training_dataloader, validation_dataloader], rng + state, data_loaders, rng = self.on_train_start( + state, [training_data_loader, validation_data_loader], rng ) - training_dataloader, validation_dataloader = dataloaders + training_data_loader, validation_data_loader = data_loaders progress_bar = trange(n_epochs, desc="Epoch") for epoch in progress_bar: @@ -395,7 +407,7 @@ def train( metrics, rng, state, - training_dataloader, + training_data_loader, training_dataset_size, training_kwargs, verbose, @@ -410,7 +422,7 @@ def train( ) # validation loop - if self.should_perform_validation(validation_dataloader, epoch): + if self.should_perform_validation(validation_data_loader, epoch): # performance evaluation on the whole validation dataset state = self.on_validation_start(state) ( @@ -422,7 +434,7 @@ def train( rng=rng, state=state, training_kwargs=training_kwargs, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=validation_dataset_size, verbose=verbose, unravel=unravel, @@ -460,7 +472,7 @@ def _training_loop( metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], jnp.ndarray], ...]], rng: PRNGKeyArray, state: TrainState, - training_dataloader: DataLoader, + training_data_loader: DataLoader, training_dataset_size: int, training_kwargs: FrozenDict[str, Any], verbose: bool, @@ -476,7 +488,7 @@ def _training_loop( state = self.training_epoch_start(state, callbacks) # ensure to use a different key at each step model_key = self.training_step_start(rng, state.step) - for step, batch in enumerate(training_dataloader): + for step, batch in enumerate(training_data_loader): # forward and backward pass state, aux = self.training_step( state, @@ -539,14 +551,14 @@ def _validation_loop( rng: PRNGKeyArray, state: TrainState, training_kwargs: FrozenDict[str, Any], - validation_dataloader: DataLoader, + validation_data_loader: DataLoader, validation_dataset_size: int, verbose: bool = True, unravel: Optional[Callable[[any], PyTree]] = None, ) -> Tuple[Dict[str, float], str]: validation_losses_and_metrics_epoch_all_steps = [] validation_epoch_metrics_str = "" - for batch in validation_dataloader: + for batch in validation_data_loader: validation_losses_and_metrics_current_batch = self.validation_step( state, batch, @@ -590,10 +602,10 @@ def _get_mean_losses_and_metrics( return losses_and_metrics def should_perform_validation( - self, validation_dataloader: Optional[DataLoader], epoch: int + self, validation_data_loader: Optional[DataLoader], epoch: int ) -> bool: return ( - validation_dataloader is not None + validation_data_loader is not None and self.eval_every_n_epochs > 0 and epoch % self.eval_every_n_epochs == 0 ) @@ -605,7 +617,7 @@ def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: def on_train_start( self, state: TrainState, - dataloaders: List[DataLoader], + data_loaders: List[DataLoader], rng: PRNGKeyArray, ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: if self.freeze_fun is not None: @@ -639,14 +651,17 @@ def on_train_start( ) ), ) - return state, dataloaders, rng + return state, data_loaders, rng def on_train_end(self, state: TrainState) -> TrainState: self.save_checkpoint( state, - save_checkpoint_dir=self.save_checkpoint_dir, + save_checkpoint_dir=str(_Path(self.save_checkpoint_dir) / "last") + if self.save_checkpoint_dir is not None + else None, keep=self.keep_top_n_checkpoints, force_save=True, + prefix="", ) if self.freeze_fun is not None: @@ -677,7 +692,7 @@ def compute_metrics( def training_step_start( self, rng: PRNGKeyArray, step: Union[int, jax.Array] ) -> PRNGKeyArray: - step = step if isinstance(step, int) or step.ndim == 0 else step[0] + # step = step if isinstance(step, int) or step.ndim == 0 else step[0] return random.fold_in(rng, step) def _sync_state(self, state: TrainState) -> TrainState: @@ -689,14 +704,14 @@ def save_checkpoint( save_checkpoint_dir: Path, keep: int = 1, force_save: bool = False, - prefix: str = "checkpoint_", + prefix: str = "", ) -> None: if self.freeze_fun is not None: state = state.replace( params=self._get_all_params(state), frozen_params=None ) return super().save_checkpoint( - self._sync_state(state), save_checkpoint_dir, keep, force_save, prefix + self._sync_state(state), save_checkpoint_dir, keep, force_save ) def _get_all_params( @@ -712,162 +727,3 @@ def _get_all_params( ) ) return trainable_params if trainable_params is not None else state.params - - -class JittedMixin: - @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7)) - def training_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, Any]]: - return super().training_step( - state, batch, loss_fun, rng, n_data, unravel, kwargs - ) - - @partial(jax.jit, static_argnums=(0, 3, 5, 6, 7, 8)) - def validation_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Dict[str, jnp.ndarray]: - return super().validation_step( - state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs - ) - - -class MultiDeviceMixin: - all_reduce_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.multi_device = True - - @staticmethod - def _add_device_dim_to_input_dataloader(dataloader: DataLoader) -> DataLoader: - def _reshape_input_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all batches must be a multiple of {n_devices}, that is the number of " - f"available devices. Please set an appropriate batch size in the data loader." - ) - single_input_shape = batch.shape[1:] - # reshape to (local_devices, device_batch_size, *single_input_shape) - return batch.reshape((n_devices, -1) + single_input_shape) - - class DataLoaderWrapper: - def __init__(self, dataloader): - self.dataloader = dataloader - - def __iter__(self): - dataloader = map( - lambda batch: tree_map(_reshape_input_batch, batch), self.dataloader - ) - dataloader = jax_utils.prefetch_to_device(dataloader, 2) - yield from dataloader - - return DataLoaderWrapper(dataloader) if dataloader is not None else dataloader - - @staticmethod - def _sync_mutable(state: TrainState) -> TrainState: - return ( - state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) - if state.mutable is not None - else state - ) - - @staticmethod - def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: - arr = lax.pmean(arr, axis_name="batch") - return arr - - def _sync_state(self, state: TrainState) -> TrainState: - state = self._sync_mutable(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - def on_train_start( - self, state: TrainState, dataloaders: List[DataLoader], rng: PRNGKeyArray - ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: - state, dataloaders, rng = super(MultiDeviceMixin, self).on_train_start( - state, dataloaders, rng - ) - state = jax_utils.replicate(state) - dataloaders = [ - self._add_device_dim_to_input_dataloader(dl) for dl in dataloaders - ] - model_key = random.split(rng, jax.local_device_count()) - return state, dataloaders, model_key - - def on_train_end(self, state: TrainState) -> TrainState: - state = super(MultiDeviceMixin, self).on_train_end(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - def training_step_start(self, rng: PRNGKeyArray, step: int) -> PRNGKeyArray: - step = step if isinstance(step, int) or step.ndim == 0 else step[0] - return jax.vmap(lambda r: random.fold_in(r, step))(rng) - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7)) - def training_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, Any]]: - return super().training_step( - state, batch, loss_fun, rng, n_data, unravel, kwargs - ) - - def training_step_end( - self, - current_epoch: int, - state: TrainState, - aux: Dict[str, Any], - batch: Batch, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], - callbacks: Optional[List[Callback]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, jnp.ndarray]]: - state, training_losses_and_metrics = super( - MultiDeviceMixin, self - ).training_step_end( - current_epoch, state, aux, batch, metrics, callbacks, kwargs - ) - return state, tree_map(lambda x: x.mean(), training_losses_and_metrics) - - def on_validation_start(self, state: TrainState) -> TrainState: - state = super(MultiDeviceMixin, self).on_validation_start(state) - state = self._sync_mutable(state) - return state - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7, 8)) - def validation_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Dict[str, jnp.ndarray]: - validation_losses_and_metrics = super().validation_step( - state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs - ) - return lax.pmean(validation_losses_and_metrics, axis_name="batch") diff --git a/fortuna/utils/checkpoint.py b/fortuna/utils/checkpoint.py new file mode 100644 index 00000000..f8ad6214 --- /dev/null +++ b/fortuna/utils/checkpoint.py @@ -0,0 +1,21 @@ +from typing import Optional + +from orbax.checkpoint import ( + Checkpointer, + CheckpointManager, + CheckpointManagerOptions, + PyTreeCheckpointHandler, +) + + +def get_checkpoint_manager( + checkpoint_dir: str, keep_top_n_checkpoints: Optional[int] = None +): + if checkpoint_dir is not None: + options = CheckpointManagerOptions( + create=True, max_to_keep=keep_top_n_checkpoints + ) + return CheckpointManager( + checkpoint_dir, Checkpointer(PyTreeCheckpointHandler()), options + ) + return None diff --git a/fortuna/utils/mesh.py b/fortuna/utils/mesh.py index bb845669..af502e57 100644 --- a/fortuna/utils/mesh.py +++ b/fortuna/utils/mesh.py @@ -1,356 +1,89 @@ -from functools import partial -import re -import random -from ml_collections import ConfigDict -from ml_collections.config_dict.config_dict import placeholder +from typing import Dict -import flax -import jax -import jax.numpy as jnp -from jax.sharding import PartitionSpec as PS -from jax.sharding import Mesh -from jax.experimental import mesh_utils -from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint -from jax.experimental.pjit import pjit +from jax import local_device_count +from jax.experimental.mesh_utils import create_device_mesh from jax.interpreters import pxla +from jax.lax import with_sharding_constraint +from jax.sharding import ( + Mesh, + PartitionSpec, +) import numpy as np +from fortuna.utils.partition import get_names_from_partition_spec -class DistributedConfig(object): - """ Utility class for initializing JAX distributed. """ - @staticmethod - def get_default_config(updates=None): - config = ConfigDict() - config.initialize_jax_distributed = False - config.coordinator_address = placeholder(str) - config.num_processes = placeholder(int) - config.process_id = placeholder(int) - config.local_device_ids = placeholder(str) +def get_mesh(axis_dims: Dict[str, int]): + keys = tuple(axis_dims.keys()) + dims = tuple(axis_dims.values()) - if updates is not None: - config.update(ConfigDict(updates).copy_and_resolve_references()) - return config - - @classmethod - def initialize(cls, config): - config = cls.get_default_config(config) - if config.initialize_jax_distributed: - if config.local_device_ids is not None: - local_device_ids = [int(x) for x in config.local_device_ids.split(',')] - else: - local_device_ids = None - - jax.distributed.initialize( - coordinator_address=config.coordinator_address, - num_processes=config.num_processes, - process_id=config.process_id, - local_device_ids=local_device_ids, - ) - - -def make_shard_and_gather_fns(partition_specs, dtype_specs=None): - """ Create pytree of sharding and gathering functions from pytree of - partition specs. - """ - float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) - - def make_to_dtype_fn(dtype_spec): - def to_dtype(tensor): - if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes: - # Convert all float tensors to the same dtype - return tensor.astype(dtype_specs) - elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'): - return tensor.astype(dtype_spec.dtype) - return tensor - return to_dtype - - def make_shard_fn(partition_spec, dtype_spec=None): - jax_shard_function = pjit( - make_to_dtype_fn(dtype_spec), - in_shardings=None, - out_shardings=partition_spec + allowed_keys = ("dp", "fsdp", "mp") + if set(keys) != set(allowed_keys): + raise ValueError( + f"`axis_dims` must contain exactly the following keys: {allowed_keys}." ) - def shard_fn(tensor): - return jax_shard_function(tensor).block_until_ready() - return shard_fn - - def make_gather_fn(partition_spec, dtype_spec=None): - jax_gather_fn = pjit( - make_to_dtype_fn(dtype_spec), - in_shardings=partition_spec, - out_shardings=None + for v in dims: + if type(v) != int: + raise ValueError("All values in `axis_dims` must be integers or `-1`.") + if len(np.where(np.array(dims) == -1)[0]) > 1: + raise ValueError("At most one axis dimension can be `-1`.") + + n_devices = local_device_count() + + fixed_prod = np.prod([v for v in dims if v != -1]) + reminder = n_devices % fixed_prod + if fixed_prod > n_devices: + raise ValueError( + f"The product of the specified axis dimensions cannot be greater than {n_devices}, " + f"the number of available devices." ) - def gather_fn(tensor): - return jax.device_get(jax_gather_fn(tensor)) - return gather_fn - - if dtype_specs is None or dtype_specs in float_dtypes: - shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs) - gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs) - else: - shard_fns = jax.tree_util.tree_map( - make_shard_fn, partition_specs, dtype_specs - ) - gather_fns = jax.tree_util.tree_map( - make_gather_fn, partition_specs, dtype_specs + if reminder != 0: + raise ValueError( + "The product of the axis dimensions must divide the number of available devices. " + f"However, {n_devices} were found, and {fixed_prod} to be the product of the specified axis " + f"dimensions." ) - return shard_fns, gather_fns + dims = tuple([dims[np.where(np.array(keys) == k)[0][0]] for k in allowed_keys]) + mesh_shape = np.arange(n_devices).reshape(dims).shape + physical_mesh = create_device_mesh(mesh_shape) + return Mesh(physical_mesh, allowed_keys) -def get_jax_mesh(axis_dims, names): - if axis_dims.startswith('!'): - # Allow splitting a physical mesh axis if needed - mesh_axis_splitting = True - axis_dims = axis_dims[1:] - else: - mesh_axis_splitting = False - if ':' in axis_dims: - dims = [] - dim_names = [] - for axis in axis_dims.split(','): - name, dim = axis.split(':') - assert name in names - dims.append(int(dim)) - dim_names.append(name) - assert(set(dim_names) == set(names)) - else: - dims = [int(x) for x in axis_dims.split(',')] - dim_names = names - assert len(dims) == len(names) - mesh_shape = np.arange(jax.device_count()).reshape(dims).shape - if mesh_axis_splitting: - physical_mesh = np.array(jax.devices()).reshape(mesh_shape) - else: - physical_mesh = mesh_utils.create_device_mesh(mesh_shape) - return Mesh(physical_mesh, dim_names) +def names_in_current_mesh(*names) -> bool: + """ + Check if the axis names in the current mesh contain the names provided. + Parameters + ---------- + names: List[str] + Provided names. -def names_in_current_mesh(*names): - """ Check if current mesh axes contain these names. """ + Returns + ------- + bool: + Whether the list of provided names is contained in the list of axis names from the current mesh. + """ mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names return set(names) <= set(mesh_axis_names) -def get_names_from_partition_spec(partition_specs): - """ Return axis names from partition specs. """ - names = set() - if isinstance(partition_specs, dict): - partition_specs = partition_specs.values() - for item in partition_specs: - if item is None: - continue - elif isinstance(item, str): - names.add(item) - else: - names.update(get_names_from_partition_spec(item)) +def with_conditional_sharding_constraint(x, partition_specs): + """ - return list(names) + Parameters + ---------- + x + partition_specs + Returns + ------- -def with_sharding_constraint(x, partition_specs): + """ """ A smarter version of with_sharding_constraint that only applies the constraint if the current mesh contains the axes in the partition specs. """ axis_names = get_names_from_partition_spec(partition_specs) if names_in_current_mesh(*axis_names): - x = _with_sharding_constraint(x, partition_specs) + x = with_sharding_constraint(x, partition_specs) return x - - -def wrap_function_with_rng(rng): - """ To be used as decorator, automatically bookkeep a RNG for the wrapped function. """ - def wrap_function(function): - def wrapped(*args, **kwargs): - nonlocal rng - rng, split_rng = jax.random.split(rng) - return function(split_rng, *args, **kwargs) - return wrapped - return wrap_function - - -def init_rng(seed): - global jax_utils_rng - jax_utils_rng = JaxRNG.from_seed(seed) - - -def next_rng(*args, **kwargs): - global jax_utils_rng - return jax_utils_rng(*args, **kwargs) - - -def get_metrics(metrics, unreplicate=False, stack=False): - if unreplicate: - metrics = flax.jax_utils.unreplicate(metrics) - metrics = jax.device_get(metrics) - if stack: - return jax.tree_map(lambda *args: np.stack(args), *metrics) - else: - return {key: float(val) for key, val in metrics.items()} - - -def mse_loss(val, target, valid=None): - if valid is None: - valid = jnp.ones((*target.shape[:2], 1)) - valid = valid.astype(jnp.float32) - loss = jnp.mean( - jnp.where( - valid > 0.0, - jnp.square(val - target), - 0.0 - ) - ) - return loss - - -def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): - if valid is None: - valid = jnp.ones(tokens.shape[:2]) - valid = valid.astype(jnp.float32) - valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) - logits = logits.astype(jnp.float32) # for numerical stability - token_log_prob = jnp.squeeze( - jnp.take_along_axis( - jax.nn.log_softmax(logits, axis=-1), - jnp.expand_dims(tokens, -1), - axis=-1, - ), - -1, - ) - token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) - loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length) - correct = jnp.where( - valid > 0.0, - jnp.argmax(logits, axis=-1) == tokens, - jnp.array(False) - ) - accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) - return loss, accuracy - - -def global_norm(tree): - """ Return the global L2 norm of a pytree. """ - squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) - flattened, _ = jax.flatten_util.ravel_pytree(squared) - return jnp.sqrt(jnp.sum(flattened)) - - -def average_metrics(metrics): - return jax.tree_map( - lambda *args: jnp.mean(jnp.stack(args)), - *metrics - ) - - -def get_float_dtype_by_name(dtype): - return { - 'bf16': jnp.bfloat16, - 'bfloat16': jnp.bfloat16, - 'fp16': jnp.float16, - 'float16': jnp.float16, - 'fp32': jnp.float32, - 'float32': jnp.float32, - 'fp64': jnp.float64, - 'float64': jnp.float64, - }[dtype] - - -def float_tensor_to_dtype(tensor, dtype): - if dtype is None or dtype == '': - return tensor - if isinstance(dtype, str): - dtype = get_float_dtype_by_name(dtype) - float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) - if getattr(tensor, 'dtype', None) in float_dtypes: - tensor = tensor.astype(dtype) - return tensor - - -def float_to_dtype(tree, dtype): - return jax.tree_util.tree_map( - partial(float_tensor_to_dtype, dtype=dtype), tree - ) - - -def get_gradient_checkpoint_policy(name): - return { - 'everything_saveable': jax.checkpoint_policies.everything_saveable, - 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, - 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, - 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - }[name] - - -def tree_path_to_string(path, sep=None): - keys = [] - for key in path: - if isinstance(key, jax.tree_util.SequenceKey): - keys.append(str(key.idx)) - elif isinstance(key, jax.tree_util.DictKey): - keys.append(str(key.key)) - elif isinstance(key, jax.tree_util.GetAttrKey): - keys.append(str(key.name)) - elif isinstance(key, jax.tree_util.FlattenedIndexKey): - keys.append(str(key.key)) - else: - keys.append(str(key)) - if sep is None: - return tuple(keys) - return sep.join(keys) - - -def flatten_tree(xs, is_leaf=None, sep=None): - flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf) - output = {} - for key, val in flattened: - output[tree_path_to_string(key, sep=sep)] = val - return output - - -def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): - """ An extended version of jax.tree_util.tree_map, where the mapped function - f takes both the name (path) and the tree leaf as input. - """ - return jax.tree_util.tree_map_with_path( - lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), - tree, *rest, - is_leaf=is_leaf - ) - - -def match_partition_rules(rules, params): - """ Returns a pytree of PartitionSpec according to rules. Supports handling - Flax TrainState and Optax optimizer state. - """ - def get_partition_spec(name, leaf): - if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1: - """ Don't partition scalar values. """ - return PS() - for rule, ps in rules: - if re.search(rule, name) is not None: - return ps - raise ValueError(f'Partition rule not found for param: {name}') - return named_tree_map(get_partition_spec, params, sep='/') - - -def get_weight_decay_mask(exclusions): - """ Return a weight decay mask function that computes the pytree masks - according to the given exclusion rules. - """ - def decay(name, _): - for rule in exclusions: - if re.search(rule, name) is not None: - return False - return True - - def weight_decay_mask(params): - return named_tree_map(decay, params, sep='/') - - return weight_decay_mask - - -def tree_apply(fns, tree): - """ Apply a pytree of functions to the pytree. """ - return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree) - diff --git a/fortuna/utils/nested_dicts.py b/fortuna/utils/nested_dicts.py index f778036a..754255ce 100644 --- a/fortuna/utils/nested_dicts.py +++ b/fortuna/utils/nested_dicts.py @@ -8,6 +8,13 @@ ) from flax.core import FrozenDict +from jax.tree_util import ( + DictKey, + FlattenedIndexKey, + GetAttrKey, + SequenceKey, + tree_map_with_path, +) from fortuna.typing import AnyKey @@ -213,3 +220,39 @@ def nested_update( else: updated_mapping[k] = v return updated_mapping + + +def path_to_string( + path: Tuple[Union[DictKey, SequenceKey, GetAttrKey, FlattenedIndexKey, AnyKey]], + separator: str = None, +) -> Union[str, Tuple[str]]: + """ + Transform a sequence of keys into a string. + + Parameters + ---------- + path: Tuple[Union[DictKey, SequenceKey, GetAttrKey, FlattenedIndexKey, AnyKey]] + A sequence of keys. + separator: str + A string to interpret as separator. + + Returns + ------- + Union[str, Tuple[str]] + A sequence of keys. + """ + keys = [] + for key in path: + if isinstance(key, SequenceKey): + keys.append(str(key.idx)) + elif isinstance(key, DictKey): + keys.append(str(key.key)) + elif isinstance(key, GetAttrKey): + keys.append(str(key.name)) + elif isinstance(key, FlattenedIndexKey): + keys.append(str(key.key)) + else: + keys.append(str(key)) + if separator is None: + return tuple(keys) + return separator.join(keys) diff --git a/fortuna/utils/partition.py b/fortuna/utils/partition.py new file mode 100644 index 00000000..a181b599 --- /dev/null +++ b/fortuna/utils/partition.py @@ -0,0 +1,73 @@ +import re +from typing import ( + Dict, + Tuple, +) + +import jax.numpy as jnp +from jax.sharding import ( + PartitionSpec, + Sharding, +) +from jax.tree_util import tree_map_with_path +import numpy as np +from optax._src.base import PyTree + +from fortuna.utils.nested_dicts import path_to_string + + +def named_tree_map(f, tree, *rest, is_leaf=None, separator=None): + return tree_map_with_path( + lambda string_path, x, *r: f( + path_to_string(string_path, separator=separator), x, *r + ), + tree, + *rest, + is_leaf=is_leaf, + ) + + +def match_partition_specs( + partition_specs: Dict[str, PartitionSpec], tree: PyTree +) -> PyTree: + """ + Match partition specifics to a tree structure. + + Parameters + ---------- + partition_specs: Dict[str, Tuple[str]] + tree: PyTree + + Returns + ------- + PyTree + A tree of partition specifics. + """ + + def get_partition_spec(path, shape_leaf): + if len(shape_leaf.shape) == 0 or np.prod(shape_leaf.shape) == 1: + # do not partition scalar values + return PartitionSpec() + for rule, ps in partition_specs.items(): + if re.search(rule, path) is not None: + return ps + # raise ValueError(f"A partition rule for the following path was not found: `{path}`") + return PartitionSpec() + + return named_tree_map(get_partition_spec, tree, separator="/") + + +def get_names_from_partition_spec(partition_specs): + """Return axis names from partition specs.""" + names = set() + if isinstance(partition_specs, dict): + partition_specs = partition_specs.values() + for item in partition_specs: + if item is None: + continue + elif isinstance(item, str): + names.add(item) + else: + names.update(get_names_from_partition_spec(item)) + + return list(names) diff --git a/fortuna/utils/port.py b/fortuna/utils/port.py new file mode 100644 index 00000000..1bd99bef --- /dev/null +++ b/fortuna/utils/port.py @@ -0,0 +1,6 @@ +import socket + + +def is_port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 diff --git a/poetry.lock b/poetry.lock index 5450bd50..cc9ffade 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "absl-py" version = "1.4.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -15,6 +16,7 @@ files = [ name = "absolufy-imports" version = "0.3.1" description = "A tool to automatically replace relative imports with absolute ones." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -26,6 +28,7 @@ files = [ name = "aiohttp" version = "3.8.4" description = "Async http client/server framework (asyncio)" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -134,6 +137,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -148,6 +152,7 @@ frozenlist = ">=1.1.0" name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -159,6 +164,7 @@ files = [ name = "antlr4-python3-runtime" version = "4.9.3" description = "ANTLR 4.9.3 runtime for Python 3.7" +category = "main" optional = true python-versions = "*" files = [ @@ -169,6 +175,7 @@ files = [ name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = true python-versions = ">=3.6.2" files = [ @@ -189,6 +196,7 @@ trio = ["trio (>=0.16,<0.22)"] name = "appdirs" version = "1.4.4" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = "*" files = [ @@ -200,6 +208,7 @@ files = [ name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "main" optional = false python-versions = "*" files = [ @@ -211,6 +220,7 @@ files = [ name = "argon2-cffi" version = "21.3.0" description = "The secure Argon2 password hashing algorithm." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -230,6 +240,7 @@ tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -267,6 +278,7 @@ tests = ["pytest"] name = "array-record" version = "0.2.0" description = "A file format that achieves a new frontier of IO efficiency" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -283,6 +295,7 @@ etils = {version = "*", extras = ["epath"]} name = "arrow" version = "1.2.3" description = "Better dates & times for Python" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -297,6 +310,7 @@ python-dateutil = ">=2.7.0" name = "asttokens" version = "2.2.1" description = "Annotate AST trees with source code positions" +category = "main" optional = false python-versions = "*" files = [ @@ -314,6 +328,7 @@ test = ["astroid", "pytest"] name = "astunparse" version = "1.6.3" description = "An AST unparser for Python" +category = "main" optional = false python-versions = "*" files = [ @@ -329,6 +344,7 @@ wheel = ">=0.23.0,<1.0" name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -340,6 +356,7 @@ files = [ name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -358,6 +375,7 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "babel" version = "2.12.1" description = "Internationalization utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -372,6 +390,7 @@ pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "main" optional = false python-versions = "*" files = [ @@ -383,6 +402,7 @@ files = [ name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" +category = "main" optional = false python-versions = ">=3.6.0" files = [ @@ -401,6 +421,7 @@ lxml = ["lxml"] name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -419,6 +440,7 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] name = "boto3" version = "1.26.145" description = "The AWS SDK for Python" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -438,6 +460,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.29.145" description = "Low-level, data-driven core of boto 3." +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -457,6 +480,7 @@ crt = ["awscrt (==0.16.9)"] name = "cached-property" version = "1.5.2" description = "A decorator for caching properties in classes." +category = "main" optional = false python-versions = "*" files = [ @@ -468,6 +492,7 @@ files = [ name = "cachetools" version = "5.3.0" description = "Extensible memoizing collections and decorators" +category = "main" optional = false python-versions = "~=3.7" files = [ @@ -479,6 +504,7 @@ files = [ name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -490,6 +516,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." +category = "main" optional = false python-versions = "*" files = [ @@ -566,6 +593,7 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -577,6 +605,7 @@ files = [ name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -661,6 +690,7 @@ files = [ name = "chex" version = "0.1.7" description = "Chex: Testing made fun, in JAX!" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -681,6 +711,7 @@ typing-extensions = {version = ">=4.2.0", markers = "python_version < \"3.11\""} name = "click" version = "8.1.3" description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -695,6 +726,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cloudpickle" version = "2.2.1" description = "Extended pickling support for Python objects" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -706,6 +738,7 @@ files = [ name = "codespell" version = "2.2.4" description = "Codespell" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -723,6 +756,7 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -734,6 +768,7 @@ files = [ name = "comm" version = "0.1.3" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -753,6 +788,7 @@ typing = ["mypy (>=0.990)"] name = "contextlib2" version = "21.6.0" description = "Backports and enhancements for the contextlib module" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -764,6 +800,7 @@ files = [ name = "contourpy" version = "1.0.7" description = "Python library for calculating contours of 2D quadrilateral grids" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -838,6 +875,7 @@ test-no-images = ["pytest"] name = "coverage" version = "7.2.5" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -904,6 +942,7 @@ toml = ["tomli"] name = "cryptography" version = "41.0.0" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -945,6 +984,7 @@ test-randomorder = ["pytest-randomly"] name = "cycler" version = "0.11.0" description = "Composable style cycles" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -956,6 +996,7 @@ files = [ name = "datasets" version = "2.12.0" description = "HuggingFace community-driven open-source library of datasets" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -999,6 +1040,7 @@ vision = ["Pillow (>=6.2.1)"] name = "debugpy" version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1026,6 +1068,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1037,6 +1080,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1048,6 +1092,7 @@ files = [ name = "dill" version = "0.3.6" description = "serialize all of python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1062,6 +1107,7 @@ graph = ["objgraph (>=1.7.2)"] name = "distlib" version = "0.3.6" description = "Distribution utilities" +category = "dev" optional = false python-versions = "*" files = [ @@ -1073,6 +1119,7 @@ files = [ name = "dm-tree" version = "0.1.8" description = "Tree is a library for working with nested data structures." +category = "main" optional = false python-versions = "*" files = [ @@ -1121,6 +1168,7 @@ files = [ name = "docutils" version = "0.19" description = "Docutils -- Python Documentation Utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1132,6 +1180,7 @@ files = [ name = "et-xmlfile" version = "1.1.0" description = "An implementation of lxml.xmlfile for the standard library" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1143,6 +1192,7 @@ files = [ name = "etils" version = "1.2.0" description = "Collection of common python utils" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1177,6 +1227,7 @@ lazy-imports = ["etils[ecolab]"] name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1191,6 +1242,7 @@ test = ["pytest (>=6)"] name = "executing" version = "1.2.0" description = "Get the currently executing AST node of a frame, and other information" +category = "main" optional = false python-versions = "*" files = [ @@ -1205,6 +1257,7 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] name = "fastjsonschema" version = "2.16.3" description = "Fastest Python implementation of JSON schema" +category = "main" optional = false python-versions = "*" files = [ @@ -1219,6 +1272,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.0" description = "A platform independent file lock." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1234,6 +1288,7 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "p name = "flatbuffers" version = "23.5.9" description = "The FlatBuffers serialization format for Python" +category = "main" optional = false python-versions = "*" files = [ @@ -1243,13 +1298,14 @@ files = [ [[package]] name = "flax" -version = "0.6.10" +version = "0.6.11" description = "Flax: A neural network library for JAX designed for flexibility" +category = "main" optional = false python-versions = "*" files = [ - {file = "flax-0.6.10-py3-none-any.whl", hash = "sha256:8dccc7b84b00ff6f59a36dc0e79f5919498cfeb009a41f8c07f68bf2513198db"}, - {file = "flax-0.6.10.tar.gz", hash = "sha256:e2174a0df7bb4921f29b2cbd33f55ddf6eed161d6df61809fe374a25e473fb2f"}, + {file = "flax-0.6.11-py3-none-any.whl", hash = "sha256:3ce6843ed47a35abfd86a7eb47db3934a156d08d6513dc8dcb58d461b0dd6f39"}, + {file = "flax-0.6.11.tar.gz", hash = "sha256:ecedf179ceb16c0b511982a293834bb13086168dce1dff697ac083efa818fc72"}, ] [package.dependencies] @@ -1271,6 +1327,7 @@ testing = ["atari-py (==0.2.5)", "clu", "einops", "gym (==0.18.3)", "jaxlib", "j name = "fonttools" version = "4.39.4" description = "Tools to manipulate font files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1296,6 +1353,7 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] name = "fqdn" version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" +category = "main" optional = true python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" files = [ @@ -1307,6 +1365,7 @@ files = [ name = "frozendict" version = "2.3.8" description = "A simple immutable dictionary" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1353,6 +1412,7 @@ files = [ name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1436,6 +1496,7 @@ files = [ name = "fsspec" version = "2023.5.0" description = "File-system specification" +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -1475,6 +1536,7 @@ tqdm = ["tqdm"] name = "gast" version = "0.4.0" description = "Python AST that abstracts the underlying Python version" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1486,6 +1548,7 @@ files = [ name = "google-auth" version = "2.18.0" description = "Google Authentication Library" +category = "main" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" files = [ @@ -1511,6 +1574,7 @@ requests = ["requests (>=2.20.0,<3.0.0dev)"] name = "google-auth-oauthlib" version = "1.0.0" description = "Google Authentication Library" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1529,6 +1593,7 @@ tool = ["click (>=6.0.0)"] name = "google-pasta" version = "0.2.0" description = "pasta is an AST-based Python refactoring library" +category = "main" optional = false python-versions = "*" files = [ @@ -1544,6 +1609,7 @@ six = "*" name = "googleapis-common-protos" version = "1.59.0" description = "Common protobufs used in Google APIs" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1561,6 +1627,7 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] name = "greenlet" version = "2.0.2" description = "Lightweight in-process concurrent programming" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*" files = [ @@ -1634,6 +1701,7 @@ test = ["objgraph", "psutil"] name = "grpcio" version = "1.54.0" description = "HTTP/2-based RPC framework" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1691,6 +1759,7 @@ protobuf = ["grpcio-tools (>=1.54.0)"] name = "h5py" version = "3.8.0" description = "Read and write HDF5 files from Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1728,6 +1797,7 @@ numpy = ">=1.14.5" name = "html5lib" version = "1.1" description = "HTML parser based on the WHATWG HTML specification" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1749,6 +1819,7 @@ lxml = ["lxml"] name = "huggingface-hub" version = "0.14.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -1780,6 +1851,7 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "hydra-core" version = "1.3.2" description = "A framework for elegantly configuring complex applications" +category = "main" optional = true python-versions = "*" files = [ @@ -1788,7 +1860,7 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = "==4.9.*" +antlr4-python3-runtime = ">=4.9.0,<4.10.0" importlib-resources = {version = "*", markers = "python_version < \"3.9\""} omegaconf = ">=2.2,<2.4" packaging = "*" @@ -1797,6 +1869,7 @@ packaging = "*" name = "identify" version = "2.5.24" description = "File identification library for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1811,6 +1884,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1822,6 +1896,7 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1833,6 +1908,7 @@ files = [ name = "importlib-metadata" version = "4.13.0" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1852,6 +1928,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.12.0" description = "Read resources from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1870,6 +1947,7 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1881,6 +1959,7 @@ files = [ name = "ipykernel" version = "6.23.0" description = "IPython Kernel for Jupyter" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1894,7 +1973,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -1914,6 +1993,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "8.12.2" description = "IPython: Productive Interactive Computing" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1953,6 +2033,7 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" +category = "main" optional = true python-versions = "*" files = [ @@ -1964,6 +2045,7 @@ files = [ name = "ipywidgets" version = "8.0.6" description = "Jupyter interactive widgets" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1985,6 +2067,7 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] name = "isoduration" version = "20.11.0" description = "Operations with ISO 8601 durations" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1997,15 +2080,17 @@ arrow = ">=0.15.0" [[package]] name = "jax" -version = "0.4.10" +version = "0.4.13" description = "Differentiate, compile, and transform Numpy code." +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jax-0.4.10.tar.gz", hash = "sha256:1bf0f2720f778f2937301a16a4d5cd3497f13a4d6c970c24a88918a81816a888"}, + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, ] [package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} ml_dtypes = ">=0.1.0" numpy = ">=1.21" opt_einsum = "*" @@ -2013,37 +2098,40 @@ scipy = ">=1.7" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.9)"] -cpu = ["jaxlib (==0.4.10)"] -cuda = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-cudnn82 = ["jaxlib (==0.4.10+cuda11.cudnn82)"] -cuda11-cudnn86 = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-local = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-pip = ["jaxlib (==0.4.10+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.6)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] -cuda12-local = ["jaxlib (==0.4.10+cuda12.cudnn88)"] -cuda12-pip = ["jaxlib (==0.4.10+cuda12.cudnn88)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] -minimum-jaxlib = ["jaxlib (==0.4.7)"] -tpu = ["jaxlib (==0.4.10)", "libtpu-nightly (==0.1.dev20230511)", "requests"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] [[package]] name = "jaxlib" -version = "0.4.10" +version = "0.4.13" description = "XLA library for JAX" +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jaxlib-0.4.10-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:0814c478382e82b0f90aacd820fb898c4a9caa705a1f515c5fd0928198c814f3"}, - {file = "jaxlib-0.4.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:97e7b2b0f32debbb011556cb2dc82cdfb0087b618e302f92319475727408a64e"}, - {file = "jaxlib-0.4.10-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:f46edb93332285ab6f57b2843869183cbd495b4f35bea0fba25a3766a7429306"}, - {file = "jaxlib-0.4.10-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1b705e495149945defe478781865d403bd3994c11e326829aea7aafda0dfa639"}, - {file = "jaxlib-0.4.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:377757745d5e2097fccce71c31292973d544a36329b7ed85bf9c41837e107f74"}, - {file = "jaxlib-0.4.10-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a6349c98c3ffd879b390a3532390e8e49f084aa523c1553aa5c21374ca8b4ea9"}, - {file = "jaxlib-0.4.10-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:7dc9c89b2b07cf8c576d5fca433181f324fed52e51db60873d2b6d3e496588e2"}, - {file = "jaxlib-0.4.10-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62f3d2bad0476bb6728d1be813894cf3421a3d31706a0208b1f57eec86d310d5"}, - {file = "jaxlib-0.4.10-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:6ea7ad6b520732994e25429768d6bd731d55c59c75ef6f9faa2f59e419fb0ada"}, - {file = "jaxlib-0.4.10-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:749a1135a452db1afb4e5de7770fc5dafebb310c35d9db077ed925fcab028471"}, - {file = "jaxlib-0.4.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c966b13467c41ff44ba1b3b7cdceb37a76a75f0420f454a8a51543f8bbaabe4a"}, - {file = "jaxlib-0.4.10-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:07557fcf1e4c7c60bbb48c4f4f426909fcf610a7bfa56cbb139719ba3900722d"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac"}, + {file = "jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8"}, + {file = "jaxlib-0.4.13-cp310-cp310-win_amd64.whl", hash = "sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649"}, + {file = "jaxlib-0.4.13-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89"}, + {file = "jaxlib-0.4.13-cp311-cp311-win_amd64.whl", hash = "sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0"}, + {file = "jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc"}, + {file = "jaxlib-0.4.13-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5"}, + {file = "jaxlib-0.4.13-cp39-cp39-win_amd64.whl", hash = "sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e"}, ] [package.dependencies] @@ -2051,10 +2139,15 @@ ml-dtypes = ">=0.1.0" numpy = ">=1.21" scipy = ">=1.7" +[package.extras] +cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] + [[package]] name = "jedi" version = "0.18.2" description = "An autocompletion tool for Python that can be used for text editors." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2074,6 +2167,7 @@ testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2091,6 +2185,7 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2102,6 +2197,7 @@ files = [ name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2113,6 +2209,7 @@ files = [ name = "jsonpointer" version = "2.3" description = "Identify specific nodes in a JSON document (RFC 6901)" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2124,6 +2221,7 @@ files = [ name = "jsonschema" version = "4.17.3" description = "An implementation of JSON Schema validation for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2153,6 +2251,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter" version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." +category = "main" optional = true python-versions = "*" files = [ @@ -2173,6 +2272,7 @@ qtconsole = "*" name = "jupyter-cache" version = "0.6.1" description = "A defined interface for working with a cache of jupyter notebooks." +category = "dev" optional = false python-versions = "~=3.8" files = [ @@ -2200,6 +2300,7 @@ testing = ["coverage", "ipykernel", "jupytext", "matplotlib", "nbdime", "nbforma name = "jupyter-client" version = "8.2.0" description = "Jupyter protocol implementation and client libraries" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2209,7 +2310,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -2223,6 +2324,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-console" version = "6.6.3" description = "Jupyter terminal console" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2234,7 +2336,7 @@ files = [ ipykernel = ">=6.14" ipython = "*" jupyter-client = ">=7.0.0" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" prompt-toolkit = ">=3.0.30" pygments = "*" pyzmq = ">=17" @@ -2247,6 +2349,7 @@ test = ["flaky", "pexpect", "pytest"] name = "jupyter-core" version = "5.3.0" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2267,6 +2370,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-events" version = "0.6.3" description = "Jupyter Event System library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2291,6 +2395,7 @@ test = ["click", "coverage", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>= name = "jupyter-server" version = "2.5.0" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -2303,7 +2408,7 @@ anyio = ">=3.1.0" argon2-cffi = "*" jinja2 = "*" jupyter-client = ">=7.4.4" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" jupyter-events = ">=0.4.0" jupyter-server-terminals = "*" nbconvert = ">=6.4.4" @@ -2326,6 +2431,7 @@ test = ["ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", " name = "jupyter-server-terminals" version = "0.4.4" description = "A Jupyter Server Extension Providing Terminals." +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -2345,6 +2451,7 @@ test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2356,6 +2463,7 @@ files = [ name = "jupyterlab-widgets" version = "3.0.7" description = "Jupyter interactive widgets for JupyterLab" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2367,6 +2475,7 @@ files = [ name = "jupytext" version = "1.14.5" description = "Jupyter notebooks as Markdown documents, Julia, Python or R scripts" +category = "dev" optional = false python-versions = "~=3.6" files = [ @@ -2389,6 +2498,7 @@ toml = ["toml"] name = "keras" version = "2.12.0" description = "Deep learning for humans." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2399,6 +2509,7 @@ files = [ name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2476,6 +2587,7 @@ files = [ name = "libclang" version = "16.0.0" description = "Clang Python Bindings, mirrored from the official LLVM repo: https://github.com/llvm/llvm-project/tree/main/clang/bindings/python, to make the installation process easier." +category = "main" optional = false python-versions = "*" files = [ @@ -2493,6 +2605,7 @@ files = [ name = "lxml" version = "4.9.2" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" files = [ @@ -2585,6 +2698,7 @@ source = ["Cython (>=0.29.7)"] name = "markdown" version = "3.4.3" description = "Python implementation of John Gruber's Markdown." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2602,6 +2716,7 @@ testing = ["coverage", "pyyaml"] name = "markdown-it-py" version = "2.2.0" description = "Python port of markdown-it. Markdown parsing, done right!" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2626,6 +2741,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.2" description = "Safely add untrusted strings to HTML/XML markup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2685,6 +2801,7 @@ files = [ name = "matplotlib" version = "3.7.1" description = "Python plotting package" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2747,6 +2864,7 @@ python-dateutil = ">=2.7" name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -2761,6 +2879,7 @@ traitlets = "*" name = "mdit-py-plugins" version = "0.3.5" description = "Collection of plugins for markdown-it-py" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2780,6 +2899,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2791,6 +2911,7 @@ files = [ name = "mistune" version = "2.0.5" description = "A sane Markdown parser with useful plugins and renderers" +category = "main" optional = true python-versions = "*" files = [ @@ -2802,6 +2923,7 @@ files = [ name = "ml-dtypes" version = "0.1.0" description = "" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2837,6 +2959,7 @@ dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] name = "msgpack" version = "1.0.5" description = "MessagePack serializer" +category = "main" optional = false python-versions = "*" files = [ @@ -2909,6 +3032,7 @@ files = [ name = "multidict" version = "6.0.4" description = "multidict implementation" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2992,6 +3116,7 @@ files = [ name = "multiprocess" version = "0.70.14" description = "better multiprocessing and multithreading in python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3018,6 +3143,7 @@ dill = ">=0.3.6" name = "multitasking" version = "0.0.11" description = "Non-blocking Python methods using decorators" +category = "dev" optional = false python-versions = "*" files = [ @@ -3029,6 +3155,7 @@ files = [ name = "myst-nb" version = "0.17.2" description = "A Jupyter Notebook Sphinx reader built on top of the MyST markdown parser." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3057,6 +3184,7 @@ testing = ["beautifulsoup4", "coverage (>=6.4,<8.0)", "ipykernel (>=5.5,<6.0)", name = "myst-parser" version = "0.18.1" description = "An extended commonmark compliant parser, with bridges to docutils & sphinx." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3083,6 +3211,7 @@ testing = ["beautifulsoup4", "coverage[toml]", "pytest (>=6,<7)", "pytest-cov", name = "nbclassic" version = "1.0.0" description = "Jupyter Notebook as a Jupyter Server extension." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3118,6 +3247,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-jupyter", "pytest-p name = "nbclient" version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3127,7 +3257,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" nbformat = ">=5.1" traitlets = ">=5.3" @@ -3140,6 +3270,7 @@ test = ["flaky", "ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "p name = "nbconvert" version = "7.4.0" description = "Converting Jupyter Notebooks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3178,6 +3309,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.8.0" description = "The Jupyter Notebook format" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3199,6 +3331,7 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nbsphinx" version = "0.8.12" description = "Jupyter Notebook Tools for Sphinx" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3218,6 +3351,7 @@ traitlets = ">=5" name = "nbsphinx-link" version = "1.3.0" description = "A sphinx extension for including notebook files outside sphinx source root" +category = "main" optional = true python-versions = "*" files = [ @@ -3233,6 +3367,7 @@ sphinx = ">=1.8" name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -3244,6 +3379,7 @@ files = [ name = "nodeenv" version = "1.7.0" description = "Node.js virtual environment builder" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -3258,6 +3394,7 @@ setuptools = "*" name = "notebook" version = "6.5.4" description = "A web-based notebook environment for interactive computing" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3292,6 +3429,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixs name = "notebook-shim" version = "0.2.3" description = "A shim layer for notebook traits and config" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3309,6 +3447,7 @@ test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync" name = "numpy" version = "1.23.5" description = "NumPy is the fundamental package for array computing with Python." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3346,6 +3485,7 @@ files = [ name = "oauthlib" version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3362,6 +3502,7 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] name = "omegaconf" version = "2.3.0" description = "A flexible configuration library" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3370,13 +3511,14 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = "==4.9.*" +antlr4-python3-runtime = ">=4.9.0,<4.10.0" PyYAML = ">=5.1.0" [[package]] name = "openpyxl" version = "3.1.2" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3391,6 +3533,7 @@ et-xmlfile = "*" name = "opt-einsum" version = "3.3.0" description = "Optimizing numpys einsum function" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -3409,6 +3552,7 @@ tests = ["pytest", "pytest-cov", "pytest-pep8"] name = "optax" version = "0.1.5" description = "A gradient processing and optimisation library in JAX." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3427,6 +3571,7 @@ numpy = ">=1.18.0" name = "orbax-checkpoint" version = "0.2.2" description = "Orbax Checkpoint" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3455,6 +3600,7 @@ dev = ["flax", "pytest", "pytest-xdist"] name = "packaging" version = "23.1" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3466,6 +3612,7 @@ files = [ name = "pandas" version = "2.0.1" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3532,6 +3679,7 @@ xml = ["lxml (>=4.6.3)"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3543,6 +3691,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3558,6 +3707,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathos" version = "0.3.0" description = "parallel graph management and execution in heterogeneous computing" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3575,6 +3725,7 @@ ppft = ">=1.7.6.6" name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "main" optional = false python-versions = "*" files = [ @@ -3589,6 +3740,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "main" optional = false python-versions = "*" files = [ @@ -3600,6 +3752,7 @@ files = [ name = "pillow" version = "9.5.0" description = "Python Imaging Library (Fork)" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3679,6 +3832,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3690,6 +3844,7 @@ files = [ name = "platformdirs" version = "3.5.1" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3705,6 +3860,7 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3720,6 +3876,7 @@ testing = ["pytest", "pytest-benchmark"] name = "pox" version = "0.3.2" description = "utilities for filesystem exploration and automated builds" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3731,6 +3888,7 @@ files = [ name = "ppft" version = "1.7.6.6" description = "distributed and parallel python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3745,6 +3903,7 @@ dill = ["dill (>=0.3.6)"] name = "pre-commit" version = "3.3.1" description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3763,6 +3922,7 @@ virtualenv = ">=20.10.0" name = "prometheus-client" version = "0.16.0" description = "Python client for the Prometheus monitoring system." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3777,6 +3937,7 @@ twisted = ["twisted"] name = "promise" version = "2.3" description = "Promises/A+ implementation for Python" +category = "dev" optional = false python-versions = "*" files = [ @@ -3793,6 +3954,7 @@ test = ["coveralls", "futures", "mock", "pytest (>=2.7.3)", "pytest-benchmark", name = "prompt-toolkit" version = "3.0.38" description = "Library for building powerful interactive command lines in Python" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3807,6 +3969,7 @@ wcwidth = "*" name = "protobuf" version = "3.20.3" description = "Protocol Buffers" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3838,6 +4001,7 @@ files = [ name = "protobuf3-to-dict" version = "0.1.5" description = "Ben Hodgson: A teeny Python library for creating Python dicts from protocol buffers and the reverse. Useful as an intermediate step before serialisation (e.g. to JSON). Kapor: upgrade it to PB3 and PY3, rename it to protobuf3-to-dict" +category = "main" optional = true python-versions = "*" files = [ @@ -3852,6 +4016,7 @@ six = "*" name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3878,6 +4043,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "main" optional = false python-versions = "*" files = [ @@ -3889,6 +4055,7 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" +category = "main" optional = false python-versions = "*" files = [ @@ -3903,6 +4070,7 @@ tests = ["pytest"] name = "pyarrow" version = "12.0.0" description = "Python library for Apache Arrow" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3940,6 +4108,7 @@ numpy = ">=1.16.6" name = "pyasn1" version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3951,6 +4120,7 @@ files = [ name = "pyasn1-modules" version = "0.3.0" description = "A collection of ASN.1-based protocols modules" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3965,6 +4135,7 @@ pyasn1 = ">=0.4.6,<0.6.0" name = "pycparser" version = "2.21" description = "C parser in Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3976,6 +4147,7 @@ files = [ name = "pydata-sphinx-theme" version = "0.12.0" description = "Bootstrap-based Sphinx theme from the PyData community" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4000,6 +4172,7 @@ test = ["pydata-sphinx-theme[doc]", "pytest"] name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4014,6 +4187,7 @@ plugins = ["importlib-metadata"] name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -4028,6 +4202,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.3" description = "Persistent/Functional/Immutable data structures" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4064,6 +4239,7 @@ files = [ name = "pytest" version = "7.3.1" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4086,6 +4262,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "4.0.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4104,6 +4281,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -4118,6 +4296,7 @@ six = ">=1.5" name = "python-json-logger" version = "2.0.7" description = "A python library adding a json log formatter" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4129,6 +4308,7 @@ files = [ name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -4140,6 +4320,7 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" files = [ @@ -4163,6 +4344,7 @@ files = [ name = "pywinpty" version = "2.0.10" description = "Pseudo terminal support for Windows from Python." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4178,6 +4360,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4227,6 +4410,7 @@ files = [ name = "pyzmq" version = "25.0.2" description = "Python bindings for 0MQ" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4316,6 +4500,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "qtconsole" version = "5.4.3" description = "Jupyter Qt console" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -4342,6 +4527,7 @@ test = ["flaky", "pytest", "pytest-qt"] name = "qtpy" version = "2.3.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4359,6 +4545,7 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] name = "regex" version = "2023.5.5" description = "Alternative regular expression module, to replace re." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4456,6 +4643,7 @@ files = [ name = "requests" version = "2.31.0" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4477,6 +4665,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "requests-oauthlib" version = "1.3.1" description = "OAuthlib authentication support for Requests." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -4495,6 +4684,7 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] name = "responses" version = "0.18.0" description = "A utility library for mocking out the `requests` Python library." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4513,6 +4703,7 @@ tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=4.6)", "pytest-cov", name = "rfc3339-validator" version = "0.1.4" description = "A pure python RFC3339 validator" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -4527,6 +4718,7 @@ six = "*" name = "rfc3986-validator" version = "0.1.1" description = "Pure python rfc3986 validator" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -4538,6 +4730,7 @@ files = [ name = "rich" version = "13.3.5" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -4557,6 +4750,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" +category = "main" optional = false python-versions = ">=3.6,<4" files = [ @@ -4571,6 +4765,7 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.6.1" description = "An Amazon S3 Transfer Manager" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -4588,6 +4783,7 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] name = "safetensors" version = "0.3.1" description = "Fast and Safe Tensor serialization" +category = "main" optional = true python-versions = "*" files = [ @@ -4648,6 +4844,7 @@ torch = ["torch (>=1.10)"] name = "sagemaker" version = "2.161.0" description = "Open source library for training and deploying models on Amazon SageMaker." +category = "main" optional = true python-versions = ">= 3.6" files = [ @@ -4683,6 +4880,7 @@ test = ["Jinja2 (==3.0.3)", "PyYAML (==6.0)", "apache-airflow (==2.6.0)", "apach name = "sagemaker-utils" version = "0.3.6" description = "Helper functions to work with SageMaker" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4700,6 +4898,7 @@ yaspin = "*" name = "schema" version = "0.7.5" description = "Simple data validation library" +category = "main" optional = true python-versions = "*" files = [ @@ -4714,6 +4913,7 @@ contextlib2 = ">=0.5.5" name = "scikit-learn" version = "1.2.2" description = "A set of python modules for machine learning and data mining" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4756,6 +4956,7 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( name = "scipy" version = "1.10.1" description = "Fundamental algorithms for scientific computing in Python" +category = "main" optional = false python-versions = "<3.12,>=3.8" files = [ @@ -4794,6 +4995,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "send2trash" version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" +category = "main" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -4810,6 +5012,7 @@ win32 = ["pywin32"] name = "setuptools" version = "67.7.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4826,6 +5029,7 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4837,6 +5041,7 @@ files = [ name = "smdebug-rulesconfig" version = "1.0.1" description = "SMDebug RulesConfig" +category = "main" optional = true python-versions = ">=2.7" files = [ @@ -4848,6 +5053,7 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4859,6 +5065,7 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." +category = "main" optional = false python-versions = "*" files = [ @@ -4870,6 +5077,7 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4881,6 +5089,7 @@ files = [ name = "sphinx" version = "5.3.0" description = "Python documentation generator" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4916,6 +5125,7 @@ test = ["cython", "html5lib", "pytest (>=4.6)", "typed_ast"] name = "sphinx-autodoc-typehints" version = "1.23.0" description = "Type hints (PEP 484) support for the Sphinx autodoc extension" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4935,6 +5145,7 @@ type-comment = ["typed-ast (>=1.5.4)"] name = "sphinx-gallery" version = "0.11.1" description = "A Sphinx extension that builds an HTML version of any Python script and puts it into an examples gallery." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4949,6 +5160,7 @@ sphinx = ">=3" name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4964,6 +5176,7 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -4979,6 +5192,7 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4994,6 +5208,7 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5008,6 +5223,7 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5023,6 +5239,7 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5038,6 +5255,7 @@ test = ["pytest"] name = "sqlalchemy" version = "2.0.13" description = "Database Abstraction Library" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5085,7 +5303,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} typing-extensions = ">=4.2.0" [package.extras] @@ -5115,6 +5333,7 @@ sqlcipher = ["sqlcipher3-binary"] name = "stack-data" version = "0.6.2" description = "Extract data from python stack frames and tracebacks for informative displays" +category = "main" optional = false python-versions = "*" files = [ @@ -5134,6 +5353,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5148,6 +5368,7 @@ widechars = ["wcwidth"] name = "tblib" version = "1.7.0" description = "Traceback serialization library." +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -5159,6 +5380,7 @@ files = [ name = "tensorboard" version = "2.12.3" description = "TensorBoard lets you watch Tensors Flow" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5183,6 +5405,7 @@ wheel = ">=0.26" name = "tensorboard-data-server" version = "0.7.0" description = "Fast data loading for TensorBoard" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5195,6 +5418,7 @@ files = [ name = "tensorflow-cpu" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5240,6 +5464,7 @@ wrapt = ">=1.11.0,<1.15" name = "tensorflow-datasets" version = "4.9.2" description = "tensorflow/datasets is a library of datasets ready to use with TensorFlow." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -5306,6 +5531,7 @@ youtube-vis = ["pycocotools"] name = "tensorflow-estimator" version = "2.12.0" description = "TensorFlow Estimator." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5316,6 +5542,7 @@ files = [ name = "tensorflow-io-gcs-filesystem" version = "0.32.0" description = "TensorFlow IO" +category = "main" optional = false python-versions = ">=3.7, <3.12" files = [ @@ -5346,13 +5573,12 @@ tensorflow-rocm = ["tensorflow-rocm (>=2.12.0,<2.13.0)"] name = "tensorflow-macos" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." +category = "main" optional = false python-versions = ">=3.8" files = [ {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:db464c88e10e927725997f9b872a21c9d057789d3b7e9a26e4ef1af41d0bcc8c"}, {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:172277c33cb1ae0da19f98c5bcd4946149cfa73c8ea05c6ba18365d58dd3c6f2"}, - {file = "tensorflow_macos-2.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:9c9b14fbb73ec4cb0f209722a1489020fd8614c92ae22589f2309c48cefdf21f"}, - {file = "tensorflow_macos-2.12.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:6a54539bd076746f69ae8bef7282f981674fe4dbf59c3a84c4af86ae6bae9d5c"}, {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e3fa53e63672fd71998bbd71cc5478c74dbe5a2d9291d1801c575358c28403c2"}, {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:5499312c21ed3ed47cc6b4cf861896e9564c2c32d8d3c2ef1437c5ca31adfc73"}, {file = "tensorflow_macos-2.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:84cb873c90be63efabfecca53fdc48b734a037d0750532b55cb7ce7c343b5cac"}, @@ -5387,6 +5613,7 @@ wrapt = ">=1.11.0,<1.15" name = "tensorflow-metadata" version = "1.13.1" description = "Library and standards for schema and statistics." +category = "dev" optional = false python-versions = ">=3.8,<4" files = [ @@ -5402,6 +5629,7 @@ protobuf = ">=3.20.3,<5" name = "tensorstore" version = "0.1.36" description = "Read and write large, multi-dimensional arrays" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5431,6 +5659,7 @@ numpy = ">=1.16.0" name = "termcolor" version = "2.3.0" description = "ANSI color formatting for output in terminal" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5445,6 +5674,7 @@ tests = ["pytest", "pytest-cov"] name = "terminado" version = "0.17.1" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5465,6 +5695,7 @@ test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -5476,6 +5707,7 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5494,6 +5726,7 @@ test = ["flake8", "isort", "pytest"] name = "tokenizers" version = "0.13.3" description = "Fast and Customizable Tokenizers" +category = "main" optional = true python-versions = "*" files = [ @@ -5548,6 +5781,7 @@ testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -5559,6 +5793,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5570,6 +5805,7 @@ files = [ name = "toolz" version = "0.12.0" description = "List processing tools and functional utilities" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5581,6 +5817,7 @@ files = [ name = "tornado" version = "6.3.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "main" optional = false python-versions = ">= 3.8" files = [ @@ -5601,6 +5838,7 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5621,6 +5859,7 @@ telegram = ["requests"] name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5636,6 +5875,7 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "transformers" version = "4.30.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -5705,6 +5945,7 @@ vision = ["Pillow"] name = "typing-extensions" version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5716,6 +5957,7 @@ files = [ name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" +category = "main" optional = false python-versions = ">=2" files = [ @@ -5727,6 +5969,7 @@ files = [ name = "uri-template" version = "1.2.0" description = "RFC 6570 URI Template Processor" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -5741,6 +5984,7 @@ dev = ["flake8 (<4.0.0)", "flake8-annotations", "flake8-bugbear", "flake8-commas name = "urllib3" version = "1.26.15" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -5757,6 +6001,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "virtualenv" version = "20.23.0" description = "Virtual Python Environment builder" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5777,6 +6022,7 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "coverage-enable-subprocess name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" +category = "main" optional = false python-versions = "*" files = [ @@ -5788,6 +6034,7 @@ files = [ name = "webcolors" version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5803,6 +6050,7 @@ tests = ["pytest", "pytest-cov"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "main" optional = false python-versions = "*" files = [ @@ -5814,6 +6062,7 @@ files = [ name = "websocket-client" version = "1.5.1" description = "WebSocket client for Python with low level API options" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5830,6 +6079,7 @@ test = ["websockets"] name = "werkzeug" version = "2.3.4" description = "The comprehensive WSGI web application library." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5847,6 +6097,7 @@ watchdog = ["watchdog (>=2.3)"] name = "wheel" version = "0.40.0" description = "A built-package format for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5861,6 +6112,7 @@ test = ["pytest (>=6.0.0)"] name = "widgetsnbextension" version = "4.0.7" description = "Jupyter interactive widgets for Jupyter Notebook" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5872,6 +6124,7 @@ files = [ name = "wrapt" version = "1.14.1" description = "Module for decorators, wrappers and monkey patching." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -5945,6 +6198,7 @@ files = [ name = "xlrd" version = "2.0.1" description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -5961,6 +6215,7 @@ test = ["pytest", "pytest-cov"] name = "xxhash" version = "3.2.0" description = "Python binding for xxHash" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -6068,6 +6323,7 @@ files = [ name = "yarl" version = "1.9.2" description = "Yet another URL library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -6155,6 +6411,7 @@ multidict = ">=4.0" name = "yaspin" version = "2.3.0" description = "Yet Another Terminal Spinner" +category = "main" optional = true python-versions = ">=3.7.2,<4.0.0" files = [ @@ -6169,6 +6426,7 @@ termcolor = ">=2.2,<3.0" name = "yfinance" version = "0.2.18" description = "Download market data from Yahoo! Finance API" +category = "dev" optional = false python-versions = "*" files = [ @@ -6193,6 +6451,7 @@ requests = ">=2.26" name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.7" files = [ diff --git a/tests/fortuna/calib_model/test_calib_model.py b/tests/fortuna/calib_model/test_calib_model.py index 437e3d6c..5d83edfc 100644 --- a/tests/fortuna/calib_model/test_calib_model.py +++ b/tests/fortuna/calib_model/test_calib_model.py @@ -128,7 +128,7 @@ def __init__(self, *args, **kwargs): ) self.class_config_restore = lambda restore_dir: Config( optimizer=Optimizer(n_epochs=3), - checkpointer=Checkpointer(restore_checkpoint_path=restore_dir), + checkpointer=Checkpointer(restore_checkpoint_dir=restore_dir), ) self.reg_config_nodir_nodump = Config( optimizer=Optimizer(n_epochs=3), monitor=Monitor(metrics=(scaled_mse,)) @@ -150,7 +150,7 @@ def __init__(self, *args, **kwargs): ) self.reg_config_restore = lambda restore_dir: Config( optimizer=Optimizer(n_epochs=3), - checkpointer=Checkpointer(restore_checkpoint_path=restore_dir), + checkpointer=Checkpointer(restore_checkpoint_dir=restore_dir), ) def test_dryrun_reg(self): @@ -217,10 +217,10 @@ def test_dryrun_reg(self): ) # load state - calib_reg.load_state(checkpoint_path=tmp_dir) + calib_reg.load_state(checkpoint_dir=tmp_dir) # save state - calib_reg.save_state(checkpoint_path=tmp_dir) + calib_reg.save_state(checkpoint_dir=tmp_dir) # model_editor calib_reg = CalibRegressor( @@ -304,10 +304,10 @@ def test_dryrun_class(self): ) # load state - calib_class.load_state(checkpoint_path=tmp_dir) + calib_class.load_state(checkpoint_dir=tmp_dir) # save state - calib_class.save_state(checkpoint_path=tmp_dir) + calib_class.save_state(checkpoint_dir=tmp_dir) # model_editor calib_class = CalibClassifier(model=model, model_editor=ModelEditor()) diff --git a/tests/fortuna/calib_model/test_output_calib_model.py b/tests/fortuna/calib_model/test_output_calib_model.py index 96b7a3a8..4f180448 100644 --- a/tests/fortuna/calib_model/test_output_calib_model.py +++ b/tests/fortuna/calib_model/test_output_calib_model.py @@ -123,7 +123,7 @@ def __init__(self, *args, **kwargs): self.calib_config_restore = lambda directory, metric: Config( optimizer=Optimizer(n_epochs=3), monitor=Monitor(metrics=(metric,)), - checkpointer=Checkpointer(restore_checkpoint_path=directory), + checkpointer=Checkpointer(restore_checkpoint_dir=directory), ) def test_dryrun_reg_map(self): @@ -178,10 +178,10 @@ def test_dryrun_reg_map(self): ) # load state - calib_model.load_state(checkpoint_path=tmp_dir) + calib_model.load_state(checkpoint_dir=tmp_dir) # save state - calib_model.save_state(checkpoint_path=tmp_dir) + calib_model.save_state(checkpoint_dir=tmp_dir) def test_dryrun_class_map(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -234,7 +234,7 @@ def test_dryrun_class_map(self): ) # load state - calib_model.load_state(checkpoint_path=tmp_dir) + calib_model.load_state(checkpoint_dir=tmp_dir) # save state - calib_model.save_state(checkpoint_path=tmp_dir) + calib_model.save_state(checkpoint_dir=tmp_dir) diff --git a/tests/fortuna/prob_model/test_train.py b/tests/fortuna/prob_model/test_train.py index 05fda2f0..d39f8c16 100755 --- a/tests/fortuna/prob_model/test_train.py +++ b/tests/fortuna/prob_model/test_train.py @@ -1,6 +1,6 @@ import tempfile -import flax.linen as nn +from flax import linen as nn import jax.numpy as jnp import numpy as np import pytest @@ -8,9 +8,11 @@ from fortuna.data.loader import DataLoader from fortuna.metric.classification import accuracy from fortuna.metric.regression import rmse +from fortuna.partitioner.base import Partitioner from fortuna.prob_model import ( CalibConfig, CalibOptimizer, + CalibProcessor, FitConfig, FitMonitor, SNGPPosteriorApproximator, @@ -45,7 +47,7 @@ OUTPUT_DIM = 2 BATCH_SIZE = 8 INPUT_SHAPE = (3,) -N_DATA = 10 +N_DATA = 16 METHODS = { "map": MAPPosteriorApproximator(), @@ -81,14 +83,14 @@ def make_data_loader( def fit_config( - task, restore_path, start_current, save_dir, dump_state, save_n_steps, freeze + task, restore_dir, start_current, save_dir, dump_state, save_n_steps, freeze ): return FitConfig( optimizer=FitOptimizer(n_epochs=3, freeze_fun=freeze), monitor=FitMonitor(metrics=(accuracy if task == "classification" else rmse,)), checkpointer=FitCheckpointer( start_from_current_state=start_current, - restore_checkpoint_path=restore_path, + restore_checkpoint_dir=restore_dir, save_checkpoint_dir=save_dir, dump_state=dump_state, save_every_n_steps=save_n_steps, @@ -96,7 +98,10 @@ def fit_config( ) -calib_config = CalibConfig(optimizer=CalibOptimizer(n_epochs=3)) +calib_config = CalibConfig( + optimizer=CalibOptimizer(n_epochs=3), + processor=CalibProcessor(n_posterior_samples=2), +) def train( @@ -105,7 +110,7 @@ def train( train_data_loader, val_data_loader, calib_data_loader, - restore_path=None, + restore_dir=None, start_current=False, save_dir=None, dump_state=False, @@ -119,7 +124,7 @@ def train( calib_data_loader=calib_data_loader, fit_config=fit_config( task, - restore_path, + restore_dir, start_current, save_dir, dump_state, @@ -147,7 +152,7 @@ def train_and_sample( train_data_loader, val_data_loader, calib_data_loader, - restore_path=None, + restore_dir=None, start_current=False, save_dir=None, dump_state=False, @@ -161,7 +166,7 @@ def train_and_sample( train_data_loader, val_data_loader, calib_data_loader, - restore_path, + restore_dir, start_current, save_dir, dump_state, @@ -173,12 +178,18 @@ def train_and_sample( def define_prob_model(task, method, model_editor=None): + partitioner = Partitioner( + axis_dims={"mp": 2, "fsdp": 1, "dp": 2}, + rules={"l1/kernel": ("fsdp", "mp"), "bn1": ("mp",)}, + ) + if task == "regression": return ProbRegressor( model=MyModel(OUTPUT_DIM), likelihood_log_variance_model=MyModel(OUTPUT_DIM), posterior_approximator=METHODS[method], model_editor=model_editor, + partitioner=partitioner, ) else: return ProbClassifier( @@ -187,6 +198,7 @@ def define_prob_model(task, method, model_editor=None): else MyModelWithSpectralNorm(OUTPUT_DIM), posterior_approximator=METHODS[method], model_editor=model_editor, + partitioner=partitioner, ) @@ -213,7 +225,7 @@ def dryrun_task(task, method): prob_model = define_prob_model(task, method) map_fit_config = fit_config( task, - restore_path=None, + restore_dir=None, start_current=None, save_dir=None, dump_state=False, @@ -252,7 +264,7 @@ def dryrun_task(task, method): with tempfile.TemporaryDirectory() as tmp_dir: map_fit_config = fit_config( task, - restore_path=None, + restore_dir=None, start_current=None, save_dir=None, dump_state=False, @@ -277,14 +289,15 @@ def dryrun_task(task, method): train_data_loader, val_data_loader, calib_data_loader, - restore_path=tmp_dir, + restore_dir=tmp_dir, ) prob_model = define_prob_model(task, method) - prob_model.load_state(tmp_dir) + prob_model.load_state(tmp_dir + "/last") sample(method, prob_model, train_data_loader) prob_model.predictive.log_prob(train_data_loader) + prob_model = define_prob_model(task, method) if method not in ["laplace", "swag"]: train_and_sample( task, @@ -315,7 +328,7 @@ def dryrun_task(task, method): calib_data_loader, save_dir=tmp_dir, dump_state=True, - restore_path=tmp_dir, + restore_dir=tmp_dir, freeze=freeze_fun, ) train_and_sample( @@ -327,7 +340,7 @@ def dryrun_task(task, method): calib_data_loader, save_dir=tmp_dir, dump_state=True, - restore_path=tmp_dir, + restore_dir=tmp_dir, freeze=freeze_fun, ) train_and_sample( @@ -339,7 +352,7 @@ def dryrun_task(task, method): calib_data_loader, map_fit_config=fit_config( task, - restore_path=None, + restore_dir=None, start_current=None, save_dir=None, dump_state=False, diff --git a/tests/fortuna/test_mixin.py b/tests/fortuna/test_mixin.py index 642ab8c7..e834ec5d 100755 --- a/tests/fortuna/test_mixin.py +++ b/tests/fortuna/test_mixin.py @@ -7,10 +7,9 @@ from fortuna.prob_model.posterior.posterior_mixin import WithPosteriorCheckpointingMixin from fortuna.prob_model.posterior.state import PosteriorState -from fortuna.training.mixin import ( - InputValidatorMixin, - WithEarlyStoppingMixin, -) +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin +from fortuna.training.mixins.input_validator import InputValidatorMixin class FakeTrainerWithCheckpointing( diff --git a/tests/fortuna/test_predictive.py b/tests/fortuna/test_predictive.py index b4442919..6fb07c02 100755 --- a/tests/fortuna/test_predictive.py +++ b/tests/fortuna/test_predictive.py @@ -85,8 +85,7 @@ def test_pred_stats(self): assert log_probs.shape == (self.n_inputs,) sample = self.prob_class.predictive.sample( - self.class_inputs_loader, - n_target_samples=self.n_post_samples, + self.class_inputs_loader, n_target_samples=self.n_post_samples ) assert sample.shape == ( self.n_post_samples, @@ -94,8 +93,7 @@ def test_pred_stats(self): ) sample = self.prob_reg.predictive.sample( - self.reg_inputs_loader, - n_target_samples=self.n_post_samples, + self.reg_inputs_loader, n_target_samples=self.n_post_samples ) assert sample.shape == (self.n_post_samples, self.n_inputs, self.output_dim) diff --git a/tests/fortuna/test_trainer.py b/tests/fortuna/test_trainer.py index fd243915..a6b81fa1 100755 --- a/tests/fortuna/test_trainer.py +++ b/tests/fortuna/test_trainer.py @@ -568,7 +568,7 @@ def test_should_perform_validation(self): self.assertTrue(trainer.should_perform_validation({}, 10)) def test__validation_loop(self): - validation_dataloader = [ + validation_data_loader = [ [jnp.array([[0, 0.0, 0.0], [0, 0.0, 0]]), jnp.array([0.0, 0.0])], [jnp.array([[0.1, 0.0, 10], [0, 0.0, 0]]), jnp.array([1.0, 0.0])], ] @@ -580,7 +580,7 @@ def test__validation_loop(self): observed_validation_epoch_metrics_str, ) = trainer._validation_loop( state=None, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=2, loss_fun=lambda x: x, rng=jax.random.PRNGKey(0), @@ -600,7 +600,7 @@ def test__validation_loop(self): observed_validation_epoch_metrics_str, ) = trainer._validation_loop( state=None, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=2, loss_fun=lambda x: x, rng=jax.random.PRNGKey(0), @@ -618,7 +618,7 @@ def test__validation_loop(self): ) def test__training_loop(self): - training_dataloader = [ + training_data_loader = [ [jnp.array([[0, 0.0, 0.0], [0, 0.0, 0]]), jnp.array([0.0, 0.0])], [jnp.array([[0.1, 0.0, 10], [0, 0.0, 0]]), jnp.array([1.0, 0.0])], ] @@ -635,7 +635,7 @@ def test__training_loop(self): metrics=(accuracy,), rng=jax.random.PRNGKey(0), state=FakeTrainState(), - training_dataloader=training_dataloader, + training_data_loader=training_data_loader, training_dataset_size=2, training_kwargs=FrozenDict({}), unravel=None, diff --git a/tests/make_model.py b/tests/make_model.py index 1469ea63..dec82faf 100644 --- a/tests/make_model.py +++ b/tests/make_model.py @@ -1,3 +1,5 @@ +from functools import partial + import flax.linen as nn import jax.numpy as jnp @@ -7,6 +9,7 @@ class MyModel(nn.Module): output_dim: int dense: nn.Module = nn.Dense + dtype: str = "float32" @nn.compact def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: @@ -14,8 +17,16 @@ def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: dense = self.spectral_norm(self.dense, train=train) else: dense = self.dense + norm = partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) x = x.reshape(x.shape[0], -1) - x = dense(2, name="l1")(x) + x = dense(4, name="l1")(x) + x = norm(name="bn1")(x) x = nn.Dropout(rate=0.9)(x, deterministic=not train) x = dense(self.output_dim, name="l2")(x) return x