Skip to content

Commit 45c7df1

Browse files
apoorvtintindgourab-aws
authored andcommitted
TRN2 Meshes and Configurations
1 parent 3405a6e commit 45c7df1

29 files changed

+2013
-14
lines changed

axlearn/common/trainer_config_modifier.py

Lines changed: 160 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
"""Defines trainer config modifiers, which will be used in model definitions."""
44

5-
from typing import Dict, Sequence, Union
5+
from typing import Callable, Dict, NamedTuple, Sequence, Union
66

77
from axlearn.common import config
88
from axlearn.common.base_layer import RematSpec
99
from axlearn.common.config import (
1010
REQUIRED,
11+
ConfigBase,
1112
ConfigModifier,
1213
ConfigOr,
1314
Required,
@@ -17,7 +18,53 @@
1718
from axlearn.common.gradient_accumulation import with_minibatch_steps
1819
from axlearn.common.metrics import MetricAccumulator
1920
from axlearn.common.trainer import SpmdTrainer
20-
from axlearn.common.utils import HybridMeshShape, MeshShape
21+
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec
22+
23+
24+
class _FoundModule(NamedTuple):
25+
"""Module found in recursive search of matching module."""
26+
27+
# The module found
28+
module: ConfigModifier.Config
29+
# The parent of the module found
30+
parent_module: ConfigModifier.Config
31+
# Key of the found module in parent
32+
key_in_parent: ConfigModifier.Config
33+
34+
35+
def _find_target_module(module_name: str, cfg: SpmdTrainer.Config) -> _FoundModule:
36+
"""Recursively search for the target module matching module_name in provided cfg.
37+
38+
Args:
39+
module_name: Name of the target module
40+
cfg: The trainer config to be searched for module_name
41+
42+
Raises:
43+
ValueError: The module_name is not found.
44+
45+
Returns:
46+
A Tuple(curr_module, key_in_parent, parent_module)
47+
curr_module: Module found
48+
parent_module: The parent module
49+
key_in_parent: Key in parent for the found module
50+
"""
51+
52+
# Here we assume x.y.z format.
53+
# One example would be model.decoder.transformer.layer.
54+
target_modules = module_name.split(".")
55+
curr_module = cfg
56+
key_in_parent = None
57+
parent_module = None
58+
59+
for target_module in target_modules:
60+
if not hasattr(curr_module, target_module):
61+
raise ValueError(f"{target_module} is not found in {curr_module}.")
62+
parent_module = curr_module
63+
key_in_parent = target_module
64+
curr_module = getattr(curr_module, target_module)
65+
return _FoundModule(
66+
module=curr_module, parent_module=parent_module, key_in_parent=key_in_parent
67+
)
2168

2269

2370
class GradientAccumulationModifier(ConfigModifier):
@@ -100,18 +147,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
100147
"""
101148

102149
for module_name, remat_spec in self._remat_policies.items():
103-
# Here we assume x.y.z format.
104-
# One example would be model.decoder.transformer.layer.
105-
target_modules = module_name.split(".")
106-
curr_module = cfg
107-
for target_module in target_modules:
108-
if not hasattr(curr_module, target_module):
109-
raise ValueError(f"{target_module} is not found in {curr_module}.")
110-
curr_module = getattr(curr_module, target_module)
150+
found_module = _find_target_module(module_name, cfg)
111151
# Here we assume all modules have remat_spec attribute.
112-
if not hasattr(curr_module, "remat_spec"):
113-
raise ValueError(f"{curr_module} does not have remat_spec attribute")
114-
curr_module.remat_spec = remat_spec
152+
if not hasattr(found_module.module, "remat_spec"):
153+
raise ValueError(f"{found_module.module} does not have remat_spec attribute")
154+
found_module.module.remat_spec = remat_spec
115155
return cfg
116156

117157

@@ -146,6 +186,113 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
146186
return cfg
147187

148188

189+
class ModelConfigModifier(ConfigModifier):
190+
"""Update the model config for the trainer config."""
191+
192+
@config_class
193+
class Config(ConfigModifier.Config):
194+
"""Configure ModelConfigModifier.
195+
196+
Attributes:
197+
model_cfg_modifications: A mapping from module path
198+
(e.g. `model.decoder.transformer.layer`) to a Config.
199+
"""
200+
201+
model_cfg_modifications: Dict[str, Callable[[ConfigBase], ConfigBase]] = {}
202+
203+
def __init__(self, cfg: Config):
204+
super().__init__(cfg)
205+
self._model_cfg_modifications = self.config.model_cfg_modifications
206+
207+
def _merge_configs(self, target_cfg: ConfigBase, found_module: ConfigBase) -> ConfigBase:
208+
"""Merge configurations from the config being replaced on a best effort basis.
209+
210+
Args:
211+
target_cfg: configuration that will replace found_module.
212+
found_module: existing configuration whose class will be replaced
213+
but it's confguration will be merged with target_cfg.
214+
215+
Returns:
216+
The modified config.
217+
218+
"""
219+
for key in target_cfg.keys():
220+
if key == "klass":
221+
continue
222+
elif hasattr(found_module.module, key) and hasattr(target_cfg, key):
223+
setattr(target_cfg, key, getattr(found_module.module, key))
224+
return target_cfg
225+
226+
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
227+
"""Overwrite the model config of the specified modules.
228+
229+
Args:
230+
cfg: The trainer config to be modified.
231+
232+
Raises:
233+
ValueError: The target module is not found.
234+
235+
Returns:
236+
The modified trainer config.
237+
"""
238+
239+
# Iterate over modules in the mapping, modules are sorted based on module name length.
240+
# This ensures parent is modified before children to avoid missing modifications.
241+
for module_name, model_cfg in sorted(
242+
self._model_cfg_modifications.items(), key=lambda item: len(item[0])
243+
):
244+
found_module = _find_target_module(module_name, cfg)
245+
246+
model_cfg = self._merge_configs(model_cfg, found_module)
247+
# Replace in the parent config
248+
setattr(found_module.parent_module, found_module.key_in_parent, model_cfg)
249+
return cfg
250+
251+
252+
class PartitionSpecModifier(ConfigModifier):
253+
"""Update the partition spec attribute for the specified modules."""
254+
255+
@config_class
256+
class Config(ConfigModifier.Config):
257+
"""Configure PartitionSpecModifier.
258+
259+
Attributes:
260+
partition_specs: A nested mapping from module path
261+
(e.g. `model.decoder.transformer.layer`) to another
262+
mapping of model attribute to PartitionSpec.
263+
"""
264+
265+
partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED
266+
267+
def __init__(self, cfg: Config):
268+
super().__init__(cfg)
269+
self._attribute_dicts = self.config.partition_specs
270+
271+
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
272+
"""Update the partition_spec attributes for the specified modules.
273+
274+
Args:
275+
cfg: The trainer config to be modified.
276+
277+
Raises:
278+
ValueError: The target module is not found.
279+
ValueError: The partition_spec attribute is not found.
280+
281+
Returns:
282+
The modified trainer config.
283+
"""
284+
for module_name, partition_spec_dict in self._attribute_dicts.items():
285+
found_module = _find_target_module(module_name, cfg)
286+
for partition_spec_name, partition_spec in partition_spec_dict.items():
287+
if not hasattr(found_module.module, partition_spec_name):
288+
raise ValueError(
289+
f"{found_module.module} does not have {partition_spec_name} attribute"
290+
)
291+
setattr(found_module.module, partition_spec_name, partition_spec)
292+
293+
return cfg
294+
295+
149296
class ChainConfigModifier(ConfigModifier):
150297
"""Chain multiple config modifiers together."""
151298

axlearn/common/trainer_config_modifier_test.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import jax
66
from absl.testing import absltest
77

8-
from axlearn.common import test_utils
8+
from axlearn.common import causal_lm, test_utils
9+
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
910
from axlearn.common.base_layer import RematSpec
1011
from axlearn.common.trainer import SpmdTrainer
1112
from axlearn.common.trainer_config_modifier import (
1213
ChainConfigModifier,
1314
GradientAccumulationModifier,
1415
MeshShapeModifier,
16+
ModelConfigModifier,
17+
PartitionSpecModifier,
1518
RematSpecModifier,
1619
)
1720
from axlearn.common.trainer_test import DummyModel
@@ -65,6 +68,88 @@ def test_remat_policy_override(self):
6568
_ = cfg_modifier(cfg)
6669

6770

71+
class ModelConfigModifierTest(test_utils.TestCase):
72+
def test_model_config_override(self):
73+
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
74+
self.assertTrue(
75+
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
76+
)
77+
78+
cfg_modifier = (
79+
ModelConfigModifier.default_config()
80+
.set(
81+
model_cfg_modifications={
82+
"model.decoder.transformer": RepeatedTransformerLayer.default_config(),
83+
}
84+
)
85+
.instantiate()
86+
)
87+
88+
cfg = cfg_modifier(cfg)
89+
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
90+
self.assertTrue(
91+
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
92+
)
93+
cfg_modifier = (
94+
ModelConfigModifier.default_config()
95+
.set(
96+
model_cfg_modifications={
97+
"model.decoder.unknown": RepeatedTransformerLayer.default_config(),
98+
}
99+
)
100+
.instantiate()
101+
)
102+
# Ensure that the exception is working.
103+
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
104+
_ = cfg_modifier(cfg)
105+
106+
107+
class PartitionSpecModifierTest(test_utils.TestCase):
108+
def test_partition_spec_override(self):
109+
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
110+
cfg_modifier = (
111+
PartitionSpecModifier.default_config()
112+
.set(
113+
partition_specs={
114+
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
115+
},
116+
)
117+
.instantiate()
118+
)
119+
cfg = cfg_modifier(cfg)
120+
self.assertTrue(
121+
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
122+
)
123+
cfg_modifier = (
124+
PartitionSpecModifier.default_config()
125+
.set(
126+
partition_specs={
127+
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
128+
"model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
129+
},
130+
)
131+
.instantiate()
132+
)
133+
# Ensure that the exception is working.
134+
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
135+
_ = cfg_modifier(cfg)
136+
137+
cfg_modifier = (
138+
PartitionSpecModifier.default_config()
139+
.set(
140+
partition_specs={
141+
"model.linear": {
142+
"param_partition_spec": ("model", ("expert", "fsdp", "seq")),
143+
"unknown_partition_spec": ("model", ("expert", "fsdp", "seq")),
144+
},
145+
},
146+
)
147+
.instantiate()
148+
)
149+
with self.assertRaisesRegex(ValueError, ".*does not have unknown_partition_spec attribute"):
150+
_ = cfg_modifier(cfg)
151+
152+
68153
class MeshShapeModifierTest(test_utils.TestCase):
69154
def test_mesh_shape_update(self):
70155
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())

0 commit comments

Comments
 (0)