Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlexAttention] Add initial benchmarks #3578

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

mfrancepillois
Copy link
Contributor

Add benchmarks to evaluate flex attention kernels performances.
Add these benchmarks to CI workflow (need to install a specific pytorch version with XPU FlexAttention support enabled).

Add benchmarks to evaluate flex attention kernels performances.
Add these benchmarks to CI workflow (need to install a specific pytorch version with XPU FlexAttention support enabled).
@mfrancepillois mfrancepillois marked this pull request as ready for review March 3, 2025 12:59
@mfrancepillois mfrancepillois linked an issue Mar 3, 2025 that may be closed by this pull request
@mfrancepillois
Copy link
Contributor Author

@liangan1, this PR adds two FlexAttention-based benchmarks to start monitoring FlexAttention performance. Could you please have a look?

@@ -37,8 +37,12 @@ runs:
ITEM_PATH="${{ inputs.root }}/${{ inputs.key }}"
echo "dest=$ITEM_PATH" >> $GITHUB_OUTPUT
if [[ -d ${{ inputs.path }} ]]; then
echo "Directory ${{ inputs.path }} exists and will not be restored from cache"
exit 1
if [[ ${{ inputs.repository == 'liangan1/pytorch' }} ]]; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need it here. Just add rm -rf pytorch in the workflow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workflow has been modified to delete the directory if it already exists (regardless the repository).

@@ -37,8 +37,8 @@ runs:
ITEM_PATH="${{ inputs.root }}/${{ inputs.key }}"
echo "dest=$ITEM_PATH" >> $GITHUB_OUTPUT
if [[ -d ${{ inputs.path }} ]]; then
echo "Directory ${{ inputs.path }} exists and will not be restored from cache"
exit 1
echo "Directory ${{ inputs.path }} already exists and will not be removed"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand, will not be removed, then why remove on the next line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to understand what you mean. The original code was modified in response to this comment from Pavel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pbchekin Is the current change what you expected? Why echo will not be removed, then remove it right after?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my mistake. I didn't read my comment correctly. The comment has been updated.

@@ -37,8 +37,8 @@ runs:
ITEM_PATH="${{ inputs.root }}/${{ inputs.key }}"
echo "dest=$ITEM_PATH" >> $GITHUB_OUTPUT
if [[ -d ${{ inputs.path }} ]]; then
echo "Directory ${{ inputs.path }} exists and will not be restored from cache"
exit 1
echo "Directory ${{ inputs.path }} already exists and will not be removed"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the comment. The directory here exists and the next line will remove it.

@@ -45,8 +45,14 @@ runs:
if: inputs.ref != ''
shell: bash
run: |
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV"
echo "PYTORCH_COMMIT_ID=${{ steps.commit-id.outputs.commit_id }}" | tee -a "$GITHUB_ENV"
if [[ "${{ inputs.repository }}" = "liangan1/pytorch" ]]; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "liangan1" ? Why do we need to use a personal directory ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to fetch and install a specific pytorch version with XPU support for FlexAttention. Currently this code is only available in Liangang's fork (named liangan1).

x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
for z in [1, 2, 4, 8, 16, 32]
for (h, dhead) in [(16, 128), (32, 64)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to align the requirements in the https://jira.devtools.intel.com/browse/TRITONXPU-172. e.g., GQA/MHA, paged kv cache. More head dim, sequence length converge.

Comment on lines 40 to 51
+ [[z, h, 1024, dhead, True, mode]
for z in [1, 2, 4, 8, 16, 32, 64]
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
+ [[z, h, 1024 + 64, dhead, True, mode]
for z in [1, 2, 4, 8, 16, 32]
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
+ [[z, h, 1024 + 128 + 512, dhead, True, mode]
for z in [1, 2, 4, 8, 16, 32]
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]], #
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the largest classical shape for LLM, specified in https://jira.devtools.intel.com/browse/TRITONXPU-172, cannot be evaluated due resource limitations on PVC.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two kernel for Flexattention, flex-attention for prefill and flex-decoding for the decoding stages. Only the sequence length of 16384/z(query, key, value) is covered in this benchmar and this is only for the prefill stage. In the real case, there are prefill(len(q)=len(k)=len(v), decoding(len(q)=1<<=len(k)=len(v)) and extend stage(e.g., multi-round chat. len(q)>1 and len(q)<len(k)=len(v)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these additional explanations. As enhancing the benchmarks to evaluate the performance of other stages (Decode and Append), GQA and paged KV cache requires significant work to improve the benchmarks, I would prefer this PR to focus on adding an initial benchmark for FlexAttention prefill stage only (similar to our current FA benchmark) and address the remaining limitations in different PRs. I have created #3615, #3616, #3617 for this purpose.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these additional explanations. As enhancing the benchmarks to evaluate the performance of other stages (Decode and Append), GQA and paged KV cache requires significant work to improve the benchmarks, I would prefer this PR to focus on adding an initial benchmark for FlexAttention prefill stage only (similar to our current FA benchmark) and address the remaining limitations in different PRs. I have created #3615, #3616, #3617 for this purpose.

Make sense. This PR is a good start point.

@mfrancepillois mfrancepillois changed the title [FlexAttention] Add benchmarks [FlexAttention] Add initial benchmarks Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add FlexAttention to benchmarks/triton_kernels_benchmark
5 participants