Skip to content

Commit

Permalink
mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jun 25, 2023
1 parent 734f597 commit 404840e
Showing 1 changed file with 356 additions and 0 deletions.
356 changes: 356 additions & 0 deletions fortuna/utils/mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
from functools import partial
import re
import random
from ml_collections import ConfigDict
from ml_collections.config_dict.config_dict import placeholder

import flax
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
from jax.experimental.pjit import pjit
from jax.interpreters import pxla
import numpy as np


class DistributedConfig(object):
""" Utility class for initializing JAX distributed. """

@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.initialize_jax_distributed = False
config.coordinator_address = placeholder(str)
config.num_processes = placeholder(int)
config.process_id = placeholder(int)
config.local_device_ids = placeholder(str)

if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config

@classmethod
def initialize(cls, config):
config = cls.get_default_config(config)
if config.initialize_jax_distributed:
if config.local_device_ids is not None:
local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
else:
local_device_ids = None

jax.distributed.initialize(
coordinator_address=config.coordinator_address,
num_processes=config.num_processes,
process_id=config.process_id,
local_device_ids=local_device_ids,
)


def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
""" Create pytree of sharding and gathering functions from pytree of
partition specs.
"""
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)

def make_to_dtype_fn(dtype_spec):
def to_dtype(tensor):
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
# Convert all float tensors to the same dtype
return tensor.astype(dtype_specs)
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
return tensor.astype(dtype_spec.dtype)
return tensor
return to_dtype

def make_shard_fn(partition_spec, dtype_spec=None):
jax_shard_function = pjit(
make_to_dtype_fn(dtype_spec),
in_shardings=None,
out_shardings=partition_spec
)
def shard_fn(tensor):
return jax_shard_function(tensor).block_until_ready()
return shard_fn

def make_gather_fn(partition_spec, dtype_spec=None):
jax_gather_fn = pjit(
make_to_dtype_fn(dtype_spec),
in_shardings=partition_spec,
out_shardings=None
)
def gather_fn(tensor):
return jax.device_get(jax_gather_fn(tensor))
return gather_fn

if dtype_specs is None or dtype_specs in float_dtypes:
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
else:
shard_fns = jax.tree_util.tree_map(
make_shard_fn, partition_specs, dtype_specs
)
gather_fns = jax.tree_util.tree_map(
make_gather_fn, partition_specs, dtype_specs
)
return shard_fns, gather_fns


def get_jax_mesh(axis_dims, names):
if axis_dims.startswith('!'):
# Allow splitting a physical mesh axis if needed
mesh_axis_splitting = True
axis_dims = axis_dims[1:]
else:
mesh_axis_splitting = False

if ':' in axis_dims:
dims = []
dim_names = []
for axis in axis_dims.split(','):
name, dim = axis.split(':')
assert name in names
dims.append(int(dim))
dim_names.append(name)
assert(set(dim_names) == set(names))
else:
dims = [int(x) for x in axis_dims.split(',')]
dim_names = names
assert len(dims) == len(names)
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
if mesh_axis_splitting:
physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
else:
physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
return Mesh(physical_mesh, dim_names)


def names_in_current_mesh(*names):
""" Check if current mesh axes contain these names. """
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
return set(names) <= set(mesh_axis_names)


def get_names_from_partition_spec(partition_specs):
""" Return axis names from partition specs. """
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_partition_spec(item))

return list(names)


def with_sharding_constraint(x, partition_specs):
""" A smarter version of with_sharding_constraint that only applies the
constraint if the current mesh contains the axes in the partition specs.
"""
axis_names = get_names_from_partition_spec(partition_specs)
if names_in_current_mesh(*axis_names):
x = _with_sharding_constraint(x, partition_specs)
return x


def wrap_function_with_rng(rng):
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
def wrap_function(function):
def wrapped(*args, **kwargs):
nonlocal rng
rng, split_rng = jax.random.split(rng)
return function(split_rng, *args, **kwargs)
return wrapped
return wrap_function


def init_rng(seed):
global jax_utils_rng
jax_utils_rng = JaxRNG.from_seed(seed)


def next_rng(*args, **kwargs):
global jax_utils_rng
return jax_utils_rng(*args, **kwargs)


def get_metrics(metrics, unreplicate=False, stack=False):
if unreplicate:
metrics = flax.jax_utils.unreplicate(metrics)
metrics = jax.device_get(metrics)
if stack:
return jax.tree_map(lambda *args: np.stack(args), *metrics)
else:
return {key: float(val) for key, val in metrics.items()}


def mse_loss(val, target, valid=None):
if valid is None:
valid = jnp.ones((*target.shape[:2], 1))
valid = valid.astype(jnp.float32)
loss = jnp.mean(
jnp.where(
valid > 0.0,
jnp.square(val - target),
0.0
)
)
return loss


def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
if valid is None:
valid = jnp.ones(tokens.shape[:2])
valid = valid.astype(jnp.float32)
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
logits = logits.astype(jnp.float32) # for numerical stability
token_log_prob = jnp.squeeze(
jnp.take_along_axis(
jax.nn.log_softmax(logits, axis=-1),
jnp.expand_dims(tokens, -1),
axis=-1,
),
-1,
)
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
correct = jnp.where(
valid > 0.0,
jnp.argmax(logits, axis=-1) == tokens,
jnp.array(False)
)
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
return loss, accuracy


def global_norm(tree):
""" Return the global L2 norm of a pytree. """
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
flattened, _ = jax.flatten_util.ravel_pytree(squared)
return jnp.sqrt(jnp.sum(flattened))


def average_metrics(metrics):
return jax.tree_map(
lambda *args: jnp.mean(jnp.stack(args)),
*metrics
)


def get_float_dtype_by_name(dtype):
return {
'bf16': jnp.bfloat16,
'bfloat16': jnp.bfloat16,
'fp16': jnp.float16,
'float16': jnp.float16,
'fp32': jnp.float32,
'float32': jnp.float32,
'fp64': jnp.float64,
'float64': jnp.float64,
}[dtype]


def float_tensor_to_dtype(tensor, dtype):
if dtype is None or dtype == '':
return tensor
if isinstance(dtype, str):
dtype = get_float_dtype_by_name(dtype)
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
if getattr(tensor, 'dtype', None) in float_dtypes:
tensor = tensor.astype(dtype)
return tensor


def float_to_dtype(tree, dtype):
return jax.tree_util.tree_map(
partial(float_tensor_to_dtype, dtype=dtype), tree
)


def get_gradient_checkpoint_policy(name):
return {
'everything_saveable': jax.checkpoint_policies.everything_saveable,
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
}[name]


def tree_path_to_string(path, sep=None):
keys = []
for key in path:
if isinstance(key, jax.tree_util.SequenceKey):
keys.append(str(key.idx))
elif isinstance(key, jax.tree_util.DictKey):
keys.append(str(key.key))
elif isinstance(key, jax.tree_util.GetAttrKey):
keys.append(str(key.name))
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
keys.append(str(key.key))
else:
keys.append(str(key))
if sep is None:
return tuple(keys)
return sep.join(keys)


def flatten_tree(xs, is_leaf=None, sep=None):
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
output = {}
for key, val in flattened:
output[tree_path_to_string(key, sep=sep)] = val
return output


def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
""" An extended version of jax.tree_util.tree_map, where the mapped function
f takes both the name (path) and the tree leaf as input.
"""
return jax.tree_util.tree_map_with_path(
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
tree, *rest,
is_leaf=is_leaf
)


def match_partition_rules(rules, params):
""" Returns a pytree of PartitionSpec according to rules. Supports handling
Flax TrainState and Optax optimizer state.
"""
def get_partition_spec(name, leaf):
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
""" Don't partition scalar values. """
return PS()
for rule, ps in rules:
if re.search(rule, name) is not None:
return ps
raise ValueError(f'Partition rule not found for param: {name}')
return named_tree_map(get_partition_spec, params, sep='/')


def get_weight_decay_mask(exclusions):
""" Return a weight decay mask function that computes the pytree masks
according to the given exclusion rules.
"""
def decay(name, _):
for rule in exclusions:
if re.search(rule, name) is not None:
return False
return True

def weight_decay_mask(params):
return named_tree_map(decay, params, sep='/')

return weight_decay_mask


def tree_apply(fns, tree):
""" Apply a pytree of functions to the pytree. """
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)

0 comments on commit 404840e

Please sign in to comment.