Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor parallel MOE implementation #2293

Closed
wants to merge 50 commits into from
Closed

Conversation

scv119
Copy link
Contributor

@scv119 scv119 commented Dec 28, 2023

This PR implements tensor parallel MOE by sharding each expert across all ranks.

concretely, it does following:

  1. column parallel each expert's w1 and w3 weights
  2. row parallel each expert's w2 weight
  3. for each batch, it groups the per request's hidden state according to routing decision.
  4. apply all experts mlp using grouped gemm
  5. apply routing weights
  6. all reduce to collect the result across TP ranks
  7. merge the per request result across different experts.

benchmark result:
A100 80G * 8, input_len=32, output_len=128

baseline:
batch_size 1: 2.1385579633448892 seconds
batch_size 8: 2.428515106982862 seconds
batch_size 32: 2.9776507209753618 seconds
batch_size 64: 3.7744668100300864 seconds

this PR
batch_size 1: 1.6442222506545174 seconds (77%)
batch_size 8: 2.3404843776564426 seconds (96%)
batch_size 32: 3.0149446266586892 seconds (101%)
batch_size 64: 3.878694705994955 seconds (103%)

A100 80G * 4, input_len=32, output_len=128

baseline:
batch_size 1: 2.9904473346929685 seconds
batch_size 8: 3.2857296433260976 seconds
batch_size 32: 3.917926660312029 seconds
batch_size 64: 4.401127053638144 seconds

this PR
batch_size 1: 1.6416094843492222 seconds (55%)
batch_size 8: 2.9794040496732728 seconds (91%)
batch_size 32: 3.631852053649103 seconds (93%)
batch_size 64: 4.388253151012274 seconds (100%)

csrc/bincount.cu Outdated Show resolved Hide resolved
@scv119
Copy link
Contributor Author

scv119 commented Dec 29, 2023

running into some weird torch.sort issues during cuda-graph capture...

@WoosukKwon
Copy link
Collaborator

Hi @scv119, thanks for addressing my comments! I haven't actually completed the review yet. Will add more tonight or tmr morning.

@scv119
Copy link
Contributor Author

scv119 commented Jan 4, 2024

@WoosukKwon just let you know the triton grouped matmul returns different result from torch reference implementation for large matrix multiplication, which is likely caused by triton-lang/triton#1190 (comment) but that's purely my speculation.

we might need to use https://github.com/imoneoi/cutlass_grouped_gemm if it matters.

@WoosukKwon WoosukKwon self-requested a review January 16, 2024 21:47
@scv119
Copy link
Contributor Author

scv119 commented Jan 17, 2024

1089dd8
we delayed the allreduce after the weights are merged by indices, which reduces the communications by half.

Comment on lines +177 to +180
grouped_w1_out = grouped_matmul(expanded_hidden_states,
cum_experts_range, w1s, "silu")
grouped_w3_out = grouped_matmul(expanded_hidden_states,
cum_experts_range, w3s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge w1s and w3 just like what we do for LlamaMLP? Merging the two weights will be highly efficient given the cost of grouped GEMM.

self,
expanded_hidden_states: torch.
Tensor, # [batch_size * top_k_experts, hidden_size]
reverse_indices, # [batch_size * top_k_experts]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
reverse_indices, # [batch_size * top_k_experts]
reverse_indices: torch.Tensor, # [batch_size * top_k_experts]

Comment on lines +63 to +74
set_weight_attrs(self.w1s, {
"weight_loader": self.weight_loader,
"tp_type": "column"
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
"tp_type": "row"
})
set_weight_attrs(self.w3s, {
"weight_loader": self.weight_loader,
"tp_type": "column"
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we make this compatible with other parallel linear layers by tagging input_dim and output_dim instead of tp_type?

Comment on lines +63 to +74
set_weight_attrs(self.w1s, {
"weight_loader": self.weight_loader,
"tp_type": "column"
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
"tp_type": "row"
})
set_weight_attrs(self.w3s, {
"weight_loader": self.weight_loader,
"tp_type": "column"
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we make this more similar to other parallel linear layers by tagging input_dim and output_dim instead of tp_type?

Comment on lines +272 to +277
expert_params_mapping = [
# (param_name, weight_name, expert_id)
(f"{weight_name}s", f"experts.{expert_id}.{weight_name}.weight",
expert_id) for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, do we assume that the expert linear layers don't have bias terms?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @scv119 for the updates! The PR looks good to me overall. For the grouped GEMM, stuff I think we can investigate the Cutlass implementation later. I actually spent some time understanding it last weekend, but found it a bit difficult to understand. For now, I think the Triton kernel is acceptable, and it is actually needed for AMD GPUs anyway.

@chu-tianxiang
Copy link
Contributor

Any insights into how the quantized model will be managed please? There's a challenge regarding the weights: it may not be possible to concatenate them due to differences in experts. For instance, GPTQ might employ distinct activation order and AWQ might use varying scales. Thank you.

linear_method=None)

self.w1s = nn.Parameter(
torch.empty(self.num_total_experts,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are many experts like deepseekmoe, it is easy to oom in this function. Any ideas to improve memory utilization?

@scv119
Copy link
Contributor Author

scv119 commented Jan 18, 2024

thanks @WoosukKwon. will do another pass; also we noticed some poor performance on h100, probably need tune the kernel parameters a bit.

@pcmoritz
Copy link
Collaborator

On H100s, changing the number of SMs to 256 brought the best improvement in terms of throughtput for me (but still not quite matching current master). It was measured with python benchmarks/benchmark_throughput.py --model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len 1000 --output-len 50 -tp 8 --num-prompts 1000:

Current master: 28600 tok/s
Current PR: 24900 tok/s
New hyperparameters below: 27600 tok/s

All numbers have an error of about +/- 200 tok/s.

It is quite possible that by tuning more / autotuning we can get even better results here -- I'd love to learn about it if anybody has better parameters :)

diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py
index 6d37884302..94bb9f0858 100644
--- a/vllm/model_executor/layers/moe.py
+++ b/vllm/model_executor/layers/moe.py
@@ -335,15 +335,13 @@ def grouped_matmul(input: torch.Tensor,
     BLOCK_SIZE_M = 16
     BLOCK_SIZE_N = 64
     BLOCK_SIZE_K = 32
-    num_warps = 2
-    NUM_SM = 128
+    num_warps = 4
+    NUM_SM = 256
     num_stages = 5
     # hand tune the block size for different problem sizes.
     if input.shape[0] >= 8:
-        num_warps = 4
         BLOCK_SIZE_N = 128
     if input.shape[0] >= 32:
-        num_warps = 4
         BLOCK_SIZE_M = 32
         BLOCK_SIZE_N = 128
     # we use a fixed number of CTA, and it's auto-tunable

@pcmoritz pcmoritz mentioned this pull request Jan 22, 2024
@scv119
Copy link
Contributor Author

scv119 commented Jan 30, 2024

i think one overhead of this PR is too many small elementwise operations that are not fused according to my profile.
#2453 should be a better version of this one, thus closing here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants