diff --git a/fortuna/utils/mesh.py b/fortuna/utils/mesh.py new file mode 100644 index 00000000..bb845669 --- /dev/null +++ b/fortuna/utils/mesh.py @@ -0,0 +1,356 @@ +from functools import partial +import re +import random +from ml_collections import ConfigDict +from ml_collections.config_dict.config_dict import placeholder + +import flax +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as PS +from jax.sharding import Mesh +from jax.experimental import mesh_utils +from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint +from jax.experimental.pjit import pjit +from jax.interpreters import pxla +import numpy as np + + +class DistributedConfig(object): + """ Utility class for initializing JAX distributed. """ + + @staticmethod + def get_default_config(updates=None): + config = ConfigDict() + config.initialize_jax_distributed = False + config.coordinator_address = placeholder(str) + config.num_processes = placeholder(int) + config.process_id = placeholder(int) + config.local_device_ids = placeholder(str) + + if updates is not None: + config.update(ConfigDict(updates).copy_and_resolve_references()) + return config + + @classmethod + def initialize(cls, config): + config = cls.get_default_config(config) + if config.initialize_jax_distributed: + if config.local_device_ids is not None: + local_device_ids = [int(x) for x in config.local_device_ids.split(',')] + else: + local_device_ids = None + + jax.distributed.initialize( + coordinator_address=config.coordinator_address, + num_processes=config.num_processes, + process_id=config.process_id, + local_device_ids=local_device_ids, + ) + + +def make_shard_and_gather_fns(partition_specs, dtype_specs=None): + """ Create pytree of sharding and gathering functions from pytree of + partition specs. + """ + float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) + + def make_to_dtype_fn(dtype_spec): + def to_dtype(tensor): + if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes: + # Convert all float tensors to the same dtype + return tensor.astype(dtype_specs) + elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'): + return tensor.astype(dtype_spec.dtype) + return tensor + return to_dtype + + def make_shard_fn(partition_spec, dtype_spec=None): + jax_shard_function = pjit( + make_to_dtype_fn(dtype_spec), + in_shardings=None, + out_shardings=partition_spec + ) + def shard_fn(tensor): + return jax_shard_function(tensor).block_until_ready() + return shard_fn + + def make_gather_fn(partition_spec, dtype_spec=None): + jax_gather_fn = pjit( + make_to_dtype_fn(dtype_spec), + in_shardings=partition_spec, + out_shardings=None + ) + def gather_fn(tensor): + return jax.device_get(jax_gather_fn(tensor)) + return gather_fn + + if dtype_specs is None or dtype_specs in float_dtypes: + shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs) + gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs) + else: + shard_fns = jax.tree_util.tree_map( + make_shard_fn, partition_specs, dtype_specs + ) + gather_fns = jax.tree_util.tree_map( + make_gather_fn, partition_specs, dtype_specs + ) + return shard_fns, gather_fns + + +def get_jax_mesh(axis_dims, names): + if axis_dims.startswith('!'): + # Allow splitting a physical mesh axis if needed + mesh_axis_splitting = True + axis_dims = axis_dims[1:] + else: + mesh_axis_splitting = False + + if ':' in axis_dims: + dims = [] + dim_names = [] + for axis in axis_dims.split(','): + name, dim = axis.split(':') + assert name in names + dims.append(int(dim)) + dim_names.append(name) + assert(set(dim_names) == set(names)) + else: + dims = [int(x) for x in axis_dims.split(',')] + dim_names = names + assert len(dims) == len(names) + mesh_shape = np.arange(jax.device_count()).reshape(dims).shape + if mesh_axis_splitting: + physical_mesh = np.array(jax.devices()).reshape(mesh_shape) + else: + physical_mesh = mesh_utils.create_device_mesh(mesh_shape) + return Mesh(physical_mesh, dim_names) + + +def names_in_current_mesh(*names): + """ Check if current mesh axes contain these names. """ + mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names + return set(names) <= set(mesh_axis_names) + + +def get_names_from_partition_spec(partition_specs): + """ Return axis names from partition specs. """ + names = set() + if isinstance(partition_specs, dict): + partition_specs = partition_specs.values() + for item in partition_specs: + if item is None: + continue + elif isinstance(item, str): + names.add(item) + else: + names.update(get_names_from_partition_spec(item)) + + return list(names) + + +def with_sharding_constraint(x, partition_specs): + """ A smarter version of with_sharding_constraint that only applies the + constraint if the current mesh contains the axes in the partition specs. + """ + axis_names = get_names_from_partition_spec(partition_specs) + if names_in_current_mesh(*axis_names): + x = _with_sharding_constraint(x, partition_specs) + return x + + +def wrap_function_with_rng(rng): + """ To be used as decorator, automatically bookkeep a RNG for the wrapped function. """ + def wrap_function(function): + def wrapped(*args, **kwargs): + nonlocal rng + rng, split_rng = jax.random.split(rng) + return function(split_rng, *args, **kwargs) + return wrapped + return wrap_function + + +def init_rng(seed): + global jax_utils_rng + jax_utils_rng = JaxRNG.from_seed(seed) + + +def next_rng(*args, **kwargs): + global jax_utils_rng + return jax_utils_rng(*args, **kwargs) + + +def get_metrics(metrics, unreplicate=False, stack=False): + if unreplicate: + metrics = flax.jax_utils.unreplicate(metrics) + metrics = jax.device_get(metrics) + if stack: + return jax.tree_map(lambda *args: np.stack(args), *metrics) + else: + return {key: float(val) for key, val in metrics.items()} + + +def mse_loss(val, target, valid=None): + if valid is None: + valid = jnp.ones((*target.shape[:2], 1)) + valid = valid.astype(jnp.float32) + loss = jnp.mean( + jnp.where( + valid > 0.0, + jnp.square(val - target), + 0.0 + ) + ) + return loss + + +def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): + if valid is None: + valid = jnp.ones(tokens.shape[:2]) + valid = valid.astype(jnp.float32) + valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) + logits = logits.astype(jnp.float32) # for numerical stability + token_log_prob = jnp.squeeze( + jnp.take_along_axis( + jax.nn.log_softmax(logits, axis=-1), + jnp.expand_dims(tokens, -1), + axis=-1, + ), + -1, + ) + token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) + loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length) + correct = jnp.where( + valid > 0.0, + jnp.argmax(logits, axis=-1) == tokens, + jnp.array(False) + ) + accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) + return loss, accuracy + + +def global_norm(tree): + """ Return the global L2 norm of a pytree. """ + squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) + flattened, _ = jax.flatten_util.ravel_pytree(squared) + return jnp.sqrt(jnp.sum(flattened)) + + +def average_metrics(metrics): + return jax.tree_map( + lambda *args: jnp.mean(jnp.stack(args)), + *metrics + ) + + +def get_float_dtype_by_name(dtype): + return { + 'bf16': jnp.bfloat16, + 'bfloat16': jnp.bfloat16, + 'fp16': jnp.float16, + 'float16': jnp.float16, + 'fp32': jnp.float32, + 'float32': jnp.float32, + 'fp64': jnp.float64, + 'float64': jnp.float64, + }[dtype] + + +def float_tensor_to_dtype(tensor, dtype): + if dtype is None or dtype == '': + return tensor + if isinstance(dtype, str): + dtype = get_float_dtype_by_name(dtype) + float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) + if getattr(tensor, 'dtype', None) in float_dtypes: + tensor = tensor.astype(dtype) + return tensor + + +def float_to_dtype(tree, dtype): + return jax.tree_util.tree_map( + partial(float_tensor_to_dtype, dtype=dtype), tree + ) + + +def get_gradient_checkpoint_policy(name): + return { + 'everything_saveable': jax.checkpoint_policies.everything_saveable, + 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, + 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, + 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + }[name] + + +def tree_path_to_string(path, sep=None): + keys = [] + for key in path: + if isinstance(key, jax.tree_util.SequenceKey): + keys.append(str(key.idx)) + elif isinstance(key, jax.tree_util.DictKey): + keys.append(str(key.key)) + elif isinstance(key, jax.tree_util.GetAttrKey): + keys.append(str(key.name)) + elif isinstance(key, jax.tree_util.FlattenedIndexKey): + keys.append(str(key.key)) + else: + keys.append(str(key)) + if sep is None: + return tuple(keys) + return sep.join(keys) + + +def flatten_tree(xs, is_leaf=None, sep=None): + flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf) + output = {} + for key, val in flattened: + output[tree_path_to_string(key, sep=sep)] = val + return output + + +def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): + """ An extended version of jax.tree_util.tree_map, where the mapped function + f takes both the name (path) and the tree leaf as input. + """ + return jax.tree_util.tree_map_with_path( + lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), + tree, *rest, + is_leaf=is_leaf + ) + + +def match_partition_rules(rules, params): + """ Returns a pytree of PartitionSpec according to rules. Supports handling + Flax TrainState and Optax optimizer state. + """ + def get_partition_spec(name, leaf): + if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1: + """ Don't partition scalar values. """ + return PS() + for rule, ps in rules: + if re.search(rule, name) is not None: + return ps + raise ValueError(f'Partition rule not found for param: {name}') + return named_tree_map(get_partition_spec, params, sep='/') + + +def get_weight_decay_mask(exclusions): + """ Return a weight decay mask function that computes the pytree masks + according to the given exclusion rules. + """ + def decay(name, _): + for rule in exclusions: + if re.search(rule, name) is not None: + return False + return True + + def weight_decay_mask(params): + return named_tree_map(decay, params, sep='/') + + return weight_decay_mask + + +def tree_apply(fns, tree): + """ Apply a pytree of functions to the pytree. """ + return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree) +