Skip to content

Commit

Permalink
TRN2 Meshes and Configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 10, 2025
1 parent 2d1fb29 commit d481132
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 14 deletions.
132 changes: 120 additions & 12 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import (
REQUIRED,
ConfigBase,
ConfigModifier,
ConfigOr,
Required,
Expand All @@ -17,7 +18,27 @@
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec


def find_target_module(
module_name: str, cfg: SpmdTrainer.Config
) -> tuple[ConfigModifier.Config, ConfigModifier.Config, ConfigModifier.Config]:
"""Recursively search for the target module matching module name in provided config"""
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
key_in_parent = None
parent_module = None

for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
parent_module = curr_module
key_in_parent = target_module
curr_module = getattr(curr_module, target_module)
return curr_module, key_in_parent, parent_module


class GradientAccumulationModifier(ConfigModifier):
Expand Down Expand Up @@ -100,18 +121,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""

for module_name, remat_spec in self._remat_policies.items():
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
curr_module = getattr(curr_module, target_module)
found_module, _, _ = find_target_module(module_name, cfg)
# Here we assume all modules have remat_spec attribute.
if not hasattr(curr_module, "remat_spec"):
raise ValueError(f"{curr_module} does not have remat_spec attribute")
curr_module.remat_spec = remat_spec
if not hasattr(found_module, "remat_spec"):
raise ValueError(f"{found_module} does not have remat_spec attribute")
found_module.remat_spec = remat_spec
return cfg


Expand Down Expand Up @@ -146,6 +160,100 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModelConfigModifier.
Attributes:
model_cfg_modifications: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to a Config.
"""

model_cfg_modifications: Required[Dict[str, ConfigBase]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._model_cfg_modifications = cfg.model_cfg_modifications

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the model config of the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
Returns:
The modified trainer config.
"""

for module_name, model_cfg in self._model_cfg_modifications.items():
# No modification if None
if not model_cfg:
continue

found_module, key_in_parent, parent_module = find_target_module(module_name, cfg)

# Copy configurations from the config being replaced on a best effort basis
for key in model_cfg.keys():
if key == "klass":
continue
elif hasattr(found_module, key) and hasattr(model_cfg, key):
setattr(model_cfg, key, getattr(found_module, key))
# Replace in the parent config
setattr(parent_module, key_in_parent, model_cfg)
return cfg


class ParameterPartitionSpecModifier(ConfigModifier):
"""Update the parameter partition spec for specified modules."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ParameterPartitionSpecModifier.
Attributes:
remat_policies: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to PartitionSpec.
"""

partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._partition_specs = cfg.partition_specs

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Update the param_partition_spec for the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
ValueError: The partition_spec attribute is not found.
Returns:
The modified trainer config.
"""

for module_name, param_partition_spec in self._partition_specs.items():
found_module, _, _ = find_target_module(module_name, cfg)

# Here we assume all modules have param_partition_spec attribute.
if not hasattr(found_module, "param_partition_spec"):
raise ValueError(f"{found_module} does not have param_partition_spec attribute")

found_module.param_partition_spec = param_partition_spec
return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
72 changes: 71 additions & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModelConfigModifier,
ParameterPartitionSpecModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel
Expand Down Expand Up @@ -65,6 +68,73 @@ def test_remat_policy_override(self):
_ = cfg_modifier(cfg)


class ModelConfigModifierTest(test_utils.TestCase):
def test_model_config_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
self.assertTrue(
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
)

cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.transformer": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
self.assertTrue(
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
)
cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.unknown": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class ParameterPartitionSpecModifierTest(test_utils.TestCase):
def test_parameter_partition_spec_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
ParameterPartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": ("model", ("expert", "fsdp", "seq")),
},
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertTrue(
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
)
cfg_modifier = (
ParameterPartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": ("model", ("expert", "fsdp", "seq")),
"model.unknown": ("model", ("expert", "fsdp", "seq")),
},
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
Expand Down
Loading

0 comments on commit d481132

Please sign in to comment.