Skip to content

Commit

Permalink
[Bugfix] Fix deepseekv3 grouped topk error (vllm-project#13474)
Browse files Browse the repository at this point in the history
Signed-off-by: Chen-XiaoBing <[email protected]>
  • Loading branch information
Chen-XiaoBing authored and kerthcet committed Feb 21, 2025
1 parent d69082f commit ef8cc6f
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,23 +939,26 @@ def grouped_topk(hidden_states: torch.Tensor,
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")

num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)

num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]

if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
Expand Down

0 comments on commit ef8cc6f

Please sign in to comment.