@@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert'
122
122
mesh_axis_names[3]: 'fsdp'
123
123
mesh_axis_names[4]: 'seq'
124
124
mesh_axis_names[5]: 'model'
125
+ mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64'
126
+ mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
127
+ mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1
128
+ mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1
129
+ mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1
130
+ mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1
131
+ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1
132
+ mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4
133
+ mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier'
134
+ mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer'
135
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu'
136
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout'
137
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer'
138
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True
139
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear'
140
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None
141
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model'
142
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True
143
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear'
144
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model'
145
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None
146
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08
147
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32'
148
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm'
149
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0
150
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
151
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row'
152
+ mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm'
153
+ mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer'
154
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
155
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear'
156
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True
157
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
158
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None
159
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model'
160
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None
161
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
162
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention'
163
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True
164
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear'
165
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None
166
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model'
167
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None
168
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery'
169
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout'
170
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer'
171
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08
172
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32'
173
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm'
174
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
175
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row'
176
+ mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm'
177
+ mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer'
178
+ mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier'
179
+ mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear'
180
+ mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True
181
+ mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
182
+ mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None
183
+ mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model'
184
+ mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None
185
+ mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'
186
+ mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier'
187
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model'
188
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert'
189
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp'
190
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq'
191
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp'
192
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None
193
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp'
194
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model'
195
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model'
196
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp'
197
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model'
198
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert'
199
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp'
200
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq'
201
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp'
202
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model'
203
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None
204
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp'
205
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None
206
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None
207
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp'
208
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model'
209
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None
210
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp'
211
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None
212
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None
213
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp'
214
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model'
215
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None
216
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp'
217
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None
218
+ mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None
219
+ mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
125
220
mesh_shape[0]: 1
126
221
mesh_shape[1]: -1
127
222
mesh_shape[2]: 1
0 commit comments