-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FlexAttention] Add initial benchmarks #3578
base: main
Are you sure you want to change the base?
Conversation
Add benchmarks to evaluate flex attention kernels performances. Add these benchmarks to CI workflow (need to install a specific pytorch version with XPU FlexAttention support enabled).
@liangan1, this PR adds two FlexAttention-based benchmarks to start monitoring FlexAttention performance. Could you please have a look? |
.github/actions/load/action.yml
Outdated
@@ -37,8 +37,12 @@ runs: | |||
ITEM_PATH="${{ inputs.root }}/${{ inputs.key }}" | |||
echo "dest=$ITEM_PATH" >> $GITHUB_OUTPUT | |||
if [[ -d ${{ inputs.path }} ]]; then | |||
echo "Directory ${{ inputs.path }} exists and will not be restored from cache" | |||
exit 1 | |||
if [[ ${{ inputs.repository == 'liangan1/pytorch' }} ]]; then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need it here. Just add rm -rf pytorch
in the workflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The workflow has been modified to delete the directory if it already exists (regardless the repository).
Co-authored-by: Pavel Chekin <[email protected]>
Co-authored-by: Pavel Chekin <[email protected]>
.github/actions/load/action.yml
Outdated
@@ -37,8 +37,8 @@ runs: | |||
ITEM_PATH="${{ inputs.root }}/${{ inputs.key }}" | |||
echo "dest=$ITEM_PATH" >> $GITHUB_OUTPUT | |||
if [[ -d ${{ inputs.path }} ]]; then | |||
echo "Directory ${{ inputs.path }} exists and will not be restored from cache" | |||
exit 1 | |||
echo "Directory ${{ inputs.path }} already exists and will not be removed" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand, will not be removed, then why remove on the next line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure to understand what you mean. The original code was modified in response to this comment from Pavel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbchekin Is the current change what you expected? Why echo will not be removed, then remove it right after?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, my mistake. I didn't read my comment correctly. The comment has been updated.
.github/actions/load/action.yml
Outdated
@@ -37,8 +37,8 @@ runs: | |||
ITEM_PATH="${{ inputs.root }}/${{ inputs.key }}" | |||
echo "dest=$ITEM_PATH" >> $GITHUB_OUTPUT | |||
if [[ -d ${{ inputs.path }} ]]; then | |||
echo "Directory ${{ inputs.path }} exists and will not be restored from cache" | |||
exit 1 | |||
echo "Directory ${{ inputs.path }} already exists and will not be removed" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the comment. The directory here exists and the next line will remove it.
@@ -45,8 +45,14 @@ runs: | |||
if: inputs.ref != '' | |||
shell: bash | |||
run: | | |||
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV" | |||
echo "PYTORCH_COMMIT_ID=${{ steps.commit-id.outputs.commit_id }}" | tee -a "$GITHUB_ENV" | |||
if [[ "${{ inputs.repository }}" = "liangan1/pytorch" ]]; then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is "liangan1" ? Why do we need to use a personal directory ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need to fetch and install a specific pytorch version with XPU support for FlexAttention. Currently this code is only available in Liangang's fork (named liangan1).
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'], | ||
x_vals=[[z, h, 16384 // z, dhead, causal, mode] | ||
for z in [1, 2, 4, 8, 16, 32] | ||
for (h, dhead) in [(16, 128), (32, 64)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to align the requirements in the https://jira.devtools.intel.com/browse/TRITONXPU-172. e.g., GQA/MHA, paged kv cache. More head dim, sequence length converge.
+ [[z, h, 1024, dhead, True, mode] | ||
for z in [1, 2, 4, 8, 16, 32, 64] | ||
for (h, dhead) in [(8, 128), (32, 96), (4, 128)] | ||
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # | ||
+ [[z, h, 1024 + 64, dhead, True, mode] | ||
for z in [1, 2, 4, 8, 16, 32] | ||
for (h, dhead) in [(8, 128), (32, 96), (4, 128)] | ||
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # | ||
+ [[z, h, 1024 + 128 + 512, dhead, True, mode] | ||
for z in [1, 2, 4, 8, 16, 32] | ||
for (h, dhead) in [(8, 128), (32, 96), (4, 128)] | ||
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]], # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the largest classical shape for LLM, specified in https://jira.devtools.intel.com/browse/TRITONXPU-172, cannot be evaluated due resource limitations on PVC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two kernel for Flexattention, flex-attention for prefill and flex-decoding for the decoding stages. Only the sequence length of 16384/z(query, key, value) is covered in this benchmar and this is only for the prefill stage. In the real case, there are prefill(len(q)=len(k)=len(v), decoding(len(q)=1<<=len(k)=len(v)) and extend stage(e.g., multi-round chat. len(q)>1 and len(q)<len(k)=len(v)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these additional explanations. As enhancing the benchmarks to evaluate the performance of other stages (Decode and Append), GQA and paged KV cache requires significant work to improve the benchmarks, I would prefer this PR to focus on adding an initial benchmark for FlexAttention prefill stage only (similar to our current FA benchmark) and address the remaining limitations in different PRs. I have created #3615, #3616, #3617 for this purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these additional explanations. As enhancing the benchmarks to evaluate the performance of other stages (Decode and Append), GQA and paged KV cache requires significant work to improve the benchmarks, I would prefer this PR to focus on adding an initial benchmark for FlexAttention prefill stage only (similar to our current FA benchmark) and address the remaining limitations in different PRs. I have created #3615, #3616, #3617 for this purpose.
Make sense. This PR is a good start point.
Add benchmarks to evaluate flex attention kernels performances.
Add these benchmarks to CI workflow (need to install a specific pytorch version with XPU FlexAttention support enabled).