File tree 6 files changed +17
-16
lines changed
6 files changed +17
-16
lines changed Original file line number Diff line number Diff line change 1
1
# Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml
2
- * @ gobbleturk @ khatwanimohit @ bvandermoon @ vipannalla @ RissyRan @ richjames0 @ rni418 @ gagika @ shralex @ yangyuwei @ SurbhiJainUSC @ hengtaoguo @ A9isha @ wang2yn84 @ wyzhang @ mitalisi @ gpolovets1 @ mailvijayasingh @ jrplatin @ patemotter @ lumosis
2
+ * @ gobbleturk @ khatwanimohit @ bvandermoon @ vipannalla @ RissyRan @ richjames0 @ rni418 @ gagika @ shralex @ yangyuwei @ SurbhiJainUSC @ hengtaoguo @ A9isha @ wang2yn84 @ wyzhang @ mitalisi @ gpolovets1 @ mailvijayasingh @ jrplatin @ patemotter @ lumosis @ aireenmei
Original file line number Diff line number Diff line change 74
74
jrplatin: "",
75
75
patemotter: "",
76
76
lumosis: "",
77
+ aireenmei: "",
77
78
}
78
79
const reviews = await github.rest.pulls.listReviews({
79
80
owner,
Original file line number Diff line number Diff line change @@ -428,6 +428,7 @@ def __call__(
428
428
else :
429
429
return layer_output
430
430
431
+
431
432
class Llama4ScannableBlock (nn .Module ):
432
433
'''
433
434
A repeatable block given nope_layer_interval and interleave_moe_layer_step
@@ -470,12 +471,12 @@ def __call__(
470
471
nope_layer = determine_is_nope_layer (layer_id , self .nope_layer_interval )
471
472
moe_layer = determine_is_moe_layer (layer_id , self .interleave_moe_layer_step )
472
473
layer = Llama4DecoderLayer (
473
- config = cfg ,
474
- mesh = mesh ,
475
- name = f"layers_{ layer_id } " ,
476
- quant = self .quant ,
477
- is_nope_layer = nope_layer ,
478
- is_moe_layer = moe_layer ,
474
+ config = cfg ,
475
+ mesh = mesh ,
476
+ name = f"layers_{ layer_id } " ,
477
+ quant = self .quant ,
478
+ is_nope_layer = nope_layer ,
479
+ is_moe_layer = moe_layer ,
479
480
)
480
481
y = layer (
481
482
y ,
@@ -488,9 +489,8 @@ def __call__(
488
489
slot = slot ,
489
490
)
490
491
if cfg .scan_layers :
491
- y = y [0 ]
492
+ y = y [0 ]
492
493
if cfg .scan_layers :
493
494
return y , None
494
495
else :
495
496
return y
496
-
Original file line number Diff line number Diff line change @@ -362,6 +362,7 @@ def get_decoder_layers(self):
362
362
return [simple_layer .SimpleMlpDecoderLayer ]
363
363
elif self .config .decoder_block == DecoderBlockType .LLAMA4 :
364
364
from MaxText .layers import llama4 # pylint: disable=import-outside-toplevel
365
+
365
366
if self .config .scan_layers :
366
367
return [llama4 .Llama4ScannableBlock ]
367
368
else :
@@ -544,8 +545,8 @@ def __call__(
544
545
layer_call_kwargs = {"bidirectional_mask" : bidirectional_mask }
545
546
elif cfg .decoder_block == DecoderBlockType .LLAMA4 :
546
547
layer_kwargs = {
547
- "nope_layer_interval" : self .config .nope_layer_interval ,
548
- "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
548
+ "nope_layer_interval" : self .config .nope_layer_interval ,
549
+ "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
549
550
}
550
551
RemattedBlockLayer = RemattedBlockLayers [0 ]
551
552
scan_length = int (cfg .num_decoder_layers / cfg .inhomogeneous_layer_cycle_interval )
Original file line number Diff line number Diff line change @@ -152,10 +152,9 @@ def maybe_initialize_jax_distributed_system(raw_keys):
152
152
max_logging .log ("Attempting to initialize the jax distributed system for CPU backend..." )
153
153
initialize_jax_for_cpu (raw_keys )
154
154
max_logging .log ("Jax distributed system initialized on CPUs!" )
155
- elif (
156
- raw_keys ["enable_checkpointing" ]
157
- and raw_keys ["compile_topology_num_slices" ] == - 1
158
- ) or raw_keys ["hardware" ] == "gpu_multiprocess" :
155
+ elif (raw_keys ["enable_checkpointing" ] and raw_keys ["compile_topology_num_slices" ] == - 1 ) or raw_keys [
156
+ "hardware"
157
+ ] == "gpu_multiprocess" :
159
158
max_logging .log ("Attempting to initialize the jax distributed system..." )
160
159
if not raw_keys ["enable_emergency_checkpoint" ]:
161
160
jax .distributed .initialize (initialization_timeout = raw_keys ["jax_distributed_initialization_timeout" ])
Original file line number Diff line number Diff line change @@ -135,7 +135,7 @@ def setUp(self):
135
135
)
136
136
self .tokenizer_model .add_special_tokens ({"pad_token" : "<pad>" })
137
137
138
- @pytest .mark .tpu_only # ATTENTION: Only run on TPU V4-8
138
+ @pytest .mark .tpu_only # ATTENTION: Only run on TPU V4-8
139
139
def test_grpo_trainer_correctness (self ):
140
140
# Get the expected (golden) data.
141
141
golden_data = get_golden_data (self .config )
You can’t perform that action at this time.
0 commit comments