Skip to content

[https://nvbugspro.nvidia.com/bug/5329655] [feat] Pytorch path add spec dec param to attention op #5146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jhaotingc
Copy link
Collaborator

@jhaotingc jhaotingc commented Jun 12, 2025

Description

Llama3 8b + Eagle3 Before this PR:

H200:
[0] Prompt: 'Hello, my name is', Generated text: ' Emily and!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
[1] Prompt: 'The president of the United States is', Generated text: ' the!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
[2] Prompt: 'The capital of France is', Generated text: ' a!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
[3] Prompt: 'The future of AI is', Generated text: ' bright!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

After this PR:

[0] Prompt: 'Hello, my name is', Generated text: ' Emily and I am a 25-year-old freelance writer and editor. I have a passion for storytelling and a knack for crafting compelling narratives. I have been writing for over 10 years, and I have honed my skills through various writing projects, including articles, blog posts, and short stories.\nI'
[1] Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government of the United States. The president serves a four-year term and is limited to two terms. The president is elected through the Electoral College system, where each state is allocated a certain number of electoral votes based on its population. The candidate who receives the majority of the'
[2] Prompt: 'The capital of France is', Generated text: ' a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided'
[3] Prompt: 'The future of AI is', Generated text: ' bright, but it also raises concerns about bias, accountability, and the impact on jobs. Here are some of the key challenges and opportunities that AI will face in the coming years.\nThe future of AI is bright, but it also raises concerns about bias, accountability, and the impact on jobs. Here are some'

trtllm-bench for llama3-8b + Eagle3 not runnable for before and after this PR.

trtllm-bench for llama3-70b + Eagle3: (acceptance rate is fixed)
(conc=4, aa dataset)

h100
===========================================================
= DECODING STATISTICS (Eagle)
===========================================================

-- Acceptance Rate Details --------------------------------


[AR] MINIMUM: 1.47
[AR] MAXIMUM: 2.06
[AR] AVERAGE: 1.82
[AR] P50    : 1.87
[AR] P90    : 2.03
[AR] P95    : 2.06
[AR] P99    : 2.06

b200
===========================================================
= DECODING STATISTICS (Eagle)
===========================================================

-- Acceptance Rate Details --------------------------------


[AR] MINIMUM: 1.49
[AR] MAXIMUM: 2.08
[AR] AVERAGE: 1.80
[AR] P50    : 1.86
[AR] P90    : 2.04
[AR] P95    : 2.08
[AR] P99    : 2.08
===========================================================

Llama4 + Eagle3:
(bs1)

H200                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      -- Acceptance Rate Details --------------------------------                                                                                                                                                                                                                                 

[AR] MINIMUM: 2.64
[AR] MAXIMUM: 3.15
[AR] AVERAGE: 2.94
[AR] P50    : 3.04
[AR] P90    : 3.15
[AR] P95    : 3.15
[AR] P99    : 3.15


B200
-- Acceptance Rate Details --------------------------------


[AR] MINIMUM: 2.31
[AR] MAXIMUM: 3.04
[AR] AVERAGE: 2.59
[AR] P50    : 2.55
[AR] P90    : 3.04
[AR] P95    : 3.04
[AR] P99    : 3.04

(conc4)

H200
-- Acceptance Rate Details --------------------------------


[AR] MINIMUM: 1.74
[AR] MAXIMUM: 3.87
[AR] AVERAGE: 2.14
[AR] P50    : 1.89
[AR] P90    : 3.34
[AR] P95    : 3.87
[AR] P99    : 3.87

B200
-- Acceptance Rate Details --------------------------------


[AR] MINIMUM: 2.32
[AR] MAXIMUM: 3.98
[AR] AVERAGE: 3.25
[AR] P50    : 3.47
[AR] P90    : 3.98
[AR] P95    : 3.98
[AR] P99    : 3.98

I'm still not sure if llama4 + eagl3 on Hopper is supposed to have this slight drop in acceptance length, not sure if it's precision error or XQA kernel handling nope layer is different.
Potential perf drop when copying / computing the spec-dec mask, due to cpu-to-gpu copy.

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]

Launch build/test pipelines. All previously running jobs will be killed.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@jhaotingc jhaotingc requested a review from a team as a code owner June 12, 2025 05:12
enqueue_params.spec_decoding_generation_lengths
= spec_decoding_generation_lengths.value().data_ptr<int32_t>();
enqueue_params.spec_decoding_is_generation_length_variable = true;
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lowsfer, I have question toward the computation of these masks:

Take draft_len = 4 for example, the parameters for spec-dec gen step would be:

spec_decoding_max_generation_length: 5
generation_input_length: 4
spec_decoding_packed_mask (shape: bs * generation_input_length * ceil(generation_input_length/32)): [1,3,7,15]
spec_decoding_position_offset (shape: bs * generation_input_length): [0,1,2,3]

I got this from previous eagle implementation for the spec-dec masks. There was a series of kernels computing mask for tree-based path, but in PyTorch flow only linear tree is supported.

I adopted this simpler computation for the masks first, if the tree-based spec-dec is supported, we can add the spec-dec mask logics into torch flow. What's your opinion on this, thanks!

@@ -486,6 +514,10 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:

op->mAttentionChunkSize = attention_chunk_size;

op->mUseSpecDecoding = use_spec_dec; // true
op->mIsSpecDecodingEnabled = spec_decoding_position_offsets.has_value();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between mUseSpecDecoding and mIsSpecDecodingEnabled? In which case will they be different?

", bool use_spec_dec"
", Tensor? spec_decoding_position_offsets"
", Tensor? spec_decoding_packed_mask"
", Tensor? spec_decoding_generation_lengths"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have 63 arguments after this PR, which is very close to 64. Can we also pack these three spec_decoding_xxx to List[Tensor]? spec_decoding_tensors, so that other people will not reach the limit immediately?

self.rotary_embedding_long_m_scale
],
[
self.rotary_embedding_max_positions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't pack arguments in trtllm backend, let's hide it inside the attention op defined in torch_custom_ops.py. As I commented in the discussion, we can keep the attention op API unchanged when we switch to the packed solution.

attention_chunk_size: Optional[int], use_spec_dec: bool,
spec_decoding_position_offsets: Optional[torch.Tensor],
spec_decoding_packed_mask: Optional[torch.Tensor],
spec_decoding_generation_lengths: Optional[torch.Tensor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add , after the last argument, so that the format style can be unchanged.

attention_chunk_size: Optional[int], use_spec_dec: bool,
spec_decoding_position_offsets: Optional[torch.Tensor],
spec_decoding_packed_mask: Optional[torch.Tensor],
spec_decoding_generation_lengths: Optional[torch.Tensor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add , after the last argument, so that the format style can be unchanged.

@@ -687,6 +688,7 @@ def forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make spec_metadata a part of attn_metadata so that the modeling file can keep unchanged if spec_metadata is only used in attention module and all layers share the same spec_metadata?

use_spec_dec = (
spec_metadata.use_spec_dec
and spec_metadata.spec_dec_mode.require_multi_query_attn_kernel(
get_attention_backend(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If spec_metadata is part of attn_metadata, we don't need to set use_spec_dec according to attention backend. Only trtllm backend will handle spec_metadata, all other backends can ignore it.

@jhaotingc jhaotingc force-pushed the pytorch_add_spec_dec_param_to_attention branch 3 times, most recently from 868c22e to f279591 Compare June 13, 2025 03:12
@jhaotingc jhaotingc force-pushed the pytorch_add_spec_dec_param_to_attention branch from 6177d22 to eef3be2 Compare June 13, 2025 03:18
@@ -182,6 +186,25 @@ def __post_init__(self):
dtype=torch.int,
device='cuda',
)
self.spec_decoding_position_offsets = torch.empty(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip preparing these tensors if we don't need them, e.g., on Blackwell GPUs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants