Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RoBERTa-based] Add support for sdpa #30510

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hackyon
Copy link
Contributor

@hackyon hackyon commented Apr 26, 2024

What does this PR do?

Adding support for SDPA (scaled dot product attention) for RoBERTa-based models. More context in #28005 and #28802.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@fxmarty @ArthurZucker @amyeroberts

@hackyon
Copy link
Contributor Author

hackyon commented Apr 26, 2024

I ran slow tests for the affected models, and verified that they all pass except XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(). I suspect it's just some numerical computation error, but I'll take a quick look to see if I can find anything.

I'll also try to run some the perf benchmarks on RoBERTa over the weekend to see how they behave.

@hackyon
Copy link
Contributor Author

hackyon commented Apr 27, 2024

Preliminary perf numbers for Roberta (using "roberta-base" with AutoModel/Tokenizer).

Training

num_training_steps batch_size seq_len is cuda Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
1000 1 256 True 0.018 0.015 24.411 731.752 736.471 -0.641
1000 1 512 True 0.019 0.016 17.819 823.792 757.096 8.809
1000 2 256 True 0.020 0.016 29.890 760.504 757.096 0.450
1000 2 512 True 0.020 0.016 25.317 1283.793 907.688 41.435
1000 4 256 True 0.020 0.016 28.907 1094.001 907.289 20.579
1000 4 512 True 0.025 0.021 19.153 2205.299 1446.666 52.440

Inference

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 2 64 True True True 5.357 5.067 5.716 333.956 333.956 0
50 2 128 True True True 5.534 5.181 6.812 360.089 360.089 0
50 2 256 True True True 5.823 5.516 5.577 412.355 412.355 0
50 4 64 True True True 5.632 5.344 5.381 385.611 385.611 0
50 4 128 True True True 6.101 5.849 4.304 437.895 437.877 0.004
50 4 256 True True True 6.91 6.529 5.824 542.598 542.598 0

@hackyon
Copy link
Contributor Author

hackyon commented Apr 27, 2024

It seems like XLMRobertaXLModelTest::test_eager_matches_sdpa_generate() doesn't always fail, but it's flaky and depends on the random number generator. I think it is due to computation/numerical stability, which can result in slightly different results.

EDIT: I added a set_seed(0) to XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(), and the flake seems to have gone away.

@hackyon hackyon force-pushed the sdpa-roberta branch 2 times, most recently from c39f457 to 41537e3 Compare April 29, 2024 17:42
@hackyon
Copy link
Contributor Author

hackyon commented Apr 29, 2024

@fxmarty @ArthurZucker @amyeroberts

This is ready for review! With the exception of the changes to the test and check_support_list.py, all the changes are coming from "Copied From". Please let me know if you have any questions!

@hackyon hackyon marked this pull request as ready for review April 29, 2024 17:50
@hackyon hackyon mentioned this pull request May 8, 2024
5 tasks
@michaelshekasta
Copy link

@hackyon, I'm curious about whether implementing flash_atten is essential when writing an SDPA. I came across claims that flash_atten can offer up to a x4 efficiency boost (roughly) compared to native PyTorch. However, your remarks in #30510 suggest that the actual improvement is less than 50%. Could you help shed some light on this apparent difference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants