From 5be50d76336181d629e7fe16e33b2a18d5d14c25 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Thu, 9 Jan 2025 16:40:49 -0800 Subject: [PATCH] TRN2 Meshes and Configurations --- axlearn/common/trainer_config_modifier.py | 132 ++++++++++++++++-- .../common/trainer_config_modifier_test.py | 72 +++++++++- .../fuji-1B-v3-flash-single-host.txt | 59 ++++++++ .../fuji-1B-v3-flash.txt | 59 ++++++++ .../fuji-1B-v3-single-host.txt | 59 ++++++++ .../fuji-1B-v3.txt | 59 ++++++++ .../fuji-3B-v3-flash-single-host.txt | 59 ++++++++ .../fuji-3B-v3-flash.txt | 59 ++++++++ .../fuji-3B-v3-single-host.txt | 59 ++++++++ .../fuji-3B-v3.txt | 59 ++++++++ .../fuji-70B-v1-flash.txt | 65 ++++++++- .../fuji-70B-v1.txt | 65 ++++++++- .../fuji-70B-v2-flash.txt | 70 +++++++++- .../fuji-70B-v2.txt | 70 +++++++++- .../fuji-70B-v3-flash.txt | 70 +++++++++- .../fuji-70B-v3.txt | 70 +++++++++- .../fuji-7B-v1-flash-single-host.txt | 54 +++++++ .../fuji-7B-v1-flash.txt | 54 +++++++ .../fuji-7B-v1-single-host.txt | 54 +++++++ .../fuji-7B-v1.txt | 54 +++++++ .../fuji-7B-v2-flash-single-host.txt | 59 ++++++++ .../fuji-7B-v2-flash.txt | 59 ++++++++ .../fuji-7B-v2-single-host.txt | 59 ++++++++ .../fuji-7B-v2.txt | 59 ++++++++ .../fuji-8B-v3-flash-single-host.txt | 59 ++++++++ .../fuji-8B-v3-flash.txt | 59 ++++++++ .../fuji-8B-v3-single-host.txt | 59 ++++++++ .../fuji-8B-v3.txt | 59 ++++++++ axlearn/experiments/text/gpt/fuji.py | 105 +++++++++++++- 29 files changed, 1859 insertions(+), 20 deletions(-) diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index d647e1a06..f5137f022 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -8,6 +8,7 @@ from axlearn.common.base_layer import RematSpec from axlearn.common.config import ( REQUIRED, + ConfigBase, ConfigModifier, ConfigOr, Required, @@ -17,7 +18,27 @@ from axlearn.common.gradient_accumulation import with_minibatch_steps from axlearn.common.metrics import MetricAccumulator from axlearn.common.trainer import SpmdTrainer -from axlearn.common.utils import HybridMeshShape, MeshShape +from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec + + +def find_target_module( + module_name: str, cfg: SpmdTrainer.Config +) -> tuple[ConfigModifier.Config, ConfigModifier.Config, ConfigModifier.Config]: + """Recursively search for the target module matching module name in provided config""" + # Here we assume x.y.z format. + # One example would be model.decoder.transformer.layer. + target_modules = module_name.split(".") + curr_module = cfg + key_in_parent = None + parent_module = None + + for target_module in target_modules: + if not hasattr(curr_module, target_module): + raise ValueError(f"{target_module} is not found in {curr_module}.") + parent_module = curr_module + key_in_parent = target_module + curr_module = getattr(curr_module, target_module) + return curr_module, key_in_parent, parent_module class GradientAccumulationModifier(ConfigModifier): @@ -100,18 +121,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: """ for module_name, remat_spec in self._remat_policies.items(): - # Here we assume x.y.z format. - # One example would be model.decoder.transformer.layer. - target_modules = module_name.split(".") - curr_module = cfg - for target_module in target_modules: - if not hasattr(curr_module, target_module): - raise ValueError(f"{target_module} is not found in {curr_module}.") - curr_module = getattr(curr_module, target_module) + found_module, _, _ = find_target_module(module_name, cfg) # Here we assume all modules have remat_spec attribute. - if not hasattr(curr_module, "remat_spec"): - raise ValueError(f"{curr_module} does not have remat_spec attribute") - curr_module.remat_spec = remat_spec + if not hasattr(found_module, "remat_spec"): + raise ValueError(f"{found_module} does not have remat_spec attribute") + found_module.remat_spec = remat_spec return cfg @@ -146,6 +160,100 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: return cfg +class ModelConfigModifier(ConfigModifier): + """Update the model config for the trainer config.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure ModelConfigModifier. + + Attributes: + model_cfg_modifications: A mapping from module path + (e.g. `model.decoder.transformer.layer`) to a Config. + """ + + model_cfg_modifications: Required[Dict[str, ConfigBase]] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + cfg = self.config + self._model_cfg_modifications = cfg.model_cfg_modifications + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Overwrite the model config of the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + + Returns: + The modified trainer config. + """ + + for module_name, model_cfg in self._model_cfg_modifications.items(): + # No modification if None + if not model_cfg: + continue + + found_module, key_in_parent, parent_module = find_target_module(module_name, cfg) + + # Copy configurations from the config being replaced on a best effort basis + for key in model_cfg.keys(): + if key == "klass": + continue + elif hasattr(found_module, key) and hasattr(model_cfg, key): + setattr(model_cfg, key, getattr(found_module, key)) + # Replace in the parent config + setattr(parent_module, key_in_parent, model_cfg) + return cfg + + +class ParameterPartitionSpecModifier(ConfigModifier): + """Update the parameter partition spec for specified modules.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure ParameterPartitionSpecModifier. + + Attributes: + remat_policies: A mapping from module path + (e.g. `model.decoder.transformer.layer`) to PartitionSpec. + """ + + partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + cfg = self.config + self._partition_specs = cfg.partition_specs + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Update the param_partition_spec for the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + ValueError: The partition_spec attribute is not found. + + Returns: + The modified trainer config. + """ + + for module_name, param_partition_spec in self._partition_specs.items(): + found_module, _, _ = find_target_module(module_name, cfg) + + # Here we assume all modules have param_partition_spec attribute. + if not hasattr(found_module, "param_partition_spec"): + raise ValueError(f"{found_module} does not have param_partition_spec attribute") + + found_module.param_partition_spec = param_partition_spec + return cfg + + class ChainConfigModifier(ConfigModifier): """Chain multiple config modifiers together.""" diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py index ccfe00823..717d60693 100644 --- a/axlearn/common/trainer_config_modifier_test.py +++ b/axlearn/common/trainer_config_modifier_test.py @@ -5,13 +5,16 @@ import jax from absl.testing import absltest -from axlearn.common import test_utils +from axlearn.common import causal_lm, test_utils +from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer from axlearn.common.base_layer import RematSpec from axlearn.common.trainer import SpmdTrainer from axlearn.common.trainer_config_modifier import ( ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + ParameterPartitionSpecModifier, RematSpecModifier, ) from axlearn.common.trainer_test import DummyModel @@ -65,6 +68,73 @@ def test_remat_policy_override(self): _ = cfg_modifier(cfg) +class ModelConfigModifierTest(test_utils.TestCase): + def test_model_config_override(self): + cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config()) + self.assertTrue( + str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config()) + ) + + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + model_cfg_modifications={ + "model.decoder.transformer": RepeatedTransformerLayer.default_config(), + } + ) + .instantiate() + ) + + cfg = cfg_modifier(cfg) + # The default StackedTransformerLayer should have changed to RepeatedTransformerLayer + self.assertTrue( + str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config()) + ) + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + model_cfg_modifications={ + "model.decoder.unknown": RepeatedTransformerLayer.default_config(), + } + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + +class ParameterPartitionSpecModifierTest(test_utils.TestCase): + def test_parameter_partition_spec_override(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + ParameterPartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": ("model", ("expert", "fsdp", "seq")), + }, + ) + .instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertTrue( + str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")""" + ) + cfg_modifier = ( + ParameterPartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": ("model", ("expert", "fsdp", "seq")), + "model.unknown": ("model", ("expert", "fsdp", "seq")), + }, + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + class MeshShapeModifierTest(test_utils.TestCase): def test_mesh_shape_update(self): cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index f8e50909a..5fc7b8fac 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index 2a831b28c..c5ab8dfea 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 0a470ccea..40da2c204 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index c4c6eed38..83a02b693 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index d06cfb3c7..933129583 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 8d5dc4e92..2e28dde83 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 8d7e8f710..cf0e75159 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index 53ef5d052..f9734ac6c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index ade5f1af2..3a2dcd41c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -135,6 +135,69 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear']: None +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ParameterPartitionSpecModifier' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][2]: 'seq' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][2]: 'seq' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -285,7 +348,7 @@ model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layer model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 +model.decoder.transformer.num_layers: 1 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index a986f1d08..0712dfbab 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -135,6 +135,69 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear']: None +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ParameterPartitionSpecModifier' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][2]: 'seq' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][2]: 'seq' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -252,7 +315,7 @@ model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layer model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 +model.decoder.transformer.num_layers: 1 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index 03fc3428a..d04e473f4 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -135,6 +135,74 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ParameterPartitionSpecModifier' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][2]: 'seq' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][2]: 'seq' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -286,7 +354,7 @@ model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layer model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 +model.decoder.transformer.num_layers: 1 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 1ecf7529f..629759d23 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -135,6 +135,74 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ParameterPartitionSpecModifier' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][2]: 'seq' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][2]: 'seq' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -253,7 +321,7 @@ model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layer model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 +model.decoder.transformer.num_layers: 1 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 76193c0db..316e557b3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -135,6 +135,74 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ParameterPartitionSpecModifier' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][2]: 'seq' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][2]: 'seq' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -286,7 +354,7 @@ model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layer model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 +model.decoder.transformer.num_layers: 1 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index 45bdb8e66..43fef87cf 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -135,6 +135,74 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[2][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ParameterPartitionSpecModifier' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb'][1][2]: 'seq' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][0]: 'model' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][0]: 'expert' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][1]: 'fsdp' +mesh_rules[2][1].config_modifiers[2].partition_specs['model.decoder.lm_head'][1][2]: 'seq' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -253,7 +321,7 @@ model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layer model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 +model.decoder.transformer.num_layers: 1 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt index 98cd9261c..f73a6e23f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt @@ -178,6 +178,60 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear']: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt index 0a62cc2b1..39134ba97 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt @@ -178,6 +178,60 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear']: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt index a9e1f38ed..0258dc98f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt @@ -178,6 +178,60 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear']: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt index 87736a6f5..adeaa9ace 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt @@ -178,6 +178,60 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear']: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt index e01051cac..3108bb7d8 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt index 964f23e23..97daeee4a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt index 17f97ab30..372ddfc5a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt index 438da62a1..730838ca2 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt index a5b50a240..6d4254e41 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt index da5826693..e85af7de3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt index 811b565e5..5ee582a67 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt index b71e46c9d..d36011068 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..6963dfb86 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -22,10 +22,12 @@ BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, + GroupedQKVLinear, GroupedQueryAttention, MultiheadAttention, RepeatedTransformerLayer, RoFormerQKVLinear, + StackedTransformerLayer, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -37,6 +39,8 @@ ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + ParameterPartitionSpecModifier, RematSpecModifier, ) from axlearn.common.utils import extended_checkpoint_policies @@ -130,6 +134,16 @@ def get_trainer_kwargs( rope_theta = ROPE_THETA[version] + # TRN2 specific model config modifications + trn2_model_modifications = { + # Neuron compiler has a module to detect repeating blocks and reuse them during compilation. + # So compile time does not grow with the number of layers. + "model.decoder.transformer": StackedTransformerLayer.default_config(), + "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": ( + None if version == Version.V1 else GroupedQKVLinear.default_config() + ), + } + offload_dots_saveable_policy = config_for_function( extended_checkpoint_policies.offload_dots_saveable ).set(offload_src="device", offload_dst="pinned_host") @@ -174,6 +188,23 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), + ), ) elif model_size == "3B": trainer_kwargs = dict( @@ -192,6 +223,23 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), + ), ) elif model_size == "7B": trainer_kwargs = dict( @@ -287,6 +335,21 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), ), ) elif model_size == "8B": @@ -367,12 +430,27 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), ), ) elif model_size == "70B": trainer_kwargs = dict( model_kwargs=dict( - num_layers=80, + num_layers=1, hidden_dim=128 * 64, num_heads=64, # No GQA support in V1 models, so num_kv_heads is the same as num_heads. @@ -417,6 +495,31 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ParameterPartitionSpecModifier.default_config().set( + partition_specs={ + # Vocab parallel embeddings sharding from Megatron LM + "model.decoder.emb.token_emb": ( + "model", + ("expert", "fsdp", "seq"), + ), + "model.decoder.lm_head": ("model", ("expert", "fsdp", "seq")), + }, + ), + ], + ), + ), ), ) else: