22
22
BaseStackedTransformerLayer ,
23
23
FusedGroupedQKVLinear ,
24
24
FusedQKVLinear ,
25
+ GroupedQKVLinear ,
25
26
GroupedQueryAttention ,
26
27
MultiheadAttention ,
27
28
RepeatedTransformerLayer ,
28
29
RoFormerQKVLinear ,
30
+ StackedTransformerLayer ,
29
31
)
30
32
from axlearn .common .base_layer import RematSpec
31
33
from axlearn .common .config import config_for_function
37
39
ChainConfigModifier ,
38
40
GradientAccumulationModifier ,
39
41
MeshShapeModifier ,
42
+ ModelConfigModifier ,
40
43
RematSpecModifier ,
41
44
)
42
45
from axlearn .common .utils import extended_checkpoint_policies
@@ -130,6 +133,16 @@ def get_trainer_kwargs(
130
133
131
134
rope_theta = ROPE_THETA [version ]
132
135
136
+ # TRN2 specific model config modifications
137
+ trn2_model_modifications = {
138
+ # Neuron compiler has a module to detect repeating blocks and reuse them during compilation.
139
+ # So compile time does not grow with the number of layers.
140
+ "model.decoder.transformer" : StackedTransformerLayer .default_config (),
141
+ "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear" : (
142
+ None if version == Version .V1 else GroupedQKVLinear .default_config ()
143
+ ),
144
+ }
145
+
133
146
offload_dots_saveable_policy = config_for_function (
134
147
extended_checkpoint_policies .offload_dots_saveable
135
148
).set (offload_src = "device" , offload_dst = "pinned_host" )
@@ -174,6 +187,23 @@ def get_trainer_kwargs(
174
187
train_batch_size = train_batch_size ,
175
188
max_step = max_step ,
176
189
mesh_shape = mesh_shape_from_axes (data = - 1 , fsdp = 8 ),
190
+ mesh_rules = (
191
+ (
192
+ "neuron-(trn2|trn2n).48xlarge-64" ,
193
+ ChainConfigModifier .default_config ().set (
194
+ config_modifiers = [
195
+ MeshShapeModifier .default_config ().set (
196
+ # TP within the chip, FSDP across chips.
197
+ # Each TRN2 chip has 4 XLA cores.
198
+ mesh_shape = mesh_shape_from_axes (fsdp = - 1 , model = 4 )
199
+ ),
200
+ ModelConfigModifier .default_config ().set (
201
+ model_cfg_modifications = trn2_model_modifications
202
+ ),
203
+ ],
204
+ ),
205
+ ),
206
+ ),
177
207
)
178
208
elif model_size == "3B" :
179
209
trainer_kwargs = dict (
@@ -192,6 +222,23 @@ def get_trainer_kwargs(
192
222
train_batch_size = train_batch_size ,
193
223
max_step = max_step ,
194
224
mesh_shape = mesh_shape_from_axes (data = - 1 , fsdp = 8 ),
225
+ mesh_rules = (
226
+ (
227
+ "neuron-(trn2|trn2n).48xlarge-64" ,
228
+ ChainConfigModifier .default_config ().set (
229
+ config_modifiers = [
230
+ MeshShapeModifier .default_config ().set (
231
+ # TP within the chip, FSDP across chips.
232
+ # Each TRN2 chip has 4 XLA cores.
233
+ mesh_shape = mesh_shape_from_axes (fsdp = - 1 , model = 4 )
234
+ ),
235
+ ModelConfigModifier .default_config ().set (
236
+ model_cfg_modifications = trn2_model_modifications
237
+ ),
238
+ ],
239
+ ),
240
+ ),
241
+ ),
195
242
)
196
243
elif model_size == "7B" :
197
244
trainer_kwargs = dict (
@@ -287,6 +334,21 @@ def get_trainer_kwargs(
287
334
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)" ,
288
335
mesh_shape_from_axes (data = - 1 , fsdp = 8 ),
289
336
),
337
+ (
338
+ "neuron-(trn2|trn2n).48xlarge-64" ,
339
+ ChainConfigModifier .default_config ().set (
340
+ config_modifiers = [
341
+ MeshShapeModifier .default_config ().set (
342
+ # TP within the chip, FSDP across chips.
343
+ # Each TRN2 chip has 4 XLA cores.
344
+ mesh_shape = mesh_shape_from_axes (fsdp = - 1 , model = 4 )
345
+ ),
346
+ ModelConfigModifier .default_config ().set (
347
+ model_cfg_modifications = trn2_model_modifications
348
+ ),
349
+ ],
350
+ ),
351
+ ),
290
352
),
291
353
)
292
354
elif model_size == "8B" :
@@ -367,6 +429,21 @@ def get_trainer_kwargs(
367
429
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)" ,
368
430
mesh_shape_from_axes (data = - 1 , fsdp = 8 ),
369
431
),
432
+ (
433
+ "neuron-(trn2|trn2n).48xlarge-64" ,
434
+ ChainConfigModifier .default_config ().set (
435
+ config_modifiers = [
436
+ MeshShapeModifier .default_config ().set (
437
+ # TP within the chip, FSDP across chips.
438
+ # Each TRN2 chip has 4 XLA cores.
439
+ mesh_shape = mesh_shape_from_axes (fsdp = - 1 , model = 4 )
440
+ ),
441
+ ModelConfigModifier .default_config ().set (
442
+ model_cfg_modifications = trn2_model_modifications
443
+ ),
444
+ ],
445
+ ),
446
+ ),
370
447
),
371
448
)
372
449
elif model_size == "70B" :
@@ -417,6 +494,21 @@ def get_trainer_kwargs(
417
494
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)" ,
418
495
mesh_shape_from_axes (data = - 1 , fsdp = 128 ),
419
496
),
497
+ (
498
+ "neuron-(trn2|trn2n).48xlarge-64" ,
499
+ ChainConfigModifier .default_config ().set (
500
+ config_modifiers = [
501
+ MeshShapeModifier .default_config ().set (
502
+ # TP within the chip, FSDP across chips.
503
+ # Each TRN2 chip has 4 XLA cores.
504
+ mesh_shape = mesh_shape_from_axes (fsdp = - 1 , model = 4 )
505
+ ),
506
+ ModelConfigModifier .default_config ().set (
507
+ model_cfg_modifications = trn2_model_modifications
508
+ ),
509
+ ],
510
+ ),
511
+ ),
420
512
),
421
513
)
422
514
else :
0 commit comments