Skip to content

Commit dd57b8b

Browse files
committed
Add new code owner for multimodal workgroup
1 parent b6be338 commit dd57b8b

File tree

6 files changed

+17
-16
lines changed

6 files changed

+17
-16
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml
2-
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @rni418 @gagika @shralex @yangyuwei @SurbhiJainUSC @hengtaoguo @A9isha @wang2yn84 @wyzhang @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis
2+
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @rni418 @gagika @shralex @yangyuwei @SurbhiJainUSC @hengtaoguo @A9isha @wang2yn84 @wyzhang @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @aireenmei

.github/workflows/AddLabel.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ jobs:
7474
jrplatin: "",
7575
patemotter: "",
7676
lumosis: "",
77+
aireenmei: "",
7778
}
7879
const reviews = await github.rest.pulls.listReviews({
7980
owner,

MaxText/layers/llama4.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def __call__(
428428
else:
429429
return layer_output
430430

431+
431432
class Llama4ScannableBlock(nn.Module):
432433
'''
433434
A repeatable block given nope_layer_interval and interleave_moe_layer_step
@@ -470,12 +471,12 @@ def __call__(
470471
nope_layer = determine_is_nope_layer(layer_id, self.nope_layer_interval)
471472
moe_layer = determine_is_moe_layer(layer_id, self.interleave_moe_layer_step)
472473
layer = Llama4DecoderLayer(
473-
config=cfg,
474-
mesh=mesh,
475-
name=f"layers_{layer_id}",
476-
quant=self.quant,
477-
is_nope_layer=nope_layer,
478-
is_moe_layer=moe_layer,
474+
config=cfg,
475+
mesh=mesh,
476+
name=f"layers_{layer_id}",
477+
quant=self.quant,
478+
is_nope_layer=nope_layer,
479+
is_moe_layer=moe_layer,
479480
)
480481
y = layer(
481482
y,
@@ -488,9 +489,8 @@ def __call__(
488489
slot=slot,
489490
)
490491
if cfg.scan_layers:
491-
y=y[0]
492+
y = y[0]
492493
if cfg.scan_layers:
493494
return y, None
494495
else:
495496
return y
496-

MaxText/layers/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def get_decoder_layers(self):
362362
return [simple_layer.SimpleMlpDecoderLayer]
363363
elif self.config.decoder_block == DecoderBlockType.LLAMA4:
364364
from MaxText.layers import llama4 # pylint: disable=import-outside-toplevel
365+
365366
if self.config.scan_layers:
366367
return [llama4.Llama4ScannableBlock]
367368
else:
@@ -544,8 +545,8 @@ def __call__(
544545
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask}
545546
elif cfg.decoder_block == DecoderBlockType.LLAMA4:
546547
layer_kwargs = {
547-
"nope_layer_interval": self.config.nope_layer_interval,
548-
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
548+
"nope_layer_interval": self.config.nope_layer_interval,
549+
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
549550
}
550551
RemattedBlockLayer = RemattedBlockLayers[0]
551552
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)

MaxText/max_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,9 @@ def maybe_initialize_jax_distributed_system(raw_keys):
152152
max_logging.log("Attempting to initialize the jax distributed system for CPU backend...")
153153
initialize_jax_for_cpu(raw_keys)
154154
max_logging.log("Jax distributed system initialized on CPUs!")
155-
elif (
156-
raw_keys["enable_checkpointing"]
157-
and raw_keys["compile_topology_num_slices"] == -1
158-
) or raw_keys["hardware"] == "gpu_multiprocess":
155+
elif (raw_keys["enable_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1) or raw_keys[
156+
"hardware"
157+
] == "gpu_multiprocess":
159158
max_logging.log("Attempting to initialize the jax distributed system...")
160159
if not raw_keys["enable_emergency_checkpoint"]:
161160
jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"])

MaxText/tests/grpo_trainer_correctness_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def setUp(self):
135135
)
136136
self.tokenizer_model.add_special_tokens({"pad_token": "<pad>"})
137137

138-
@pytest.mark.tpu_only # ATTENTION: Only run on TPU V4-8
138+
@pytest.mark.tpu_only # ATTENTION: Only run on TPU V4-8
139139
def test_grpo_trainer_correctness(self):
140140
# Get the expected (golden) data.
141141
golden_data = get_golden_data(self.config)

0 commit comments

Comments
 (0)