Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix deepseekv3 grouped topk error #13474

Merged
merged 3 commits into from
Feb 20, 2025

Conversation

Chen-XiaoBing
Copy link
Contributor

@Chen-XiaoBing Chen-XiaoBing commented Feb 18, 2025

Fix the logic in grouped topk computation of fused moe.

There is a slight difference between vllm grouped_topk and the official code. When the newly-introduced bias term (e_score_correction_bias in vllm) is not None, we should firstly get the top-2 scores of each group, and use the summation to get top-k groups. You can also check the compute logic in DeepSeek v3's official inference code

The mask scores are set to 0. This configuration may result in the selection of experts in masked groups if the scores in the unmasked groups are negative. This behavior can lead to incorrect or suboptimal selections in certain scenarios. DeepSeek v3 also reset masked scores to -inf.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mgoin
Copy link
Member

mgoin commented Feb 18, 2025

Thank you for your contribution!

I ran a quick eval on gsm8k with DeepSeek R1 and see that this PR seems to slightly regress performance, however it is within stderr.

(vllm) ➜  vllm git:(main) lm_eval --model vllm --model_args pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
vllm (pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.956|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.956|±  |0.0056|

(vllm) ➜  vllm git:(fix-dsv3-grouped-topk) lm_eval --model vllm --model_args pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
vllm (pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9507|±  | 0.006|
|     |       |strict-match    |     5|exact_match|↑  |0.9507|±  | 0.006|

I will try to run harder evals to determine improvement. Would you have an example of bad performance from DeepSeek on main due to this issue?

@simon-mo
Copy link
Collaborator

@mgoin, we might need to test V3 model as it is more trained for GSM and MMLU.

@mgoin
Copy link
Member

mgoin commented Feb 18, 2025

Here is MMLU for R1, similar small drops for each category - I'll try to get V3 going on another machine

(vllm) ➜  vllm git:(main) lm_eval --model vllm --model_args pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.8 --trust_remote_code --tasks mmlu --batch_size 16
vllm (pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 16

|      Groups      |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------------------|------:|------|------|------|---|-----:|---|-----:|
|mmlu              |      2|none  |      |acc   |↑  |0.8514|±  |0.0029|
| - humanities     |      2|none  |      |acc   |↑  |0.7845|±  |0.0057|
| - other          |      2|none  |      |acc   |↑  |0.8822|±  |0.0055|
| - social sciences|      2|none  |      |acc   |↑  |0.9227|±  |0.0047|
| - stem           |      2|none  |      |acc   |↑  |0.8513|±  |0.0062|

(vllm) ➜  vllm git:(fix-dsv3-grouped-topk) lm_eval --model vllm --model_args pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.8 --trust_remote_code --tasks mmlu --batch_size 16
vllm (pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 16
|      Groups      |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------------------|------:|------|------|------|---|-----:|---|-----:|
|mmlu              |      2|none  |      |acc   |↑  |0.8492|±  |0.0029|
| - humanities     |      2|none  |      |acc   |↑  |0.7824|±  |0.0057|
| - other          |      2|none  |      |acc   |↑  |0.8806|±  |0.0055|
| - social sciences|      2|none  |      |acc   |↑  |0.9194|±  |0.0048|
| - stem           |      2|none  |      |acc   |↑  |0.8493|±  |0.0062|

@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 20, 2025
@simon-mo simon-mo enabled auto-merge (squash) February 20, 2025 06:46
@simon-mo
Copy link
Collaborator

I'll release v0.7.3 after this PR is merged.

auto-merge was automatically disabled February 20, 2025 06:54

Head branch was pushed to by a user without write access

@Chen-XiaoBing Chen-XiaoBing force-pushed the fix-dsv3-grouped-topk branch 2 times, most recently from 1dbfe7b to cb53786 Compare February 20, 2025 07:08
@simon-mo simon-mo enabled auto-merge (squash) February 20, 2025 07:09
auto-merge was automatically disabled February 20, 2025 07:14

Head branch was pushed to by a user without write access

@Chen-XiaoBing
Copy link
Contributor Author

@simon-mo Some checks in the CI pipeline have failed. Would you kindly assist with merging the code?

Wei-Lin-Intel added a commit to yangulei/vllm-fork that referenced this pull request Feb 20, 2025
Wei-Lin-Intel added a commit to yangulei/vllm-fork that referenced this pull request Feb 20, 2025
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Feb 20, 2025
jikunshang added a commit to jikunshang/vllm that referenced this pull request Feb 20, 2025
Signed-off-by: Kunshang Ji <[email protected]>
@simon-mo simon-mo merged commit ed6e907 into vllm-project:main Feb 20, 2025
41 of 44 checks passed
xuechendi pushed a commit to xuechendi/vllm-fork that referenced this pull request Feb 20, 2025
@Chen-XiaoBing Chen-XiaoBing deleted the fix-dsv3-grouped-topk branch February 20, 2025 16:27
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants