Skip to content

Commit

Permalink
[Bugfix] Fix deepseekv3 grouped topk error
Browse files Browse the repository at this point in the history
  • Loading branch information
Chen-XiaoBing committed Feb 18, 2025
1 parent e2603fe commit 76896ef
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,15 @@ def grouped_topk(hidden_states: torch.Tensor,
scores = scores + e_score_correction_bias.unsqueeze(0)

num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
if e_score_correction_bias is not None:
group_scores = (
scores.view(num_token, num_expert_group, -1)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
Expand Down

0 comments on commit 76896ef

Please sign in to comment.