Skip to content

[TRTLLM-5835][feat] Optimized Mamba2Mixer prefill #5128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 71 commits into
base: main
Choose a base branch
from

Conversation

tomeras91
Copy link
Collaborator

@tomeras91 tomeras91 commented Jun 11, 2025

Description

Currently on main, the mamba2 block forward pass has some dynamic memory allocations and host-to-device copies, greatly hurting its performance. This PR improves performance by minimizing these memory operations:

  1. pre-allocate the state_indices tensors on device instead of moving it to device during forward pass
  2. introduce MambaMetadata class holding cu_seqlens and seq_idx needed for varlen batched prefill of the SSM op. Compute them from attn_metadata at the start of the model forward pass instead of doing it inside the Mamba2Mixer block of each mamba layer. This also means we create these tensors once and not multiple times in each layer.
  3. Remove the redundant for-loop on request type (prefill/decode) in the Mamba2Mixer forward pass. Replace it with 2 if statements, checking if we have prefills / decodes in the current batch.

These changes removed many of the GPU bubbles present in the mamba forward block, as seen in these profiles. They lead to 40% latency reduction.

image
image

Benchmarks

The modifications in this PR improved Nemotron-H max throughput by ~15% and min prefill latency by ~40% without sacrificing quality:

MMLU results

Version MMLU Score
Main 69.98
PR 69.98

Performance benchmarks

Benchmark setting:

  • Used genai-perf
  • Done on a single RTX6000
  • ISL / OSL = 223 / 140 (similar to ShareGPT statistics)
  • nax_batch_size = 64
  • Max throughput setting: --request-count 500 --concurrency 500 --warmup-request-count 10
  • Min latency setting: --request-count 10 --concurrency 1 --warmup-request-count 1

 
max throughput results:

Metric TRTLLM main PR Change [%]
Request Latency [ms] 30,614.38 26,482.98 -13.5%
Time To First Token [ms] 25,107.02 21,563.04 -14.1%
Inter Token Latency [ms] 40.91 36.97 -9.6%
Output Token Throughput [tokens/sec] 1,235.88 1,400.71 +13.3%
Output Token Throughput Per User [tokens/sec/user] 24.84 27.68 +11.4%
Request Throughput [req/sec] 9.14 10.43 +14.1%

 
min latency results:

Metric TRTLLM main PR Change [%]
Request Latency [ms] 3,283.65 2,362.60 -28.0%
Time To First Token [ms] 82.55 50.54 -38.8%
Inter Token Latency [ms] 24.58 17.77 -27.7%
Output Token Throughput [tokens/sec] 39.89 55.44 +39.0%
Output Token Throughput Per User [tokens/sec/user] 40.69 56.28 +38.3%
Request Throughput [req/sec] 0.30 0.42 +40.0%

tomeras91 added 30 commits May 7, 2025 20:13
…b/causal-conv1d

Signed-off-by: Tomer Asida <[email protected]>

Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>

Signed-off-by: Tomer Asida <[email protected]>
…ze in RMSNorm

Signed-off-by: Tomer Asida <[email protected]>

Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…kernels (similar to the tests in tests/unittest/_torch/thop/test_mamba_conv1d_op.py)

Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…nels for better numerical stability and support for initial states + varlen batching (AKA continuous batching)

Signed-off-by: Tomer Asida <[email protected]>
… prefill and decode kernels (similar to the tests in tests/unittest/_torch/thop/test_selective_scan_op.py)

Signed-off-by: Tomer Asida <[email protected]>
…tiple tensors and torch.cat

Signed-off-by: Tomer Asida <[email protected]>
… Results in +25% throughput

(1) call convolution and SSM explicitly so no need for special call to get conv states
(2) same dtype for conv and ssm states
(3) remove unused code - causal_conv1d_varlen_states, mamba_split_conv1d_scan_combined

Signed-off-by: Tomer Asida <[email protected]>
…ause instead of duplicating code

Signed-off-by: Tomer Asida <[email protected]>
…ng. conv weights are already in correct shape

Signed-off-by: Tomer Asida <[email protected]>
…RTLLM in-house mamba_conv1d kernel

Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…. Use standard TRTLLM types and macros when needed

Signed-off-by: Tomer Asida <[email protected]>
tomeras91 added 10 commits June 5, 2025 16:33
…-LLM into fix-nemotron-h-warmup

Signed-off-by: Tomer Asida <[email protected]>
…de Nemotron-H forward pass.

This is instead of preparing cu_seqlens and seq_idx in MambaCacheManager
for better code separation. Also because in
MambaCacheManager.prepare_resources(), attn_metadata is not updated yet
and we need it to create cu_seqlens efficiently. This also makes it
similar to the regular attn_metadata flow, creating it if needed and
preparing it before forward pass. The difference is that for regular
attention this is done in PyTorchModelEngine, and for mamba we do it
inside the model forward, since hybrid models are still a special case
and we want to isolate the relevant code

Signed-off-by: Tomer Asida <[email protected]>
(they appear first in attn_metadata.seq_lens_cuda)

Signed-off-by: Tomer Asida <[email protected]>
…Manager as self.cu_seqlens and self.seq_idx don't exist anymore

Signed-off-by: Tomer Asida <[email protected]>
@tomeras91 tomeras91 requested review from a team as code owners June 11, 2025 12:35
@tomeras91 tomeras91 requested review from HuiGao-NV, symphonylyh and Copilot and removed request for symphonylyh and HuiGao-NV June 11, 2025 12:35
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the Mamba2Mixer prefill performance by reducing dynamic memory allocations and host-to-device copies. Key changes include:

  • Pre-allocating state indices on the correct device in the resource manager.
  • Introducing a new Mamba2Metadata class to pre-compute and store metadata for varlen batched prefill.
  • Refactoring the Mamba2Mixer forward pass to separate prefill and decode logic and eliminate a redundant loop.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
tensorrt_llm/_torch/pyexecutor/resource_manager.py Assigns the correct device for state indices to minimize host-to-device transfers.
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py Refactors forward pass to split prefill and decode logic using new metadata and indices.
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py Introduces a metadata class to compute and hold cu_seqlens and sequence indices.
tensorrt_llm/_torch/models/modeling_nemotron_h.py Integrates mamba metadata into layer forward passes for optimized processing.
Comments suppressed due to low confidence (2)

tensorrt_llm/_torch/pyexecutor/resource_manager.py:598

  • Consider adding an inline comment explaining that the device is explicitly set using self.ssm_states.device to ensure correct GPU allocation, and verify that self.ssm_states is initialized prior to this call.
self.state_indices = torch.as_tensor(state_indices,

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py:178

  • [nitpick] Consider adding a brief comment to clarify that state_indices is split into prefill and decode subsets based on the computed batch_split_size, which will aid future maintainers in understanding the code logic.
state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size)

… prepare() is called

Co-authored-by: Copilot <[email protected]>
Signed-off-by: tomeras91 <[email protected]>
@tomeras91 tomeras91 changed the title [TRTLLM-4923][feat] Optimized Mamba2Mixer prefill [TRTLLM-5835][feat] Optimized Mamba2Mixer prefill Jun 11, 2025
@tomeras91
Copy link
Collaborator Author

/bot run

@tomeras91 tomeras91 requested a review from Naveassaf June 12, 2025 11:00
@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8670 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8670 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #6288 completed with status: 'SUCCESS'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants