Skip to content

feat: Enable EPLB to existing MoE models #5203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,18 @@ def __init__(
class Llama4MoE(nn.Module):

def __init__(
self,
*,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
shared_expert_intermediate_size: int,
aux_stream: torch.cuda.Stream,
dtype: Optional[torch.dtype] = None,
tune_max_num_tokens: int = 8192,
model_config: ModelConfig = ModelConfig(),
self,
*,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
shared_expert_intermediate_size: int,
aux_stream: torch.cuda.Stream,
dtype: Optional[torch.dtype] = None,
tune_max_num_tokens: int = 8192,
model_config: ModelConfig = ModelConfig(),
layer_idx: Optional[int] = None,
):
from tensorrt_llm._torch.distributed import AllReduce

Expand Down Expand Up @@ -273,7 +274,8 @@ def __init__(
False, # In both low latency and max-throughput scenarios, FusedMoE needs not to do allreduce inside op.
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
model_config=model_config,
apply_router_weight_on_input=True)
apply_router_weight_on_input=True,
layer_idx=layer_idx)

self.router = Linear(hidden_size,
num_experts,
Expand Down Expand Up @@ -403,7 +405,8 @@ def __init__(
shared_expert_intermediate_size=config.intermediate_size,
model_config=model_config,
aux_stream=aux_stream,
dtype=config.torch_dtype)
dtype=config.torch_dtype,
layer_idx=layer_idx)

# self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
# )
Expand Down
8 changes: 6 additions & 2 deletions tensorrt_llm/_torch/models/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
self,
model_config: ModelConfig[PretrainedConfig],
aux_stream: torch.cuda.Stream,
layer_idx: Optional[int] = None,
):
super().__init__()
config = model_config.pretrained_config
Expand All @@ -51,7 +52,8 @@ def __init__(
aux_stream=aux_stream,
dtype=config.torch_dtype,
reduce_results=reduce_results,
model_config=model_config)
model_config=model_config,
layer_idx=layer_idx)

def forward(
self,
Expand Down Expand Up @@ -108,7 +110,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],

self.self_attn = MixtralAttention(model_config, layer_idx=layer_idx)

self.block_sparse_moe = MixtralMoE(model_config, aux_stream)
self.block_sparse_moe = MixtralMoE(model_config,
aux_stream,
layer_idx=layer_idx)

self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
self,
model_config: ModelConfig[Qwen3MoeConfig],
aux_stream: torch.cuda.Stream,
layer_idx: int,
layer_idx: Optional[int] = None,
):
super().__init__()
config = model_config.pretrained_config
Expand Down Expand Up @@ -115,6 +115,7 @@ def __init__(
dtype=config.torch_dtype,
reduce_results=False,
model_config=model_config,
layer_idx=layer_idx,
)

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
self,
model_config: ModelConfig[Qwen2MoeConfig],
aux_stream: torch.cuda.Stream,
layer_idx: Optional[int] = None,
):
super().__init__()
config = model_config.pretrained_config
Expand Down Expand Up @@ -57,7 +58,8 @@ def __init__(
aux_stream=aux_stream,
dtype=config.torch_dtype,
reduce_results=reduce_results,
model_config=model_config)
model_config=model_config,
layer_idx=layer_idx)

self.shared_expert = GatedMLP(
hidden_size=config.hidden_size,
Expand Down Expand Up @@ -143,7 +145,7 @@ def __init__(self, model_config: ModelConfig[Qwen2MoeConfig],
layer_idx=layer_idx,
)

self.mlp = QwenMoE(model_config, aux_stream)
self.mlp = QwenMoE(model_config, aux_stream, layer_idx=layer_idx)

self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):

moe_model_arch_list = [
'DeepseekV3ForCausalLM',
'MixtralForCausalLM',
'Llama4ForConditionalGeneration',
'Qwen2MoeForCausalLM',
'Qwen3MoeForCausalLM',
]


Expand Down