Skip to content

[Feature]Add async tensor parallelism using compilation pass #17882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

cascade812
Copy link
Contributor

@cascade812 cascade812 commented May 9, 2025

This PR adds torch async tp using compilation pass.
It requires below config to run

config = CompilationConfig(
    level=3,
    compile_sizes=[4, 8, 16],
    splitting_ops=[],
)
config.pass_config.enable_async_tp= True

llm = LLM(model="llama/Llama-3.2-1B-Instruct",
          enforce_eager=False,
          tensor_parallel_size=2,
          dtype=torch.float16,
          compilation_config=config)

If use vllm serve, add -O '{"level":3, "compile_sizes": [4, 8, 16], "pass_config": {"enable_async_tp": true}}'

Some benchmark results on 2 GPUs of A100.

model = unsloth/Meta-Llama-3.1-8B-Instruct
tp_size = 2
batch_size = 1
input_len=2048
output_len=1

Latency is slightly higher when enable async tp with input len is 2048.

python benchmarks/benchmark_latency.py --model unsloth/Meta-Llama-3.1-8B-Instruct --output-len 1 --input-len 2048 --batch-size 1 --tensor-parallel-size 2 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [2048], "pass_config": {"enable_async_tp": true}}' --no-enable-prefix-caching
Avg latency: 0.19084931214650472 seconds
10% percentile latency: 0.18976202309131623 seconds
25% percentile latency: 0.1900753453373909 seconds
50% percentile latency: 0.1905977502465248 seconds
75% percentile latency: 0.19149887561798096 seconds
90% percentile latency: 0.19235587865114212 seconds
99% percentile latency: 0.19316918954253195 seconds
python benchmarks/benchmark_latency.py --model unsloth/Meta-Llama-3.1-8B-Instruct --output-len 1 --input-len 2048 --batch-size 1 --tensor-parallel-size 2 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [2048], "pass_config": {"enable_async_tp": false}}' --no-enable-prefix-caching &> benchmark.log

Avg latency: 0.18629603137572606 seconds
10% percentile latency: 0.18520110249519348 seconds
25% percentile latency: 0.18542765080928802 seconds
50% percentile latency: 0.18612460047006607 seconds
75% percentile latency: 0.1871374361217022 seconds
90% percentile latency: 0.18730796426534652 seconds
99% percentile latency: 0.18765324756503104 seconds

Latency is almost the same when enable async tp with input len is 8192.

python benchmarks/benchmark_latency.py --model unsloth/Meta-Llama-3.1-8B-Instruct --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 2 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": true}}' --no-enable-prefix-caching &> benchmark.log
Avg latency: 0.7484328990181287 seconds
10% percentile latency: 0.7470755681395531 seconds
25% percentile latency: 0.7476142570376396 seconds
50% percentile latency: 0.7484151422977448 seconds
75% percentile latency: 0.7493306994438171 seconds
90% percentile latency: 0.749718876183033 seconds
99% percentile latency: 0.7499414524435997 seconds
python benchmarks/benchmark_latency.py --model unsloth/Meta-Llama-3.1-8B-Instruct --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 2 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": false}}' --no-enable-prefix-caching

Avg latency: 0.761994085709254 seconds
10% percentile latency: 0.7605385079979896 seconds
25% percentile latency: 0.7606241554021835 seconds
50% percentile latency: 0.7613658607006073 seconds
75% percentile latency: 0.7634274959564209 seconds
90% percentile latency: 0.7639447942376136 seconds
99% percentile latency: 0.7645426994562149 seconds

I think we can test this feature on a more demanding workload, like a 70B model across 4 GPUs.

Copy link

github-actions bot commented May 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented May 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cascade812.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 9, 2025
@cascade812 cascade812 changed the title Add torch async tensor parallelism using compilation pass Add async tensor parallelism using compilation pass May 9, 2025
@cascade812 cascade812 changed the title Add async tensor parallelism using compilation pass [Feature]Add async tensor parallelism using compilation pass May 10, 2025
Signed-off-by: cascade812 <[email protected]>
@tlrmchlsmth
Copy link
Collaborator

This is what I'm seeing on a 4xH200 system:

vLLM main:

python benchmarks/benchmark_latency.py --model meta-llama/Llama-3.1-70B-Instruct --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 4 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192]}' --no-enable-prefix-caching

Avg latency: 0.5901695430278778 seconds
10% percentile latency: 0.5880581809207797 seconds
25% percentile latency: 0.5890177926048636 seconds
50% percentile latency: 0.5897573744878173 seconds
75% percentile latency: 0.5918027735315263 seconds
90% percentile latency: 0.5927275052294135 seconds
99% percentile latency: 0.5936225369013846 seconds

This PR:

python benchmarks/benchmark_latency.py --model meta-llama/Llama-3.1-70B-Instruct --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 4 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true}}' --no-enable-prefix-caching

Avg latency: 0.5260226292535662 seconds
10% percentile latency: 0.5204391019418836 seconds
25% percentile latency: 0.5236838199198246 seconds
50% percentile latency: 0.5270518623292446 seconds
75% percentile latency: 0.5288312714546919 seconds
90% percentile latency: 0.5301465425640345 seconds
99% percentile latency: 0.5310012844949961 seconds

@cascade812
Copy link
Contributor Author

cascade812 commented May 13, 2025

It also shows ~10% latency reduce on 4 X A100 (40GB) for 8B LLM.

python benchmarks/benchmark_latency.py --model unsloth/Meta-Llama-3.1-8B-Instruct --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 4 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192]' --no-enable-prefix-caching

Avg latency: 0.19637128477916121 seconds
10% percentile latency: 0.19597571501508354 seconds
25% percentile latency: 0.19614936946891248 seconds
50% percentile latency: 0.1963466382585466 seconds
75% percentile latency: 0.19650844135321677 seconds
90% percentile latency: 0.19685424230992793 seconds
99% percentile latency: 0.19704130258411168 seconds

With async tp enabled

python benchmarks/benchmark_latency.py --model unsloth/Meta-Llama-3.1-8B-Instruct --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 4 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true}}' --no-enable-prefix-caching

Avg latency: 0.17523012173672517 seconds
10% percentile latency: 0.17482020873576404 seconds
25% percentile latency: 0.17494881455786526 seconds
50% percentile latency: 0.1752709816209972 seconds
75% percentile latency: 0.1755139627493918 seconds
90% percentile latency: 0.1756236758083105 seconds
99% percentile latency: 0.17574628297239542 seconds

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
@JaheimLee
Copy link

Hi! Is it necessary to always set sequence_parallelism to true?

Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! Concise and effective. I only had a few cleanup comments.

@cascade812
Copy link
Contributor Author

Hi! Is it necessary to always set sequence_parallelism to true?

No need. sequence parallelism is enabled by default if enable_async_tp is true.

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Copy link

mergify bot commented May 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cascade812.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label May 17, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label May 18, 2025
Signed-off-by: cascade812 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants