Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRN2 Meshes and Configurations #916

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
38 changes: 37 additions & 1 deletion axlearn/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Config(ConfigBase):
from collections import defaultdict
from collections.abc import Collection, Iterable
from functools import cache
from typing import Any, Callable, Generic, Optional, TypeVar, Union
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union

# attr provides similar features as Python dataclass. Unlike
# dataclass, however, it provides a richer set of features to regulate
Expand Down Expand Up @@ -394,6 +394,42 @@ def set(self, **kwargs):
setattr(self, k, v)
return self

def get_recursively(self, path: Sequence[str]) -> Any:
"""Recursively find the target key in the config and return its value.

Args:
path: A sequence of keys for indexing to get the target value.

Raises:
AttributeError: If key in path is not found.

Returns:
value at the path or self if path is empty.
"""
current = self

for key in path:
# TODO(markblee): Maybe use cfg.visit instead of getattr.
current = getattr(current, key)

return current

def set_recursively(self, path: Sequence[str], *, value: Any):
"""Recursively find the target key in the config and set its value.

Args:
path: A sequence of keys for indexing to set the target value.
new_value: New value to replace the target value.

Raises:
ValueError: if Path is empty.
AttributeError: If key in path is not found.
"""
if not path:
raise ValueError("Path is empty.")
parent = self.get_recursively(path[:-1])
setattr(parent, path[-1], value)

def clone(self, **kwargs):
"""Returns a clone of the original config with the optional keyword overrides.

Expand Down
65 changes: 65 additions & 0 deletions axlearn/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,71 @@ def set(self, **kwargs):
self.assertEqual(123, cfg_clone.a)
self.assertEqual("default", cfg_clone.b)

def test_get_recursively(self):
class Nested(Configurable):
@config_class
class Config(Configurable.Config):
"""A dummy config."""

value: int = 0

class Test(Configurable):
@config_class
class Config(Configurable.Config):
"""Another dummy config that has a nested config."""

nested: Nested.Config = Nested.default_config()
value: int = 1

cfg = Test.default_config()

# Test getting nested value.
self.assertEqual(cfg.get_recursively(["nested", "value"]), 0)

# Test getting top-level value.
self.assertEqual(cfg.get_recursively(["value"]), 1)

# Test getting non-existent value.
with self.assertRaises(AttributeError):
cfg.get_recursively(["non_existent"])

# Test getting empty path, should return self.
self.assertEqual(cfg.get_recursively([]), cfg)

def test_set_recursively(self):
class Nested(Configurable):
@config_class
class Config(Configurable.Config):
"""A dummy config."""

value: int = 0

class Test(Configurable):
@config_class
class Config(Configurable.Config):
"""Another dummy config that has a nested config."""

nested: Nested.Config = Nested.default_config()
value: int = 1

cfg = Test.default_config()

# Test setting nested value.
cfg.set_recursively(["nested", "value"], value=10)
self.assertEqual(cfg.nested.value, 10)

# Test setting top-level value.
cfg.set_recursively(["value"], value=5)
self.assertEqual(cfg.value, 5)

# Test setting non-existent value.
with self.assertRaises(AttributeError):
cfg.set_recursively(["non_existent"], value=20)

# Test setting empty path.
with self.assertRaises(ValueError):
cfg.set_recursively([], value=20)


if __name__ == "__main__":
absltest.main()
124 changes: 111 additions & 13 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
REQUIRED,
ConfigModifier,
ConfigOr,
Configurable,
Required,
config_class,
maybe_instantiate,
)
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec


class GradientAccumulationModifier(ConfigModifier):
Expand Down Expand Up @@ -100,18 +101,8 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""

for module_name, remat_spec in self._remat_policies.items():
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
curr_module = getattr(curr_module, target_module)
# Here we assume all modules have remat_spec attribute.
if not hasattr(curr_module, "remat_spec"):
raise ValueError(f"{curr_module} does not have remat_spec attribute")
curr_module.remat_spec = remat_spec
cfg.set_recursively(module_name.split(".") + ["remat_spec"], value=remat_spec)

return cfg


Expand Down Expand Up @@ -146,6 +137,113 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModuleConfigModifier(ConfigModifier):
"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModuleConfigModifier.

Attributes:
target_config: Target module path
(e.g. `model.decoder.transformer.layer`) to be modified.
modification: The new config to replace the target module's config.
"""

target_config: Required[str] = REQUIRED
modification: Required[Configurable.Config] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
self._target_config = self.config.target_config
self._modification = self.config.modification

def _merge_configs(
self, target_cfg: Configurable.Config, found_module: Configurable.Config
) -> Configurable.Config:
"""Merge configurations from the config being replaced on a best effort basis.
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved

Merge Rules:
- Klass is not changed, use target cfg.
- If field exists in both then use from class being replaced.
- Otherwise keep the value from target_cfg.

Args:
target_cfg: Configuration that will replace found_module.
found_module: Existing configuration whose class will be replaced
but it's confguration will be merged with target_cfg.

Returns:
The modified config.

"""
for key in target_cfg.keys():
if key == "klass":
continue
elif hasattr(found_module, key) and hasattr(target_cfg, key):
setattr(target_cfg, key, getattr(found_module, key))
return target_cfg

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the model config of the specified modules.

Args:
cfg: The trainer config to be modified.

Raises:
ValueError: The target module is not found.

Returns:
The modified trainer config.
"""

found_module = cfg.get_recursively(self._target_config.split("."))
self._modification = self._merge_configs(self._modification, found_module)
cfg.set_recursively(self._target_config.split("."), value=self._modification)
return cfg


class PartitionSpecModifier(ConfigModifier):
"""Update the partition spec attribute for the specified modules."""

@config_class
class Config(ConfigModifier.Config):
"""Configure PartitionSpecModifier.

Attributes:
partition_specs: A nested mapping from module path
(e.g. `model.decoder.transformer.layer`) to another
mapping of model attribute to PartitionSpec.
"""

partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
self._attribute_dicts = self.config.partition_specs

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Update the partition_spec attributes for the specified modules.

Args:
cfg: The trainer config to be modified.

Raises:
ValueError: The target module is not found.
ValueError: The partition_spec attribute is not found.

Returns:
The modified trainer config.
"""
for module_name, partition_spec_dict in self._attribute_dicts.items():
for partition_spec_name, partition_spec in partition_spec_dict.items():
cfg.set_recursively(
module_name.split(".") + [partition_spec_name], value=partition_spec
)

return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
87 changes: 85 additions & 2 deletions axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModuleConfigModifier,
PartitionSpecModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel
Expand Down Expand Up @@ -61,7 +64,87 @@ def test_remat_policy_override(self):
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"):
_ = cfg_modifier(cfg)


class ModuleConfigModifierTest(test_utils.TestCase):
def test_model_config_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
self.assertTrue(
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
)

cfg_modifier = (
ModuleConfigModifier.default_config()
.set(
target_config="model.decoder.transformer",
modification=RepeatedTransformerLayer.default_config(),
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
)
cfg_modifier = (
ModuleConfigModifier.default_config()
.set(
target_config="model.decoder.unknown",
modification=RepeatedTransformerLayer.default_config(),
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"):
_ = cfg_modifier(cfg)


class PartitionSpecModifierTest(test_utils.TestCase):
def test_partition_spec_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertTrue(
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
)
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
"model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"):
_ = cfg_modifier(cfg)

cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {
"param_partition_spec": ("model", ("expert", "fsdp", "seq")),
"unknown_partition_spec": ("model", ("expert", "fsdp", "seq")),
},
},
)
.instantiate()
)
with self.assertRaisesRegex(AttributeError, "unknown_partition_spec *"):
_ = cfg_modifier(cfg)


Expand Down
Loading