Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential speedup for Mixtral to avoid forward pass of MOE layer if expert is not selected #2459

Closed
JasonZhu1313 opened this issue Jan 17, 2024 · 1 comment

Comments

@JasonZhu1313
Copy link
Contributor

In Mixtral, we are doing the following in MixtralMoE layer


for expert_idx in self.expert_indicies:
            expert_layer = self.experts[expert_idx]
            expert_mask = (selected_experts == expert_idx)
            expert_weights = (routing_weights * expert_mask).sum(dim=-1,
                                                                 keepdim=True)

            current_hidden_states = expert_layer(hidden_states).mul_(
                expert_weights)
            if final_hidden_states is None:
                final_hidden_states = current_hidden_states
            else:
                final_hidden_states.add_(current_hidden_states)

There might be some inefficiencies here because, even if the expert on a rank is not selected, it still undergoes a forward pass on hidden_states (though the output will be multiplied by 0 - expert_weights since the expert is not selected). Considering a 10% selection rate of an expert, it means 90% of expert_weights are 0 for that expert. It might be wise to use sparse dense matmul.

In the most extreme case where the expert is not selected by all the tokens in an inference batch, we should simply bypass the expert and return a zero tensor of shape [batch * seq_len, hidden_dim]. I was trying to achieve that with:


if expert_weights.any():
   current_hidden_states = expert_layer(hidden_states).mul_(
                expert_weights)

But got

 File "/home/jobuser/vllm/model_executor/models/mixtral.py", line 88, in forward
    if torch.any(expert_weights):
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

ray::RayWorkerVllm.execute_method() (pid=12372, ip=100.100.173.140, actor_id=fcb81e4892ae707c8e6c8eac01000000, repr=<vllm.engine.ray_utils.RayWorkerVllm object at 0x7f7248338400>)
  File "/home/jobuser/vllm/engine/ray_utils.py", line 31, in execute_method
    return executor(*args, **kwargs)
  File "/home/jobuser/vllm/worker/worker.py", line 123, in warm_up_model
    self.model_runner.capture_model(self.gpu_cache)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jobuser/vllm/worker/model_runner.py", line 436, in capture_model
    graph_runner.capture(
  File "/home/jobuser/vllm/worker/model_runner.py", line 482, in capture
    with torch.cuda.graph(self.graph, pool=memory_pool):
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/cuda/graphs.py", line 197, in __exit__
    self.cuda_graph.capture_end()
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/cuda/graphs.py", line 88, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

It looks like when CUDA graph mode is enabled, it is not compatible with a conditional statement..

@cadedaniel
Copy link
Collaborator

cadedaniel commented Jan 17, 2024

Good catch @JasonZhu1313 . this PR uses this idea (as well as sharding the expert weights over higher TP) #2293

for the cuda graph error, you can enforce eager mode to test it out

@hmellor hmellor closed this as completed Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants