Skip to content

Commit 6b404f6

Browse files
committed
TRN2 Meshes and Configurations
1 parent 2d1fb29 commit 6b404f6

File tree

3 files changed

+178
-1
lines changed

3 files changed

+178
-1
lines changed

axlearn/common/trainer_config_modifier.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from axlearn.common.base_layer import RematSpec
99
from axlearn.common.config import (
1010
REQUIRED,
11+
ConfigBase,
1112
ConfigModifier,
1213
ConfigOr,
1314
Required,
@@ -146,6 +147,67 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
146147
return cfg
147148

148149

150+
class ModelConfigModifier(ConfigModifier):
151+
"""Update the model config for the trainer config."""
152+
153+
@config_class
154+
class Config(ConfigModifier.Config):
155+
"""Configure ModelConfigModifier.
156+
157+
Attributes:
158+
model_cfg_modifications: A mapping from module path
159+
(e.g. `model.decoder.transformer.layer`) to a Config.
160+
"""
161+
162+
model_cfg_modifications: Required[Dict[str, ConfigBase]] = REQUIRED
163+
164+
def __init__(self, cfg: Config):
165+
super().__init__(cfg)
166+
cfg = self.config
167+
self._model_cfg_modifications = cfg.model_cfg_modifications
168+
169+
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
170+
"""Overwrite the model config of the specified modules.
171+
172+
Args:
173+
cfg: The trainer config to be modified.
174+
175+
Raises:
176+
ValueError: The target module is not found.
177+
178+
Returns:
179+
The modified trainer config.
180+
"""
181+
182+
for module_name, model_cfg in self._model_cfg_modifications.items():
183+
# No modification if None
184+
if not model_cfg:
185+
continue
186+
# Here we assume x.y.z format.
187+
# One example would be model.decoder.transformer.layer.
188+
target_modules = module_name.split(".")
189+
curr_module = cfg
190+
target_module_in_parent = None
191+
parent_module = None
192+
193+
for target_module in target_modules:
194+
if not hasattr(curr_module, target_module):
195+
raise ValueError(f"{target_module} is not found in {curr_module}.")
196+
parent_module = curr_module
197+
target_module_in_parent = target_module
198+
curr_module = getattr(curr_module, target_module)
199+
200+
# Copy configurations from the config being replaced on a best effort basis
201+
for key in model_cfg.keys():
202+
if key == "klass":
203+
continue
204+
elif hasattr(curr_module, key) and hasattr(curr_module, key):
205+
setattr(model_cfg, key, getattr(curr_module, key))
206+
# Replace in the parent config
207+
setattr(parent_module, target_module_in_parent, model_cfg)
208+
return cfg
209+
210+
149211
class ChainConfigModifier(ConfigModifier):
150212
"""Chain multiple config modifiers together."""
151213

axlearn/common/trainer_config_modifier_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import jax
66
from absl.testing import absltest
77

8-
from axlearn.common import test_utils
8+
from axlearn.common import causal_lm, test_utils
9+
from axlearn.common.attention import RepeatedTransformerLayer
910
from axlearn.common.base_layer import RematSpec
1011
from axlearn.common.trainer import SpmdTrainer
1112
from axlearn.common.trainer_config_modifier import (
1213
ChainConfigModifier,
1314
GradientAccumulationModifier,
1415
MeshShapeModifier,
16+
ModelConfigModifier,
1517
RematSpecModifier,
1618
)
1719
from axlearn.common.trainer_test import DummyModel
@@ -65,6 +67,27 @@ def test_remat_policy_override(self):
6567
_ = cfg_modifier(cfg)
6668

6769

70+
class ModelConfigModifierTest(test_utils.TestCase):
71+
def test_model_config_override(self):
72+
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
73+
print(cfg)
74+
self.assertRegex(str(cfg.model.decoder), ".*StackedTransformerLayer")
75+
76+
cfg_modifier = (
77+
ModelConfigModifier.default_config()
78+
.set(
79+
model_cfg_modifications={
80+
"model.decoder.transformer": RepeatedTransformerLayer.default_config(),
81+
}
82+
)
83+
.instantiate()
84+
)
85+
86+
cfg = cfg_modifier(cfg)
87+
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
88+
self.assertRegex(str(cfg.model.decoder), ".*RepeatedTransformerLayer")
89+
90+
6891
class MeshShapeModifierTest(test_utils.TestCase):
6992
def test_mesh_shape_update(self):
7093
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())

axlearn/experiments/text/gpt/fuji.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
BaseStackedTransformerLayer,
2323
FusedGroupedQKVLinear,
2424
FusedQKVLinear,
25+
GroupedQKVLinear,
2526
GroupedQueryAttention,
2627
MultiheadAttention,
2728
RepeatedTransformerLayer,
2829
RoFormerQKVLinear,
30+
StackedTransformerLayer,
2931
)
3032
from axlearn.common.base_layer import RematSpec
3133
from axlearn.common.config import config_for_function
@@ -37,6 +39,7 @@
3739
ChainConfigModifier,
3840
GradientAccumulationModifier,
3941
MeshShapeModifier,
42+
ModelConfigModifier,
4043
RematSpecModifier,
4144
)
4245
from axlearn.common.utils import extended_checkpoint_policies
@@ -130,6 +133,16 @@ def get_trainer_kwargs(
130133

131134
rope_theta = ROPE_THETA[version]
132135

136+
# TRN2 specific model config modifications
137+
trn2_model_modifications = {
138+
# Neuron compiler has a module to detect repeating blocks and reuse them during compilation.
139+
# So compile time does not grow with the number of layers.
140+
"model.decoder.transformer": StackedTransformerLayer.default_config(),
141+
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
142+
None if version == Version.V1 else GroupedQKVLinear.default_config()
143+
),
144+
}
145+
133146
offload_dots_saveable_policy = config_for_function(
134147
extended_checkpoint_policies.offload_dots_saveable
135148
).set(offload_src="device", offload_dst="pinned_host")
@@ -174,6 +187,23 @@ def get_trainer_kwargs(
174187
train_batch_size=train_batch_size,
175188
max_step=max_step,
176189
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
190+
mesh_rules=(
191+
(
192+
"neuron-(trn2|trn2n).48xlarge-64",
193+
ChainConfigModifier.default_config().set(
194+
config_modifiers=[
195+
MeshShapeModifier.default_config().set(
196+
# TP within the chip, FSDP across chips.
197+
# Each TRN2 chip has 4 XLA cores.
198+
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
199+
),
200+
ModelConfigModifier.default_config().set(
201+
model_cfg_modifications=trn2_model_modifications
202+
),
203+
],
204+
),
205+
),
206+
),
177207
)
178208
elif model_size == "3B":
179209
trainer_kwargs = dict(
@@ -192,6 +222,23 @@ def get_trainer_kwargs(
192222
train_batch_size=train_batch_size,
193223
max_step=max_step,
194224
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
225+
mesh_rules=(
226+
(
227+
"neuron-(trn2|trn2n).48xlarge-64",
228+
ChainConfigModifier.default_config().set(
229+
config_modifiers=[
230+
MeshShapeModifier.default_config().set(
231+
# TP within the chip, FSDP across chips.
232+
# Each TRN2 chip has 4 XLA cores.
233+
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
234+
),
235+
ModelConfigModifier.default_config().set(
236+
model_cfg_modifications=trn2_model_modifications
237+
),
238+
],
239+
),
240+
),
241+
),
195242
)
196243
elif model_size == "7B":
197244
trainer_kwargs = dict(
@@ -287,6 +334,21 @@ def get_trainer_kwargs(
287334
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
288335
mesh_shape_from_axes(data=-1, fsdp=8),
289336
),
337+
(
338+
"neuron-(trn2|trn2n).48xlarge-64",
339+
ChainConfigModifier.default_config().set(
340+
config_modifiers=[
341+
MeshShapeModifier.default_config().set(
342+
# TP within the chip, FSDP across chips.
343+
# Each TRN2 chip has 4 XLA cores.
344+
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
345+
),
346+
ModelConfigModifier.default_config().set(
347+
model_cfg_modifications=trn2_model_modifications
348+
),
349+
],
350+
),
351+
),
290352
),
291353
)
292354
elif model_size == "8B":
@@ -367,6 +429,21 @@ def get_trainer_kwargs(
367429
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
368430
mesh_shape_from_axes(data=-1, fsdp=8),
369431
),
432+
(
433+
"neuron-(trn2|trn2n).48xlarge-64",
434+
ChainConfigModifier.default_config().set(
435+
config_modifiers=[
436+
MeshShapeModifier.default_config().set(
437+
# TP within the chip, FSDP across chips.
438+
# Each TRN2 chip has 4 XLA cores.
439+
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
440+
),
441+
ModelConfigModifier.default_config().set(
442+
model_cfg_modifications=trn2_model_modifications
443+
),
444+
],
445+
),
446+
),
370447
),
371448
)
372449
elif model_size == "70B":
@@ -417,6 +494,21 @@ def get_trainer_kwargs(
417494
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
418495
mesh_shape_from_axes(data=-1, fsdp=128),
419496
),
497+
(
498+
"neuron-(trn2|trn2n).48xlarge-64",
499+
ChainConfigModifier.default_config().set(
500+
config_modifiers=[
501+
MeshShapeModifier.default_config().set(
502+
# TP within the chip, FSDP across chips.
503+
# Each TRN2 chip has 4 XLA cores.
504+
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
505+
),
506+
ModelConfigModifier.default_config().set(
507+
model_cfg_modifications=trn2_model_modifications
508+
),
509+
],
510+
),
511+
),
420512
),
421513
)
422514
else:

0 commit comments

Comments
 (0)