-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[TRTLLM-5835][feat] Optimized Mamba2Mixer prefill #5128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…b/causal-conv1d Signed-off-by: Tomer Asida <[email protected]> Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]> Signed-off-by: Tomer Asida <[email protected]>
…ze in RMSNorm Signed-off-by: Tomer Asida <[email protected]> Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…kernels (similar to the tests in tests/unittest/_torch/thop/test_mamba_conv1d_op.py) Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…nels for better numerical stability and support for initial states + varlen batching (AKA continuous batching) Signed-off-by: Tomer Asida <[email protected]>
… prefill and decode kernels (similar to the tests in tests/unittest/_torch/thop/test_selective_scan_op.py) Signed-off-by: Tomer Asida <[email protected]>
…round) Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…tiple tensors and torch.cat Signed-off-by: Tomer Asida <[email protected]>
… Results in +25% throughput (1) call convolution and SSM explicitly so no need for special call to get conv states (2) same dtype for conv and ssm states (3) remove unused code - causal_conv1d_varlen_states, mamba_split_conv1d_scan_combined Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…ause instead of duplicating code Signed-off-by: Tomer Asida <[email protected]>
…g forward pass Signed-off-by: Tomer Asida <[email protected]>
…ng. conv weights are already in correct shape Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…RTLLM in-house mamba_conv1d kernel Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…. Use standard TRTLLM types and macros when needed Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…-LLM into fix-nemotron-h-warmup Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…de Nemotron-H forward pass. This is instead of preparing cu_seqlens and seq_idx in MambaCacheManager for better code separation. Also because in MambaCacheManager.prepare_resources(), attn_metadata is not updated yet and we need it to create cu_seqlens efficiently. This also makes it similar to the regular attn_metadata flow, creating it if needed and preparing it before forward pass. The difference is that for regular attention this is done in PyTorchModelEngine, and for mamba we do it inside the model forward, since hybrid models are still a special case and we want to isolate the relevant code Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
(they appear first in attn_metadata.seq_lens_cuda) Signed-off-by: Tomer Asida <[email protected]>
…Manager as self.cu_seqlens and self.seq_idx don't exist anymore Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the Mamba2Mixer prefill performance by reducing dynamic memory allocations and host-to-device copies. Key changes include:
- Pre-allocating state indices on the correct device in the resource manager.
- Introducing a new Mamba2Metadata class to pre-compute and store metadata for varlen batched prefill.
- Refactoring the Mamba2Mixer forward pass to separate prefill and decode logic and eliminate a redundant loop.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
tensorrt_llm/_torch/pyexecutor/resource_manager.py | Assigns the correct device for state indices to minimize host-to-device transfers. |
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py | Refactors forward pass to split prefill and decode logic using new metadata and indices. |
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py | Introduces a metadata class to compute and hold cu_seqlens and sequence indices. |
tensorrt_llm/_torch/models/modeling_nemotron_h.py | Integrates mamba metadata into layer forward passes for optimized processing. |
Comments suppressed due to low confidence (2)
tensorrt_llm/_torch/pyexecutor/resource_manager.py:598
- Consider adding an inline comment explaining that the device is explicitly set using self.ssm_states.device to ensure correct GPU allocation, and verify that self.ssm_states is initialized prior to this call.
self.state_indices = torch.as_tensor(state_indices,
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py:178
- [nitpick] Consider adding a brief comment to clarify that state_indices is split into prefill and decode subsets based on the computed batch_split_size, which will aid future maintainers in understanding the code logic.
state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size)
… prepare() is called Co-authored-by: Copilot <[email protected]> Signed-off-by: tomeras91 <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
/bot run |
/bot run |
PR_Github #8670 [ run ] triggered by Bot |
PR_Github #8670 [ run ] completed with state |
Description
Currently on main, the mamba2 block forward pass has some dynamic memory allocations and host-to-device copies, greatly hurting its performance. This PR improves performance by minimizing these memory operations:
state_indices
tensors on device instead of moving it to device during forward passMambaMetadata
class holdingcu_seqlens
andseq_idx
needed for varlen batched prefill of the SSM op. Compute them fromattn_metadata
at the start of the model forward pass instead of doing it inside theMamba2Mixer
block of each mamba layer. This also means we create these tensors once and not multiple times in each layer.Mamba2Mixer
forward pass. Replace it with 2 if statements, checking if we have prefills / decodes in the current batch.These changes removed many of the GPU bubbles present in the mamba forward block, as seen in these profiles. They lead to 40% latency reduction.
Benchmarks
The modifications in this PR improved Nemotron-H max throughput by ~15% and min prefill latency by ~40% without sacrificing quality:
MMLU results
Performance benchmarks
Benchmark setting:
max throughput results:
min latency results: