-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[https://nvbugspro.nvidia.com/bug/5329655] [feat] Pytorch path add spec dec param to attention op #5146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[https://nvbugspro.nvidia.com/bug/5329655] [feat] Pytorch path add spec dec param to attention op #5146
Conversation
Signed-off-by: Jhao-Ting Chen <[email protected]>
enqueue_params.spec_decoding_generation_lengths | ||
= spec_decoding_generation_lengths.value().data_ptr<int32_t>(); | ||
enqueue_params.spec_decoding_is_generation_length_variable = true; | ||
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @lowsfer, I have question toward the computation of these masks:
Take draft_len = 4 for example, the parameters for spec-dec gen step would be:
spec_decoding_max_generation_length: 5
generation_input_length: 4
spec_decoding_packed_mask (shape: bs * generation_input_length * ceil(generation_input_length/32)): [1,3,7,15]
spec_decoding_position_offset (shape: bs * generation_input_length): [0,1,2,3]
I got this from previous eagle implementation for the spec-dec masks. There was a series of kernels computing mask for tree-based path, but in PyTorch flow only linear tree is supported.
I adopted this simpler computation for the masks first, if the tree-based spec-dec is supported, we can add the spec-dec mask logics into torch flow. What's your opinion on this, thanks!
@@ -486,6 +514,10 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch: | |||
|
|||
op->mAttentionChunkSize = attention_chunk_size; | |||
|
|||
op->mUseSpecDecoding = use_spec_dec; // true | |||
op->mIsSpecDecodingEnabled = spec_decoding_position_offsets.has_value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference between mUseSpecDecoding
and mIsSpecDecodingEnabled
? In which case will they be different?
", bool use_spec_dec" | ||
", Tensor? spec_decoding_position_offsets" | ||
", Tensor? spec_decoding_packed_mask" | ||
", Tensor? spec_decoding_generation_lengths" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will have 63 arguments after this PR, which is very close to 64. Can we also pack these three spec_decoding_xxx
to List[Tensor]? spec_decoding_tensors
, so that other people will not reach the limit immediately?
self.rotary_embedding_long_m_scale | ||
], | ||
[ | ||
self.rotary_embedding_max_positions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't pack arguments in trtllm backend, let's hide it inside the attention
op defined in torch_custom_ops.py. As I commented in the discussion, we can keep the attention
op API unchanged when we switch to the packed solution.
attention_chunk_size: Optional[int], use_spec_dec: bool, | ||
spec_decoding_position_offsets: Optional[torch.Tensor], | ||
spec_decoding_packed_mask: Optional[torch.Tensor], | ||
spec_decoding_generation_lengths: Optional[torch.Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add ,
after the last argument, so that the format style can be unchanged.
attention_chunk_size: Optional[int], use_spec_dec: bool, | ||
spec_decoding_position_offsets: Optional[torch.Tensor], | ||
spec_decoding_packed_mask: Optional[torch.Tensor], | ||
spec_decoding_generation_lengths: Optional[torch.Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add ,
after the last argument, so that the format style can be unchanged.
@@ -687,6 +688,7 @@ def forward( | |||
position_ids=position_ids, | |||
hidden_states=hidden_states, | |||
attn_metadata=attn_metadata, | |||
spec_metadata=spec_metadata, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make spec_metadata
a part of attn_metadata
so that the modeling file can keep unchanged if spec_metadata
is only used in attention module and all layers share the same spec_metadata
?
use_spec_dec = ( | ||
spec_metadata.use_spec_dec | ||
and spec_metadata.spec_dec_mode.require_multi_query_attn_kernel( | ||
get_attention_backend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If spec_metadata
is part of attn_metadata
, we don't need to set use_spec_dec
according to attention backend. Only trtllm backend will handle spec_metadata
, all other backends can ignore it.
868c22e
to
f279591
Compare
Signed-off-by: Jhao-Ting Chen <[email protected]>
6177d22
to
eef3be2
Compare
@@ -182,6 +186,25 @@ def __post_init__(self): | |||
dtype=torch.int, | |||
device='cuda', | |||
) | |||
self.spec_decoding_position_offsets = torch.empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we skip preparing these tensors if we don't need them, e.g., on Blackwell GPUs?
Description
Llama3 8b + Eagle3 Before this PR:
After this PR:
trtllm-bench for llama3-8b + Eagle3 not runnable for before and after this PR.
trtllm-bench for llama3-70b + Eagle3: (acceptance rate is fixed)
(conc=4, aa dataset)
Llama4 + Eagle3:
(bs1)
(conc4)
I'm still not sure if llama4 + eagl3 on Hopper is supposed to have this slight drop in acceptance length, not sure if it's precision error or XQA kernel handling nope layer is different.
Potential perf drop when copying / computing the spec-dec mask, due to cpu-to-gpu copy.
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]
Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-[Post-Merge]-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.