-
Notifications
You must be signed in to change notification settings - Fork 434
Support Context Parallel for Multi Latent Attention (MLA) #1729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
091edf3
to
f290c61
Compare
Signed-off-by: Yuzhong Wang <[email protected]>
Signed-off-by: Yuzhong Wang <[email protected]>
for more information, see https://pre-commit.ci
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Yuzhong Wang <[email protected]>
for more information, see https://pre-commit.ci
Hi @cyanguwa , could you help review this PR? We aim to get CP support in MCore v0.13 (code freeze by mid-June). |
Does this PR also cover the A100? |
).squeeze(0) | ||
v_part = tex.thd_read_half_tensor( | ||
v_part.unsqueeze(0), cu_seqlens_kv_padded, 0 | ||
).squeeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please answer this question?
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Yuzhong Wang <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve because all CP tests passed.
LGTM. Just re-running the B200 test - will merge after it passes. Thanks! |
The CI pipeline had some issues for the B200 test, but I ran it locally and it seems to be fine. Merging! |
Description
#1561 has already fixed the issue that the function
AttnFuncWithCPAndKVP2P
does not support MLA (Multi-latent attention). Specifically, #1561 pad tensor v to head_dim_qk and convert MLA to normal attention. This PR improves #1561 by removing the padding and using MLA kernels to reduce the communication and computation overhead.Many thanks to SuperCB from xiaohongshu and RandMist from wechat team for their contributions.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: