From 3c832901866e38d0e3582cfd8b16373c400e70c7 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Sun, 2 Jul 2023 23:35:40 +0200 Subject: [PATCH] pre-commit --- .../prob_model_text_classification.py | 10 +- fortuna/calib_model/base.py | 4 +- fortuna/calib_model/calib_model_calibrator.py | 2 +- fortuna/likelihood/base.py | 1 + fortuna/model/llama.py | 486 +++++++++++------- fortuna/model/model_manager/base.py | 2 +- fortuna/output_calib_model/base.py | 4 +- .../output_calib_manager/base.py | 4 +- fortuna/partitioner/base.py | 23 +- fortuna/partitioner/partition_manager/base.py | 68 ++- fortuna/prob_model/base.py | 4 +- fortuna/prob_model/classification.py | 10 +- fortuna/prob_model/fit_config/checkpointer.py | 6 +- fortuna/prob_model/joint/base.py | 14 +- fortuna/prob_model/posterior/base.py | 28 +- .../prob_model/posterior/map/map_posterior.py | 58 ++- .../prob_model/posterior/posterior_mixin.py | 5 +- .../posterior/posterior_state_repository.py | 4 +- fortuna/prob_model/predictive/base.py | 4 +- fortuna/prob_model/regression.py | 10 +- fortuna/training/mixins/checkpointing.py | 66 ++- fortuna/training/output_calibrator.py | 7 +- fortuna/training/train_state_repository.py | 21 +- fortuna/training/trainer.py | 22 +- fortuna/utils/checkpoint.py | 20 +- fortuna/utils/mesh.py | 30 +- fortuna/utils/nested_dicts.py | 13 +- fortuna/utils/partition.py | 38 +- fortuna/utils/port.py | 2 +- 29 files changed, 619 insertions(+), 347 deletions(-) diff --git a/benchmarks/transformers/prob_model_text_classification.py b/benchmarks/transformers/prob_model_text_classification.py index 1493bba8..946e3deb 100644 --- a/benchmarks/transformers/prob_model_text_classification.py +++ b/benchmarks/transformers/prob_model_text_classification.py @@ -400,11 +400,17 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: model_editor = None if args.enable_probit_model_editor: - probit_freeze_fun = lambda p, v: True if "classifier" in p else False if args.probit_last_layer_only else None + probit_freeze_fun = ( + lambda p, v: True + if "classifier" in p + else False + if args.probit_last_layer_only + else None + ) model_editor = ProbitModelEditor( freeze_fun=probit_freeze_fun, init_log_var=args.probit_init_log_var, - stop_gradient=args.probit_stop_gradient + stop_gradient=args.probit_stop_gradient, ) ### TRAINING diff --git a/fortuna/calib_model/base.py b/fortuna/calib_model/base.py index 408dec16..37055efa 100644 --- a/fortuna/calib_model/base.py +++ b/fortuna/calib_model/base.py @@ -176,9 +176,7 @@ def load_state(self, checkpoint_dir: Path) -> None: ) self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_dir) - def save_state( - self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: return self.predictive.state.put( self.predictive.state.get(), checkpoint_dir=checkpoint_dir, diff --git a/fortuna/calib_model/calib_model_calibrator.py b/fortuna/calib_model/calib_model_calibrator.py index ab48441d..5425119c 100644 --- a/fortuna/calib_model/calib_model_calibrator.py +++ b/fortuna/calib_model/calib_model_calibrator.py @@ -13,9 +13,9 @@ from optax._src.base import PyTree from fortuna.calib_model.state import CalibState -from fortuna.training.trainer import TrainerABC from fortuna.training.mixins.jitted import JittedMixin from fortuna.training.mixins.multi_device import MultiDeviceMixin +from fortuna.training.trainer import TrainerABC from fortuna.typing import ( Array, Batch, diff --git a/fortuna/likelihood/base.py b/fortuna/likelihood/base.py index 67c4af21..fd32b49a 100644 --- a/fortuna/likelihood/base.py +++ b/fortuna/likelihood/base.py @@ -7,6 +7,7 @@ Tuple, Union, ) + from flax.core import FrozenDict from jax import ( jit, diff --git a/fortuna/model/llama.py b/fortuna/model/llama.py index f1bde63b..49e2427d 100644 --- a/fortuna/model/llama.py +++ b/fortuna/model/llama.py @@ -1,101 +1,123 @@ +from functools import partial +import json import os from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union -import json import tempfile -from functools import partial - -import numpy as np +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Union, +) + +from flax.core.frozen_dict import ( + FrozenDict, + freeze, + unfreeze, +) +import flax.linen as nn +from flax.linen import ( + combine_masks, + make_causal_mask, + partitioning as nn_partitioning, +) +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import ( + flatten_dict, + unflatten_dict, +) import jax -import jax.numpy as jnp from jax import lax +import jax.numpy as jnp from jax.sharding import PartitionSpec as PS -import flax.linen as nn -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from flax.linen import partitioning as nn_partitioning - +import numpy as np from transformers.configuration_utils import PretrainedConfig -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxCausalLMOutput, +) from transformers.modeling_flax_utils import FlaxPreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging - +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) LLAMA_STANDARD_CONFIGS = { - '3b': { - 'vocab_size': 32000, - 'hidden_size': 3200, - 'intermediate_size': 8640, - 'num_hidden_layers': 26, - 'num_attention_heads': 32, - 'max_sequence_length': 2048, - 'initializer_range': 0.02, - 'rms_norm_eps': 1e-6, - 'use_cache': True, - 'tie_word_embeddings': False, + "3b": { + "vocab_size": 32000, + "hidden_size": 3200, + "intermediate_size": 8640, + "num_hidden_layers": 26, + "num_attention_heads": 32, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, }, - '7b': { - 'vocab_size': 32000, - 'hidden_size': 4096, - 'intermediate_size': 11008, - 'num_hidden_layers': 32, - 'num_attention_heads': 32, - 'max_sequence_length': 2048, - 'initializer_range': 0.02, - 'rms_norm_eps': 1e-6, - 'use_cache': True, - 'tie_word_embeddings': False, + "7b": { + "vocab_size": 32000, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, }, - '13b': { - 'vocab_size': 32000, - 'hidden_size': 5120, - 'intermediate_size': 13824, - 'num_hidden_layers': 40, - 'num_attention_heads': 40, - 'max_sequence_length': 2048, - 'initializer_range': 0.02, - 'rms_norm_eps': 1e-6, - 'use_cache': True, - 'tie_word_embeddings': False, + "13b": { + "vocab_size": 32000, + "hidden_size": 5120, + "intermediate_size": 13824, + "num_hidden_layers": 40, + "num_attention_heads": 40, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, }, - '30b': { - 'vocab_size': 32000, - 'hidden_size': 6656, - 'intermediate_size': 17920, - 'num_hidden_layers': 60, - 'num_attention_heads': 52, - 'max_sequence_length': 2048, - 'initializer_range': 0.02, - 'rms_norm_eps': 1e-6, - 'use_cache': True, - 'tie_word_embeddings': False, + "30b": { + "vocab_size": 32000, + "hidden_size": 6656, + "intermediate_size": 17920, + "num_hidden_layers": 60, + "num_attention_heads": 52, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, }, - '65b': { - 'vocab_size': 32000, - 'hidden_size': 8192, - 'intermediate_size': 22016, - 'num_hidden_layers': 80, - 'num_attention_heads': 64, - 'max_sequence_length': 2048, - 'initializer_range': 0.02, - 'rms_norm_eps': 1e-5, - 'use_cache': True, - 'tie_word_embeddings': False, + "65b": { + "vocab_size": 32000, + "hidden_size": 8192, + "intermediate_size": 22016, + "num_hidden_layers": 80, + "num_attention_heads": 64, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-5, + "use_cache": True, + "tie_word_embeddings": False, }, - 'debug': { # A small model for debugging - 'vocab_size': 32000, - 'hidden_size': 128, - 'intermediate_size': 256, - 'num_hidden_layers': 2, - 'num_attention_heads': 4, - 'max_sequence_length': 2048, - 'initializer_range': 0.02, - 'rms_norm_eps': 1e-6, - 'use_cache': True, - 'tie_word_embeddings': False, + "debug": { # A small model for debugging + "vocab_size": 32000, + "hidden_size": 128, + "intermediate_size": 256, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "max_sequence_length": 2048, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "tie_word_embeddings": False, }, } @@ -162,9 +184,9 @@ def __init__( embd_pdrop=0.0, attn_pdrop=0.0, tie_word_embeddings=False, - remat_block='nothing_saveable', - remat_attention='', - remat_mlp='', + remat_block="nothing_saveable", + remat_attention="", + remat_mlp="", fcm_min_ratio=0.0, fcm_max_ratio=0.0, **kwargs, @@ -197,13 +219,13 @@ def __init__( class RMSNorm(nn.Module): dim: int - eps: float=1e-6 - dtype: jnp.dtype=jnp.float32 - param_dtype: jnp.dtype=jnp.float32 + eps: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.weight = self.param( - 'kernel', + "kernel", nn.initializers.ones, (self.dim,), self.param_dtype, @@ -219,7 +241,9 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return output * weight -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32) -> jnp.ndarray: +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32 +) -> jnp.ndarray: freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) t = np.arange(end) # type: ignore freqs = np.outer(t, freqs).astype(dtype) # type: ignore @@ -234,7 +258,6 @@ def apply_rotary_emb( freqs_cis: jnp.ndarray, dtype: jnp.dtype = jnp.float32, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) @@ -245,10 +268,14 @@ def apply_rotary_emb( freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) xq_out = xq_ * freqs_cis - xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape( + *xq_out.shape[:-1], -1 + ) xk_out = xk_ * freqs_cis - xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape( + *xk_out.shape[:-1], -1 + ) return xq_out.astype(dtype), xk_out.astype(dtype) @@ -266,7 +293,7 @@ def setup(self): self.head_dim = self.embed_dim // self.num_heads self.wq = nn.Dense( - config.num_attention_heads*self.head_dim, + config.num_attention_heads * self.head_dim, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, @@ -274,7 +301,7 @@ def setup(self): precision=self.precision, ) self.wk = nn.Dense( - config.num_attention_heads*self.head_dim, + config.num_attention_heads * self.head_dim, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, @@ -282,7 +309,7 @@ def setup(self): precision=self.precision, ) self.wv = nn.Dense( - config.num_attention_heads*self.head_dim, + config.num_attention_heads * self.head_dim, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, @@ -300,7 +327,9 @@ def setup(self): self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool") + self.causal_mask = make_causal_mask( + jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool" + ) self.freqs_cis = precompute_freqs_cis( self.head_dim, @@ -309,7 +338,9 @@ def setup(self): ) def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + return hidden_states.reshape( + hidden_states.shape[:2] + (self.num_heads, self.head_dim) + ) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) @@ -323,9 +354,15 @@ def _concatenate_to_cache(self, key, value, query, attention_mask): """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + cached_key = self.variable( + "cache", "cached_key", jnp.zeros, key.shape, key.dtype + ) + cached_value = self.variable( + "cache", "cached_value", jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32) + ) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape @@ -356,7 +393,11 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) + xq, xk, xv = ( + self.wq(hidden_states), + self.wk(hidden_states), + self.wv(hidden_states), + ) xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp")) xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp")) @@ -376,15 +417,21 @@ def __call__( mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + self.causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length), ) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size,) + causal_mask.shape[1:] + ) - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape + ) attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) dropout_rng = None @@ -394,13 +441,17 @@ def __call__( # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: - xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask) + xk, xv, attention_mask = self._concatenate_to_cache( + xk, xv, xq, attention_mask + ) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype( + self.dtype + ), ) # usual dot product attention @@ -414,9 +465,13 @@ def __call__( dtype=jnp.promote_types(self.dtype, jnp.float32), precision=self.precision, ) - attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None)) + attn_weights = with_sharding_constraint( + attn_weights, PS(("dp", "fsdp"), "mp", None, None) + ) - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision) + attn_output = jnp.einsum( + "...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision + ) attn_output = self._merge_heads(attn_output) attn_output = self.wo(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) @@ -426,9 +481,9 @@ def __call__( class FlaxLLaMAMLP(nn.Module): config: LLaMAConfig - dtype: jnp.dtype=jnp.float32 - param_dtype: jnp.dtype=jnp.float32 - precision: Optional[Union[jax.lax.Precision, str]]=None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self) -> None: config = self.config @@ -467,22 +522,24 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: class FlaxLLaMABlock(nn.Module): config: LLaMAConfig - dtype: jnp.dtype=jnp.float32 - param_dtype: jnp.dtype=jnp.float32 - precision: Optional[Union[jax.lax.Precision, str]]=None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self) -> None: attention_module = FlaxLLaMAAttention mlp_module = FlaxLLaMAMLP - if self.config.remat_attention != '': + if self.config.remat_attention != "": attention_module = remat( - FlaxLLaMAAttention, static_argnums=(3, 4, 5), - policy=get_gradient_checkpoint_policy(self.config.remat_attention) + FlaxLLaMAAttention, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_attention), ) - if self.config.remat_mlp != '': + if self.config.remat_mlp != "": mlp_module = remat( - FlaxLLaMAMLP, static_argnums=(1,), - policy=get_gradient_checkpoint_policy(self.config.remat_mlp) + FlaxLLaMAMLP, + static_argnums=(1,), + policy=get_gradient_checkpoint_policy(self.config.remat_mlp), ) self.attention = attention_module( @@ -561,13 +618,24 @@ def __init__( **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + super().__init__( + config, + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init, + ) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + def init_weights( + self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None + ) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape + ) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} @@ -584,7 +652,9 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return_dict=False, ) else: - module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, position_ids, return_dict=False + ) random_params = module_init_outputs["params"] @@ -610,10 +680,17 @@ def init_cache(self, batch_size, max_length): # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape + ) init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + jax.random.PRNGKey(0), + input_ids, + attention_mask, + position_ids, + return_dict=False, + init_cache=True, ) return init_variables["cache"] @@ -631,19 +708,31 @@ def __call__( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict ) - return_dict = return_dict if return_dict is not None else self.config.return_dict batch_size, sequence_length = input_ids.shape if position_ids is None: if past_key_values is not None: - raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") + raise ValueError( + "Make sure to provide `position_ids` when passing `past_key_values`." + ) - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -691,15 +780,16 @@ def __call__( class FlaxLLaMABlockCollection(nn.Module): config: LLaMAConfig dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype=jnp.float32 - precision: Optional[Union[jax.lax.Precision, str]]=None + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): block = FlaxLLaMABlock - if self.config.remat_block != '': + if self.config.remat_block != "": block = remat( - FlaxLLaMABlock, static_argnums=(3, 4, 5), - policy=get_gradient_checkpoint_policy(self.config.remat_block) + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), ) self.blocks = [ block( @@ -707,8 +797,9 @@ def setup(self): name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, - precision=self.precision - ) for i in range(self.config.num_hidden_layers) + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) ] def __call__( @@ -729,16 +820,19 @@ def __call__( # Apply forgetful causal mask batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] fcm_ratio = jax.random.uniform( - self.make_rng('fcm'), shape=(batch_size, 1, 1, 1), + self.make_rng("fcm"), + shape=(batch_size, 1, 1, 1), minval=self.config.fcm_min_ratio, - maxval=self.config.fcm_max_ratio + maxval=self.config.fcm_max_ratio, + ) + fcm_mask = ( + jax.random.uniform( + self.make_rng("fcm"), shape=(batch_size, 1, seq_length, seq_length) + ) + > fcm_ratio ) - fcm_mask = jax.random.uniform( - self.make_rng('fcm'), - shape=(batch_size, 1, seq_length, seq_length) - ) > fcm_ratio fcm_mask = fcm_mask.at[:, :, :, 0].set(True) - fcm_mask = fcm_mask.astype('bool') + fcm_mask = fcm_mask.astype("bool") else: fcm_mask = None @@ -769,8 +863,8 @@ def __call__( class FlaxLLaMAModule(nn.Module): config: LLaMAConfig dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype=jnp.float32 - precision: Optional[Union[jax.lax.Precision, str]]=None + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): self.embed_dim = self.config.hidden_size @@ -778,13 +872,25 @@ def setup(self): self.wte = nn.Embed( self.config.vocab_size, self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + embedding_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), dtype=self.dtype, param_dtype=self.param_dtype, ) self.dropout = nn.Dropout(rate=self.config.embd_pdrop) - self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) - self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype) + self.h = FlaxLLaMABlockCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.ln_f = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) def __call__( self, @@ -830,10 +936,12 @@ def __call__( attentions=outputs[-1], ) + @add_start_docstrings("", "") class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel): module_class = FlaxLLaMAModule + # append_call_sample_docstring( # FlaxLLaMAModel, # _TOKENIZER_FOR_DOC, @@ -842,11 +950,12 @@ class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel): # _CONFIG_FOR_DOC, # ) + class FlaxLLaMAForCausalLMModule(nn.Module): config: LLaMAConfig dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype=jnp.float32 - precision: Optional[Union[jax.lax.Precision, str]]=None + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype) @@ -855,7 +964,9 @@ def setup(self): dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + kernel_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), precision=self.precision, ) @@ -876,7 +987,7 @@ def __call__( if position_ids is None: position_ids = jnp.broadcast_to( jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), - (batch_size, seq_length) + (batch_size, seq_length), ) outputs = self.transformer( input_ids, @@ -893,21 +1004,29 @@ def __call__( if self.config.tie_word_embeddings: shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) + lm_logits = self.lm_head.apply( + {"params": {"kernel": shared_kernel}}, hidden_states + ) else: lm_logits = self.lm_head(hidden_states) if not return_dict: return (lm_logits,) + outputs[1:] - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + return FlaxCausalLMOutput( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings("", "") class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): module_class = FlaxLLaMAForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation( + self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None + ): # initializing the cache batch_size, seq_length = input_ids.shape @@ -918,9 +1037,13 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0) + ) else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) return { "past_key_values": past_key_values, @@ -933,6 +1056,7 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 return model_kwargs + # append_call_sample_docstring( # FlaxGPTJForCausalLM, # _TOKENIZER_FOR_DOC, @@ -942,7 +1066,6 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): # ) - VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} PRETRAINED_VOCAB_FILES_MAP = {} @@ -972,24 +1095,28 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) + super().__init__( + bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs + ) self.vocab_file = vocab_file self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) with tempfile.NamedTemporaryFile() as tfile: - with open_file(self.vocab_file, 'rb') as fin: + with open_file(self.vocab_file, "rb") as fin: tfile.write(fin.read()) tfile.flush() tfile.seek(0) self.sp_model.Load(tfile.name) """ Initialisation""" - self.add_special_tokens(dict( - unk_token=unk_token, - bos_token=bos_token, - eos_token=eos_token, - )) + self.add_special_tokens( + dict( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + ) + ) self.pad_token_id = self.unk_token_id @property @@ -1043,7 +1170,9 @@ def convert_tokens_to_string(self, tokens): out_string += self.sp_model.decode(current_sub_tokens) return out_string.strip() - def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. Args: @@ -1056,10 +1185,14 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) logger.error(f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -1085,7 +1218,10 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding @@ -1102,7 +1238,9 @@ def get_special_tokens_mask( """ if already_has_special_tokens: return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, ) if token_ids_1 is None: diff --git a/fortuna/model/model_manager/base.py b/fortuna/model/model_manager/base.py index 7c15cc3e..f8a8cf63 100755 --- a/fortuna/model/model_manager/base.py +++ b/fortuna/model/model_manager/base.py @@ -17,7 +17,7 @@ InputData, Mutable, Params, - Shape + Shape, ) from fortuna.utils.random import WithRNG diff --git a/fortuna/output_calib_model/base.py b/fortuna/output_calib_model/base.py index 7f2d3380..d5dd4ec5 100644 --- a/fortuna/output_calib_model/base.py +++ b/fortuna/output_calib_model/base.py @@ -153,9 +153,7 @@ def load_state(self, checkpoint_dir: Path) -> None: checkpoint_dir=checkpoint_dir ) - def save_state( - self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the calibration state as a checkpoint. diff --git a/fortuna/output_calibrator/output_calib_manager/base.py b/fortuna/output_calibrator/output_calib_manager/base.py index 510e0e6b..146473e4 100644 --- a/fortuna/output_calibrator/output_calib_manager/base.py +++ b/fortuna/output_calibrator/output_calib_manager/base.py @@ -106,7 +106,9 @@ def init( rng, params_key, dropout_key = random.split(rng, 3) rngs = {"params": params_key, "dropout": dropout_key} return ( - FrozenDict(self.output_calibrator.init(rngs, jnp.zeros((1, output_dim)), **kwargs)) + FrozenDict( + self.output_calibrator.init(rngs, jnp.zeros((1, output_dim)), **kwargs) + ) if self.output_calibrator is not None else None ) diff --git a/fortuna/partitioner/base.py b/fortuna/partitioner/base.py index d7f01e61..55fe1338 100644 --- a/fortuna/partitioner/base.py +++ b/fortuna/partitioner/base.py @@ -1,22 +1,31 @@ -from typing import Dict, Tuple, Optional +from typing import ( + Dict, + Optional, + Tuple, +) + from jax.sharding import PartitionSpec + from fortuna.utils.mesh import get_mesh from fortuna.utils.port import is_port_in_use class Partitioner: def __init__( - self, - axis_dims: Optional[Dict[str, int]] = None, - rules: Optional[Dict[str, Tuple[str, ...]]] = None, - coordinator_address: Optional[str] = None, - n_devices: Optional[int] = None + self, + axis_dims: Optional[Dict[str, int]] = None, + rules: Optional[Dict[str, Tuple[str, ...]]] = None, + coordinator_address: Optional[str] = None, + n_devices: Optional[int] = None, ): if axis_dims is None: axis_dims = {"dp": 1, "fsdp": 1, "mp": 1} if rules is None: rules = {} - self.specs = {k: PartitionSpec(*v) if v is not None else PartitionSpec(None) for k, v in rules.items()} + self.specs = { + k: PartitionSpec(*v) if v is not None else PartitionSpec(None) + for k, v in rules.items() + } self.mesh = get_mesh(axis_dims) if coordinator_address is None: diff --git a/fortuna/partitioner/partition_manager/base.py b/fortuna/partitioner/partition_manager/base.py index 3c56cc6a..4a2b5c9a 100644 --- a/fortuna/partitioner/partition_manager/base.py +++ b/fortuna/partitioner/partition_manager/base.py @@ -1,23 +1,33 @@ -from typing import Callable, Optional, List +from typing import ( + Callable, + List, + Optional, +) + +from jax import ( + device_put, + eval_shape, + random, +) +from jax._src.prng import PRNGKeyArray +from jax.experimental.pjit import pjit +from jax.sharding import ( + NamedSharding, + PartitionSpec, +) +from jax.tree_util import ( + tree_map, + tree_map_with_path, +) -from fortuna.utils.random import WithRNG from fortuna.partitioner.base import Partitioner from fortuna.training.train_state import TrainState -from jax import eval_shape, device_put -from jax.tree_util import tree_map -from jax.sharding import NamedSharding, PartitionSpec from fortuna.utils.partition import match_partition_specs -from jax.experimental.pjit import pjit -from jax._src.prng import PRNGKeyArray -from jax.tree_util import tree_map_with_path -from jax import random +from fortuna.utils.random import WithRNG class PartitionManager(WithRNG): - def __init__( - self, - partitioner: Partitioner - ): + def __init__(self, partitioner: Partitioner): self.partitioner = partitioner self._init_state_fn = None self._shapes_dtypes = None @@ -34,7 +44,9 @@ def init_state_fn(self, init_state_fn: Callable[[PRNGKeyArray], TrainState]): self._init_state_fn = init_state_fn self._shapes_dtypes = eval_shape(self.init_state_fn, random.PRNGKey(0)) partitions = match_partition_specs(self.partitioner.specs, self._shapes_dtypes) - self._shardings = tree_map(lambda p: NamedSharding(mesh=self.partitioner.mesh, spec=p), partitions) + self._shardings = tree_map( + lambda p: NamedSharding(mesh=self.partitioner.mesh, spec=p), partitions + ) @property def shapes_dtypes(self): @@ -46,7 +58,9 @@ def shapes_dtypes(self): def shapes_dtypes(self, shapes_dtypes: TrainState): self._shapes_dtypes = shapes_dtypes partitions = match_partition_specs(self.partitioner.specs, self._shapes_dtypes) - self._shardings = tree_map(lambda p: NamedSharding(mesh=self.partitioner.mesh, spec=p), partitions) + self._shardings = tree_map( + lambda p: NamedSharding(mesh=self.partitioner.mesh, spec=p), partitions + ) @property def shardings(self): @@ -58,20 +72,22 @@ def shardings(self, shardings: Optional[TrainState]): def init_sharded_state_fn(self, rng: PRNGKeyArray): with self.partitioner.mesh: - return pjit(self._init_state_fn, in_shardings=PartitionSpec(), out_shardings=self.shardings)(rng) - - def reshard(self, state: TrainState, exclude: Optional[List[str]] = None) -> TrainState: + return pjit( + self._init_state_fn, + in_shardings=PartitionSpec(), + out_shardings=self.shardings, + )(rng) + + def reshard( + self, state: TrainState, exclude: Optional[List[str]] = None + ) -> TrainState: if self.shardings is not None: if exclude is None: exclude = [] return tree_map_with_path( - lambda p, _v, s: device_put(_v, s) if _v is not None and p[0].name not in exclude else _v, + lambda p, _v, s: device_put(_v, s) + if _v is not None and p[0].name not in exclude + else _v, state, - self.shardings + self.shardings, ) - - - - - - diff --git a/fortuna/prob_model/base.py b/fortuna/prob_model/base.py index b21a0d96..1dd5d256 100644 --- a/fortuna/prob_model/base.py +++ b/fortuna/prob_model/base.py @@ -247,9 +247,7 @@ def load_state(self, checkpoint_dir: Path) -> None: """ return self.posterior.load_state(checkpoint_dir) - def save_state( - self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the posterior distribution state as a checkpoint. diff --git a/fortuna/prob_model/classification.py b/fortuna/prob_model/classification.py index d83b6267..e41300f3 100644 --- a/fortuna/prob_model/classification.py +++ b/fortuna/prob_model/classification.py @@ -20,6 +20,8 @@ from fortuna.model_editor.base import ModelEditor from fortuna.output_calibrator.classification import ClassificationTemperatureScaler from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.base import ProbModel from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig @@ -43,8 +45,6 @@ get_input_shape, get_inputs_from_shape, ) -from fortuna.partitioner.base import Partitioner -from fortuna.partitioner.partition_manager.base import PartitionManager class ProbClassifier(ProbModel): @@ -138,7 +138,11 @@ def __init__( self.partition_manager = PartitionManager(partitioner) self.posterior = getattr( PosteriorApproximations, posterior_approximator.__str__() - ).value(joint=self.joint, posterior_approximator=posterior_approximator, partition_manager=self.partition_manager) + ).value( + joint=self.joint, + posterior_approximator=posterior_approximator, + partition_manager=self.partition_manager, + ) self.predictive = ClassificationPredictive(self.posterior) super().__init__(seed=seed) diff --git a/fortuna/prob_model/fit_config/checkpointer.py b/fortuna/prob_model/fit_config/checkpointer.py index 80f0ba32..ea3b448e 100644 --- a/fortuna/prob_model/fit_config/checkpointer.py +++ b/fortuna/prob_model/fit_config/checkpointer.py @@ -52,6 +52,8 @@ def __init__( allowed_checkpoint_types = ["last", "best"] if checkpoint_type not in allowed_checkpoint_types: - raise ValueError(f"`checkpoint_type={checkpoint_type}` not recognised. " - f"Pleas select one of the following options: {allowed_checkpoint_types}.") + raise ValueError( + f"`checkpoint_type={checkpoint_type}` not recognised. " + f"Pleas select one of the following options: {allowed_checkpoint_types}." + ) self.checkpoint_type = checkpoint_type diff --git a/fortuna/prob_model/joint/base.py b/fortuna/prob_model/joint/base.py index e315e87b..3a34ec8a 100755 --- a/fortuna/prob_model/joint/base.py +++ b/fortuna/prob_model/joint/base.py @@ -6,9 +6,10 @@ ) from flax.core import FrozenDict +from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp -from jax import random + from fortuna.likelihood.base import Likelihood from fortuna.model.model_manager.state import ModelManagerState from fortuna.output_calibrator.output_calib_manager.state import OutputCalibManagerState @@ -154,7 +155,9 @@ def _batched_negative_log_joint_prob( return loss, aux return -outs - def init(self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs) -> JointState: + def init( + self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs + ) -> JointState: """ Initialize the state of the joint distribution. @@ -174,9 +177,7 @@ def init(self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs) key1, key2 = random.split(rng) oms = ModelManagerState.init_from_dict( - self.likelihood.model_manager.init( - input_shape, rng=key1, **kwargs - ) + self.likelihood.model_manager.init(input_shape, rng=key1, **kwargs) ) inputs = get_inputs_from_shape(input_shape) outputs = self.likelihood.model_manager.apply( @@ -190,8 +191,7 @@ def init(self, input_shape: Shape, rng: Optional[PRNGKeyArray] = None, **kwargs) ocms = OutputCalibManagerState.init_from_dict( FrozenDict( output_calibrator=self.likelihood.output_calib_manager.init( - output_dim=output_dim, - rng=key2 + output_dim=output_dim, rng=key2 ) ) ) diff --git a/fortuna/prob_model/posterior/base.py b/fortuna/prob_model/posterior/base.py index 0e46eb2d..39e48ff6 100755 --- a/fortuna/prob_model/posterior/base.py +++ b/fortuna/prob_model/posterior/base.py @@ -10,22 +10,22 @@ from flax.core import FrozenDict from jax._src.prng import PRNGKeyArray +from orbax.checkpoint.checkpoint_manager import CheckpointManager from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState -from fortuna.utils.checkpoint import get_checkpoint_manager from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) -from orbax.checkpoint.checkpoint_manager import CheckpointManager from fortuna.prob_model.posterior.state import PosteriorState from fortuna.typing import ( Path, Status, ) -from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.utils.checkpoint import get_checkpoint_manager from fortuna.utils.freeze import get_trainable_paths from fortuna.utils.nested_dicts import ( nested_get, @@ -51,7 +51,12 @@ def posterior_method_kwargs(self) -> Dict[str, Any]: class Posterior(WithRNG): state = None - def __init__(self, joint: Joint, posterior_approximator: PosteriorApproximator, partition_manager: PartitionManager): + def __init__( + self, + joint: Joint, + posterior_approximator: PosteriorApproximator, + partition_manager: PartitionManager, + ): r""" Posterior distribution class. This refers to :math:`p(w|\mathcal{D}, \phi)`, where :math:`w` are the random model parameters, :math:`\mathcal{D}` is a training data set and :math:`\phi` are calibration parameters. @@ -75,12 +80,12 @@ def _restore_state_from_somewhere( fit_config: FitConfig, allowed_states: Optional[Tuple[Type[PosteriorState], ...]] = None, checkpoint_manager: Optional[CheckpointManager] = None, - shapes_dtypes_only: bool = False + shapes_dtypes_only: bool = False, ) -> PosteriorState: if checkpoint_manager is not None: repo = PosteriorStateRepository( partition_manager=self.partition_manager, - checkpoint_manager=checkpoint_manager + checkpoint_manager=checkpoint_manager, ) state = repo.get(optimizer=fit_config.optimizer.method) elif fit_config.checkpointer.start_from_current_state is not None: @@ -94,7 +99,9 @@ def _restore_state_from_somewhere( return state - def _init_joint_state(self, data_loader: DataLoader, rng: Optional[PRNGKeyArray] = None) -> JointState: + def _init_joint_state( + self, data_loader: DataLoader, rng: Optional[PRNGKeyArray] = None + ) -> JointState: return self.joint.init(input_shape=data_loader.input_shape, rng=rng) @staticmethod @@ -184,12 +191,11 @@ def load_state(self, checkpoint_dir: Path) -> None: """ self.state = PosteriorStateRepository( partition_manager=self.partition_manager, - checkpoint_manager=get_checkpoint_manager(checkpoint_dir=checkpoint_dir)) + checkpoint_manager=get_checkpoint_manager(checkpoint_dir=checkpoint_dir), + ) self.partition_manager.shapes_dtypes = self.state.get_shapes_dtypes_checkpoint() - def save_state( - self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1 - ) -> None: + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: """ Save the state of the posterior distribution to a checkpoint directory. diff --git a/fortuna/prob_model/posterior/map/map_posterior.py b/fortuna/prob_model/posterior/map/map_posterior.py index 1f470a8e..2c02e76b 100755 --- a/fortuna/prob_model/posterior/map/map_posterior.py +++ b/fortuna/prob_model/posterior/map/map_posterior.py @@ -1,9 +1,11 @@ import logging +from pathlib import Path from typing import Optional from jax._src.prng import PRNGKeyArray from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.joint.base import Joint from fortuna.prob_model.joint.state import JointState @@ -16,15 +18,13 @@ MAPTrainer, MultiDeviceMAPTrainer, ) -from fortuna.utils.checkpoint import get_checkpoint_manager -from pathlib import Path from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) from fortuna.typing import Status from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype +from fortuna.utils.checkpoint import get_checkpoint_manager from fortuna.utils.device import select_trainer_given_devices -from fortuna.partitioner.partition_manager.base import PartitionManager logger = logging.getLogger(__name__) @@ -34,7 +34,7 @@ def __init__( self, joint: Joint, posterior_approximator: MAPPosteriorApproximator, - partition_manager: PartitionManager + partition_manager: PartitionManager, ): """ Maximum-a-Posteriori (MAP) approximate posterior class. @@ -48,7 +48,11 @@ def __init__( partition_manager: PartitionManager An object to manage partitions. """ - super().__init__(joint=joint, posterior_approximator=posterior_approximator, partition_manager=partition_manager) + super().__init__( + joint=joint, + posterior_approximator=posterior_approximator, + partition_manager=partition_manager, + ) def __str__(self): return MAP_NAME @@ -74,7 +78,10 @@ def fit( trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, partition_manager=self.partition_manager, - checkpoint_manager=get_checkpoint_manager(fit_config.checkpointer.save_checkpoint_dir, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints), + checkpoint_manager=get_checkpoint_manager( + fit_config.checkpointer.save_checkpoint_dir, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ), save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, @@ -86,10 +93,17 @@ def fit( freeze_fun=fit_config.optimizer.freeze_fun, ) - checkpoint_restorer = get_checkpoint_manager( - str(Path(fit_config.checkpointer.restore_checkpoint_dir) / fit_config.checkpointer.checkpoint_type), - keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, - ) if fit_config.checkpointer.restore_checkpoint_dir is not None else None + checkpoint_restorer = ( + get_checkpoint_manager( + str( + Path(fit_config.checkpointer.restore_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.restore_checkpoint_dir is not None + else None + ) def init_state_fn(rng): if self._is_state_available_somewhere(fit_config): @@ -130,15 +144,18 @@ def init_state_fn(rng): self.state = PosteriorStateRepository( partition_manager=self.partition_manager, checkpoint_manager=get_checkpoint_manager( - checkpoint_dir=str(Path(fit_config.checkpointer.save_checkpoint_dir) / fit_config.checkpointer.checkpoint_type), - keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints - ) if fit_config.checkpointer.save_checkpoint_dir is not None and fit_config.checkpointer.dump_state else None + checkpoint_dir=str( + Path(fit_config.checkpointer.save_checkpoint_dir) + / fit_config.checkpointer.checkpoint_type + ), + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + if fit_config.checkpointer.save_checkpoint_dir is not None + and fit_config.checkpointer.dump_state + else None, ) if self.state.checkpoint_manager is None: - self.state.put( - state, - keep=fit_config.checkpointer.keep_top_n_checkpoints - ) + self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") return status @@ -151,7 +168,12 @@ def sample(self, rng: Optional[PRNGKeyArray] = None, **kwargs) -> JointState: calib_mutable=state.calib_mutable, ) - def _init_state(self, data_loader: DataLoader, fit_config: FitConfig, rng: Optional[PRNGKeyArray] = None) -> MAPState: + def _init_state( + self, + data_loader: DataLoader, + fit_config: FitConfig, + rng: Optional[PRNGKeyArray] = None, + ) -> MAPState: state = super()._init_joint_state(data_loader=data_loader, rng=rng) return MAPState.init( diff --git a/fortuna/prob_model/posterior/posterior_mixin.py b/fortuna/prob_model/posterior/posterior_mixin.py index f6d10f66..d4ca6c22 100644 --- a/fortuna/prob_model/posterior/posterior_mixin.py +++ b/fortuna/prob_model/posterior/posterior_mixin.py @@ -1,12 +1,11 @@ from fortuna.prob_model.posterior.name_to_posterior_state import NameToPosteriorState -from fortuna.training.name_to_train_state import NameToTrainState from fortuna.training.mixins.checkpointing import WithCheckpointingMixin +from fortuna.training.name_to_train_state import NameToTrainState class WithPosteriorCheckpointingMixin(WithCheckpointingMixin): def get_shapes_dtypes_checkpoint( - self, - name_to_train_state: NameToTrainState = NameToPosteriorState + self, name_to_train_state: NameToTrainState = NameToPosteriorState ): return super().get_shapes_dtypes_checkpoint( name_to_train_state=name_to_train_state diff --git a/fortuna/prob_model/posterior/posterior_state_repository.py b/fortuna/prob_model/posterior/posterior_state_repository.py index 053d9c00..6496c3a9 100644 --- a/fortuna/prob_model/posterior/posterior_state_repository.py +++ b/fortuna/prob_model/posterior/posterior_state_repository.py @@ -13,6 +13,4 @@ def extract_calib_keys( self, checkpoint_dir: Optional[Path] = None, ) -> Dict: - return super().extract( - ["calib_params", "calib_mutable"], checkpoint_dir - ) + return super().extract(["calib_params", "calib_mutable"], checkpoint_dir) diff --git a/fortuna/prob_model/predictive/base.py b/fortuna/prob_model/predictive/base.py index 2a3d1526..c22be74b 100755 --- a/fortuna/prob_model/predictive/base.py +++ b/fortuna/prob_model/predictive/base.py @@ -15,11 +15,11 @@ pmap, random, ) -from jax.experimental.pjit import pjit from jax._src.prng import PRNGKeyArray -from jax.sharding import PartitionSpec +from jax.experimental.pjit import pjit import jax.numpy as jnp import jax.scipy as jsp +from jax.sharding import PartitionSpec from jax.tree_util import tree_map from fortuna.data.loader import ( diff --git a/fortuna/prob_model/regression.py b/fortuna/prob_model/regression.py index b6b7670d..e121ea4c 100755 --- a/fortuna/prob_model/regression.py +++ b/fortuna/prob_model/regression.py @@ -12,6 +12,8 @@ from fortuna.model_editor.base import ModelEditor from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager from fortuna.output_calibrator.regression import RegressionTemperatureScaler +from fortuna.partitioner.base import Partitioner +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.base import ProbModel from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig @@ -20,8 +22,6 @@ from fortuna.prob_model.posterior.posterior_approximations import ( PosteriorApproximations, ) -from fortuna.partitioner.base import Partitioner -from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.prob_model.posterior.swag.swag_approximator import ( SWAGPosteriorApproximator, ) @@ -127,7 +127,11 @@ def __init__( self.partition_manager = PartitionManager(partitioner) self.posterior = getattr( PosteriorApproximations, posterior_approximator.__str__() - ).value(joint=self.joint, posterior_approximator=posterior_approximator, partition_manager=self.partition_manager) + ).value( + joint=self.joint, + posterior_approximator=posterior_approximator, + partition_manager=self.partition_manager, + ) self.predictive = RegressionPredictive(self.posterior) super().__init__(seed=seed) diff --git a/fortuna/training/mixins/checkpointing.py b/fortuna/training/mixins/checkpointing.py index ff0bdfd9..75e261f2 100644 --- a/fortuna/training/mixins/checkpointing.py +++ b/fortuna/training/mixins/checkpointing.py @@ -1,22 +1,29 @@ import logging -from typing import ( - Optional, -) +from typing import Optional from flax.training.orbax_utils import save_args_from_target +from jax import ( + ShapeDtypeStruct, + pure_callback, +) +from jax.tree_util import ( + tree_map, + tree_map_with_path, +) +from orbax.checkpoint import ( + ArrayRestoreArgs, + CheckpointManager, +) +from fortuna.partitioner.partition_manager.base import PartitionManager +from fortuna.training.name_to_train_state import NameToTrainState from fortuna.training.train_state import TrainState from fortuna.typing import ( OptaxOptimizer, Path, ) -from jax.tree_util import tree_map_with_path, tree_map -from fortuna.training.name_to_train_state import NameToTrainState -from jax import pure_callback -from orbax.checkpoint import CheckpointManager, ArrayRestoreArgs -from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.utils.checkpoint import get_checkpoint_manager -from jax import ShapeDtypeStruct + logger = logging.getLogger(__name__) @@ -50,7 +57,13 @@ def save_checkpoint( keep: int = 1, force_save: bool = False, ) -> None: - checkpoint_manager = get_checkpoint_manager(checkpoint_dir=save_checkpoint_dir, keep_top_n_checkpoints=keep) if save_checkpoint_dir is not None else self.checkpoint_manager + checkpoint_manager = ( + get_checkpoint_manager( + checkpoint_dir=save_checkpoint_dir, keep_top_n_checkpoints=keep + ) + if save_checkpoint_dir is not None + else self.checkpoint_manager + ) if self.checkpoint_manager: save_args = save_args_from_target(state) @@ -59,8 +72,9 @@ def save_ckpt_fn(_state): _state.step, _state, force=force_save, - save_kwargs={'save_args': save_args} + save_kwargs={"save_args": save_args}, ) + if ( hasattr(state, "grad_accumulated") and state.grad_accumulated is not None @@ -74,7 +88,10 @@ def restore_checkpoint( restore_checkpoint_dir: Path, optimizer: Optional[OptaxOptimizer] = None, ) -> TrainState: - if self.partition_manager.shardings is not None and self.partition_manager.shapes_dtypes is not None: + if ( + self.partition_manager.shardings is not None + and self.partition_manager.shapes_dtypes is not None + ): ref = self._get_ref_from_shardings() else: ref = self._get_ref_without_shardings() @@ -83,29 +100,27 @@ def restore_checkpoint( lambda: self.checkpoint_manager.restore( self.checkpoint_manager.latest_step(), items=ref, - restore_kwargs={'restore_args': ref}, - directory=restore_checkpoint_dir + restore_kwargs={"restore_args": ref}, + directory=restore_checkpoint_dir, ), - ref + ref, ) if optimizer is not None: restored = restored.replace( - tx=optimizer, - opt_state=optimizer.init(restored.params) + tx=optimizer, opt_state=optimizer.init(restored.params) ) return restored def get_shapes_dtypes_checkpoint( - self, - name_to_train_state: NameToTrainState = NameToTrainState + self, name_to_train_state: NameToTrainState = NameToTrainState ): ref = self._get_ref_without_shardings() state = self.checkpoint_manager.restore( self.checkpoint_manager.latest_step(), items=ref, - restore_kwargs=dict(restore_args=ref) + restore_kwargs=dict(restore_args=ref), ) name = "".join([chr(n) for n in state["encoded_name"].get().tolist()]) state = name_to_train_state[name].value.init_from_dict(state) @@ -114,15 +129,16 @@ def get_shapes_dtypes_checkpoint( def _get_ref_from_shardings(self): return tree_map_with_path( lambda p, sharding, shape_dtype: ArrayRestoreArgsWithShape( - sharding=sharding, - dtype=shape_dtype.dtype, - shape=shape_dtype.shape + sharding=sharding, dtype=shape_dtype.dtype, shape=shape_dtype.shape ), - self.partition_manager.shardings, self.partition_manager.shapes_dtypes + self.partition_manager.shardings, + self.partition_manager.shapes_dtypes, ) def _get_ref_without_shardings(self): - return tree_map(lambda v: ArrayRestoreArgs(lazy=True), self.checkpoint_manager.structure()) + return tree_map( + lambda v: ArrayRestoreArgs(lazy=True), self.checkpoint_manager.structure() + ) class ArrayRestoreArgsWithShape(ArrayRestoreArgs): diff --git a/fortuna/training/output_calibrator.py b/fortuna/training/output_calibrator.py index f8176021..8a598528 100644 --- a/fortuna/training/output_calibrator.py +++ b/fortuna/training/output_calibrator.py @@ -25,12 +25,13 @@ from jax.tree_util import tree_map from tqdm import trange from tqdm.std import tqdm as TqdmDecorator -from fortuna.partitioner.partition_manager.base import PartitionManager + from fortuna.data.loader import ( DataLoader, TargetsLoader, ) from fortuna.output_calib_model.state import OutputCalibState +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin from fortuna.training.mixins.input_validator import InputValidatorMixin @@ -67,7 +68,9 @@ def __init__( eval_every_n_epochs: int = 1, **kwargs, ): - super(OutputCalibratorABC, self).__init__(*args, partition_manager=partition_manager, **kwargs) + super(OutputCalibratorABC, self).__init__( + *args, partition_manager=partition_manager, **kwargs + ) self._calib_outputs_loader = calib_outputs_loader self._val_outputs_loader = val_outputs_loader self.predict_fn = predict_fn diff --git a/fortuna/training/train_state_repository.py b/fortuna/training/train_state_repository.py index 1b7521d4..5bf27737 100644 --- a/fortuna/training/train_state_repository.py +++ b/fortuna/training/train_state_repository.py @@ -1,22 +1,28 @@ +from shutil import rmtree from typing import ( Dict, List, Optional, Union, ) + +from orbax.checkpoint import CheckpointManager + +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.training.train_state import TrainState -from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.typing import ( OptaxOptimizer, Path, ) -from orbax.checkpoint import CheckpointManager -from shutil import rmtree class TrainStateRepository(WithCheckpointingMixin): - def __init__(self, partition_manager: PartitionManager, checkpoint_manager: Optional[CheckpointManager] = None): + def __init__( + self, + partition_manager: PartitionManager, + checkpoint_manager: Optional[CheckpointManager] = None, + ): super().__init__(partition_manager=partition_manager) self.checkpoint_manager = checkpoint_manager self._state = None @@ -30,14 +36,11 @@ def get( raise ValueError("No state available.") if checkpoint_dir or self.checkpoint_manager: return self.restore_checkpoint( - restore_checkpoint_dir=checkpoint_dir, - optimizer=optimizer + restore_checkpoint_dir=checkpoint_dir, optimizer=optimizer ) if optimizer is not None: state = self.partition_manager.reshard(self._state) - state = state.replace( - tx=optimizer, opt_state=optimizer.init(state.params) - ) + state = state.replace(tx=optimizer, opt_state=optimizer.init(state.params)) return state return self._state diff --git a/fortuna/training/trainer.py b/fortuna/training/trainer.py index 7a1a58d9..7010edf5 100755 --- a/fortuna/training/trainer.py +++ b/fortuna/training/trainer.py @@ -2,6 +2,7 @@ import collections from functools import partial import logging +from pathlib import Path as _Path from typing import ( Any, Callable, @@ -25,10 +26,12 @@ import jax.numpy as jnp from jax.tree_util import tree_map from optax._src.base import PyTree +from orbax.checkpoint import CheckpointManager from tqdm import trange from tqdm.std import tqdm as TqdmDecorator -from orbax.checkpoint import CheckpointManager + from fortuna.data.loader import DataLoader +from fortuna.partitioner.partition_manager.base import PartitionManager from fortuna.training.callback import Callback from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.training.mixins.early_stopping import WithEarlyStoppingMixin @@ -45,7 +48,6 @@ Path, Status, ) -from pathlib import Path as _Path from fortuna.utils.builtins import HashableMixin from fortuna.utils.freeze import ( get_frozen_paths, @@ -57,7 +59,6 @@ nested_update, ) from fortuna.utils.training import clip_grandients_by_norm -from fortuna.partitioner.partition_manager.base import PartitionManager class TrainerABC( @@ -82,7 +83,12 @@ def __init__( freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None, **kwargs, ): - super(TrainerABC, self).__init__(*args, partition_manager=partition_manager, checkpoint_manager=checkpoint_manager, **kwargs) + super(TrainerABC, self).__init__( + *args, + partition_manager=partition_manager, + checkpoint_manager=checkpoint_manager, + **kwargs, + ) self.predict_fn = predict_fn self.uncertainty_fn = uncertainty_fn self.save_checkpoint_dir = save_checkpoint_dir @@ -355,7 +361,7 @@ def validation_epoch_end( state, str(_Path(self.save_checkpoint_dir) / "best"), force_save=True, - prefix="" + prefix="", ) return validation_losses_and_metrics_current_epoch @@ -651,10 +657,12 @@ def on_train_start( def on_train_end(self, state: TrainState) -> TrainState: self.save_checkpoint( state, - save_checkpoint_dir=str(_Path(self.save_checkpoint_dir) / "last") if self.save_checkpoint_dir is not None else None, + save_checkpoint_dir=str(_Path(self.save_checkpoint_dir) / "last") + if self.save_checkpoint_dir is not None + else None, keep=self.keep_top_n_checkpoints, force_save=True, - prefix="" + prefix="", ) if self.freeze_fun is not None: diff --git a/fortuna/utils/checkpoint.py b/fortuna/utils/checkpoint.py index c6db13b6..f8ad6214 100644 --- a/fortuna/utils/checkpoint.py +++ b/fortuna/utils/checkpoint.py @@ -1,13 +1,21 @@ -from orbax.checkpoint import CheckpointManager, PyTreeCheckpointHandler, Checkpointer, CheckpointManagerOptions from typing import Optional +from orbax.checkpoint import ( + Checkpointer, + CheckpointManager, + CheckpointManagerOptions, + PyTreeCheckpointHandler, +) -def get_checkpoint_manager(checkpoint_dir: str, keep_top_n_checkpoints: Optional[int] = None): + +def get_checkpoint_manager( + checkpoint_dir: str, keep_top_n_checkpoints: Optional[int] = None +): if checkpoint_dir is not None: - options = CheckpointManagerOptions(create=True, max_to_keep=keep_top_n_checkpoints) + options = CheckpointManagerOptions( + create=True, max_to_keep=keep_top_n_checkpoints + ) return CheckpointManager( - checkpoint_dir, - Checkpointer(PyTreeCheckpointHandler()), - options + checkpoint_dir, Checkpointer(PyTreeCheckpointHandler()), options ) return None diff --git a/fortuna/utils/mesh.py b/fortuna/utils/mesh.py index d498c1a0..905e6697 100644 --- a/fortuna/utils/mesh.py +++ b/fortuna/utils/mesh.py @@ -1,10 +1,15 @@ -from jax.sharding import PartitionSpec, Mesh +from typing import Dict + from jax import device_count from jax.experimental.mesh_utils import create_device_mesh -from jax.lax import with_sharding_constraint from jax.interpreters import pxla +from jax.lax import with_sharding_constraint +from jax.sharding import ( + Mesh, + PartitionSpec, +) import numpy as np -from typing import Dict + from fortuna.utils.partition import get_names_from_partition_spec @@ -14,7 +19,9 @@ def get_mesh(axis_dims: Dict[str, int]): allowed_keys = ("dp", "fsdp", "mp") if set(keys) != set(allowed_keys): - raise ValueError(f"`axis_dims` must contain exactly the following keys: {allowed_keys}.") + raise ValueError( + f"`axis_dims` must contain exactly the following keys: {allowed_keys}." + ) for v in dims: if type(v) != int: raise ValueError("All values in `axis_dims` must be integers or `-1`.") @@ -26,12 +33,16 @@ def get_mesh(axis_dims: Dict[str, int]): fixed_prod = np.prod([v for v in dims if v != -1]) reminder = n_devices % fixed_prod if fixed_prod > n_devices: - raise ValueError(f"The product of the specified axis dimensions cannot be greater than {n_devices}, " - f"the number of available devices.") + raise ValueError( + f"The product of the specified axis dimensions cannot be greater than {n_devices}, " + f"the number of available devices." + ) if reminder != 0: - raise ValueError("The product of the axis dimensions must divide the number of available devices. " - f"However, {n_devices} were found, and {fixed_prod} to be the product of the specified axis " - f"dimensions.") + raise ValueError( + "The product of the axis dimensions must divide the number of available devices. " + f"However, {n_devices} were found, and {fixed_prod} to be the product of the specified axis " + f"dimensions." + ) dims = tuple([dims[np.where(np.array(keys) == k)[0][0]] for k in allowed_keys]) mesh_shape = np.arange(n_devices).reshape(dims).shape @@ -76,4 +87,3 @@ def with_conditional_sharding_constraint(x, partition_specs): if names_in_current_mesh(*axis_names): x = with_sharding_constraint(x, partition_specs) return x - diff --git a/fortuna/utils/nested_dicts.py b/fortuna/utils/nested_dicts.py index d2e41ee4..754255ce 100644 --- a/fortuna/utils/nested_dicts.py +++ b/fortuna/utils/nested_dicts.py @@ -6,9 +6,15 @@ Tuple, Union, ) -from jax.tree_util import tree_map_with_path, SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey from flax.core import FrozenDict +from jax.tree_util import ( + DictKey, + FlattenedIndexKey, + GetAttrKey, + SequenceKey, + tree_map_with_path, +) from fortuna.typing import AnyKey @@ -216,7 +222,10 @@ def nested_update( return updated_mapping -def path_to_string(path: Tuple[Union[DictKey, SequenceKey, GetAttrKey, FlattenedIndexKey, AnyKey]], separator: str = None) -> Union[str, Tuple[str]]: +def path_to_string( + path: Tuple[Union[DictKey, SequenceKey, GetAttrKey, FlattenedIndexKey, AnyKey]], + separator: str = None, +) -> Union[str, Tuple[str]]: """ Transform a sequence of keys into a string. diff --git a/fortuna/utils/partition.py b/fortuna/utils/partition.py index a7ab30f8..a181b599 100644 --- a/fortuna/utils/partition.py +++ b/fortuna/utils/partition.py @@ -1,23 +1,35 @@ import re -from jax.sharding import PartitionSpec -import numpy as np -from typing import Dict, Tuple -from fortuna.utils.nested_dicts import path_to_string +from typing import ( + Dict, + Tuple, +) + +import jax.numpy as jnp +from jax.sharding import ( + PartitionSpec, + Sharding, +) from jax.tree_util import tree_map_with_path +import numpy as np from optax._src.base import PyTree -from jax.sharding import Sharding -import jax.numpy as jnp + +from fortuna.utils.nested_dicts import path_to_string def named_tree_map(f, tree, *rest, is_leaf=None, separator=None): return tree_map_with_path( - lambda string_path, x, *r: f(path_to_string(string_path, separator=separator), x, *r), - tree, *rest, - is_leaf=is_leaf + lambda string_path, x, *r: f( + path_to_string(string_path, separator=separator), x, *r + ), + tree, + *rest, + is_leaf=is_leaf, ) -def match_partition_specs(partition_specs: Dict[str, PartitionSpec], tree: PyTree) -> PyTree: +def match_partition_specs( + partition_specs: Dict[str, PartitionSpec], tree: PyTree +) -> PyTree: """ Match partition specifics to a tree structure. @@ -31,6 +43,7 @@ def match_partition_specs(partition_specs: Dict[str, PartitionSpec], tree: PyTre PyTree A tree of partition specifics. """ + def get_partition_spec(path, shape_leaf): if len(shape_leaf.shape) == 0 or np.prod(shape_leaf.shape) == 1: # do not partition scalar values @@ -40,11 +53,12 @@ def get_partition_spec(path, shape_leaf): return ps # raise ValueError(f"A partition rule for the following path was not found: `{path}`") return PartitionSpec() - return named_tree_map(get_partition_spec, tree, separator='/') + + return named_tree_map(get_partition_spec, tree, separator="/") def get_names_from_partition_spec(partition_specs): - """ Return axis names from partition specs. """ + """Return axis names from partition specs.""" names = set() if isinstance(partition_specs, dict): partition_specs = partition_specs.values() diff --git a/fortuna/utils/port.py b/fortuna/utils/port.py index 8849e948..1bd99bef 100644 --- a/fortuna/utils/port.py +++ b/fortuna/utils/port.py @@ -3,4 +3,4 @@ def is_port_in_use(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0