Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BERT] Add support for sdpa #28802

Merged
merged 22 commits into from
Apr 26, 2024
Merged

[BERT] Add support for sdpa #28802

merged 22 commits into from
Apr 26, 2024

Conversation

hackyon
Copy link
Contributor

@hackyon hackyon commented Jan 31, 2024

What does this PR do?

Adding support for SDPA (scaled dot product attention) for Bert. More context in #28005.

Benchmarking Results on A100-80GB, CPUx12, RAM 96.6GB, OS Ubuntu 22.04, using BertLMHeadModel

Training benchmark based on fxmarty's script:

num_training_steps batch_size seq_len Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
1000 1 256 0.022 0.018 23.905 1128.190 1065.286 5.905
1000 1 512 0.034 0.028 20.473 1345.791 1093.933 23.023
1000 2 256 0.031 0.026 18.701 1175.685 1093.933 7.473
1000 2 512 0.057 0.047 21.315 2123.874 1370.097 55.016
1000 4 256 0.052 0.044 16.446 1784.135 1369.489 30.277
1000 4 512 0.106 0.087 21.524 3706.609 2196.791 68.728

Inference benchmark based on fxmarty's script:

num_batches batch_size seq_len Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 1 64 5.906 5.420 8.962 271.610 271.407 0.075
50 1 128 5.825 5.402 7.834 279.157 279.718 -0.200
50 2 64 6.190 5.349 15.709 291.489 291.751 -0.090
50 2 128 6.168 5.360 15.066 307.514 307.776 -0.085
50 4 64 6.262 5.392 16.137 332.177 332.440 -0.079
50 4 128 6.201 5.382 15.215 364.271 364.742 -0.129

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @younesbelkada

(cc @fxmarty)

@hackyon
Copy link
Contributor Author

hackyon commented Jan 31, 2024

Hey @ArthurZucker @younesbelkada

I was thinking SDPA (#28005) could be a good addition to BERT, so I drafted this change. It doesn't look too hairy so far.

As @ArthurZucker mentioned, BERT doesn't have a lot of params so there might not be much of a speedup, but this didn't look too difficult to implement so I figured whatever little improvement might still be helpful (as an aside, there's been some benchmarking of Flash Attention on training other implementations of BERT, and it still shows decent improvements).

Can you let me know if this is worth pursuing? If so, I'll add the tests and also fix the fix-copies dependencies.

Thanks!

# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a
Copy link
Contributor Author

@hackyon hackyon Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed in torch 2.2.0 I think, maybe I should check for it and skip the calls?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine to leave. We should probably bump the requirement for SDPA to torch>=2.2 in the future.

Copy link
Contributor Author

@hackyon hackyon Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This got me thinking, and I ran an additional set of benchmarking, given that FA2 is supported and the contiguous bug is fixed in 2.2.0: training and inference.

Both training and inference were ~5% faster with torch==2.2.0 (FA2 should be supported). I also tried out gating the .contiguous() requirement and saw an additional ~5-10% gain on top of that.

if version.parse(get_torch_version()) < version.parse("2.2.0")
    query_layer = query_layer.contiguous()
    key_layer = key_layer.contiguous()
    value_layer = value_layer.contiguous()

I'm leaning towards adding the if-statement to gate the call, so users who upgrade to torch=2.2.0 first can get the benefits right away (before we set the min torch version to 2.2.0). WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the if-statement for 2.2.0 in there. If you don't think it's a good idea, let me know and I'll remove it.

@ArthurZucker
Copy link
Collaborator

I think a good way to se if it is worth the shot is to benchmark your code and check if you have speedups in different contexts!

@hackyon
Copy link
Contributor Author

hackyon commented Feb 1, 2024

Sounds good, lemme look into that

@hackyon
Copy link
Contributor Author

hackyon commented Feb 6, 2024

@ArthurZucker I did some training and inference benchmarking for my change and posted the results in the PR description.

It looks like there are decent improvements across the board (percentage-wise, but I think the improvements would add up if we're doing a lot of training/inferencing). I think it could be a good addition. Thoughts?

@ArthurZucker
Copy link
Collaborator

Sounds like a good addition then! I'll let @fxmarty review and will be doing the final pass!

@pommedeterresautee
Copy link
Contributor

pommedeterresautee commented Feb 7, 2024

Just curious, is it similar to #27478 ?
Seems also #28713 is highly related.

@hackyon
Copy link
Contributor Author

hackyon commented Feb 7, 2024

re: @pommedeterresautee

Yes, it's similar. SDPA is built into pytorch, and can support Flash Attention (1) depending on the environment. AFAIK Flash Attention 2 isn't supported in SDPA yet, but there is a possibility for it to be supported down the road (but that should be built into pytorch already, and shouldn't need many changes from our end).

@pommedeterresautee
Copy link
Contributor

Thanks, I think it is now
https://pytorch.org/blog/pytorch2-2/
scaled_dot_product_attention (SDPA) now supports FlashAttention-2, yielding around 2x speedups compared to previous versions.

@hackyon
Copy link
Contributor Author

hackyon commented Feb 7, 2024

Oh nice, so I guess we could get FA2 for free eventually (when we upgrade pytorch).

Thanks for the links to similar work. I think they could cause some merge conflicts, so I'll message them and try to resolve it before it goes in.

Copy link
Collaborator

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks in good shape thank you, left a few comments

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably move the Copied from just to the __init__ and other methods, but not forward. For the forward, you can probably just add a comment that it is adapted from bert/roberta and once bridge_tower supports sdpa we can put back to copied from.

WDYT @ArthurZucker @amyeroberts

Copy link
Contributor Author

@hackyon hackyon Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be 8 methods that copy-from BertMode#forward() exactly and has this section of change.

I won't mind adding SDPA to them as well once this goes in and reinstating the copy-from. It shouldn't be that difficult (famous last words)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the fix-copies from the instances, and so the logic for sdpa attention masks should only be in BertModel now.

src/transformers/models/camembert/modeling_camembert.py Outdated Show resolved Hide resolved
src/transformers/models/clap/modeling_clap.py Outdated Show resolved Hide resolved
src/transformers/models/data2vec/modeling_data2vec_text.py Outdated Show resolved Hide resolved
src/transformers/models/roberta/modeling_roberta.py Outdated Show resolved Hide resolved
src/transformers/models/bert/modeling_bert.py Outdated Show resolved Hide resolved
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if self.config.is_decoder:
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker there are create_extended_attention_mask_for_decoder, invert_attention_mask, get_extended_attention_mask methods in modeling_utils.py that should probably be deprecated / redirect to modeling_attn_mask_utils.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I agree.

It'd be great if we could mark those old methods as deprecated, and slowly update them once we verify that the old methods and the new methods are always returning the same results.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the updated_attention_mask for sdpa, why can't we keep the previous logic and just do:

                # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
                # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
                # Details: https://github.com/pytorch/pytorch/issues/110213
                causal_mask = causal_mask.mul(~torch.all(causal_mask == torch.finfo(embedding_output.dtype).min, dim=-1, keepdim=True)).to(
                    dtype
                )

(from Llama)?
Not super fan of the complexity of _prepare_4d_causal_attention_mask_for_sdpa, and we should not add it in our new code IMO.

@@ -451,12 +451,10 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
Copy link
Contributor Author

@hackyon hackyon Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was changed to pass the fx tracing test (in common tests).

It would be good if you can help double check the logic here. I think the idea here is that we'll still have to use our own attention mask (rather than None) when tracing is active. The previous "pass" would cause the function to end without any return statements, which would have defaulted to None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks OK to me, cc @fxmarty to confirm.

AFAICT, the difference here is coming from the additional isinstance(mask, torch.fx.Proxy) in the is_tracing_check. I don't believe the reworking to remove pass should affect anything - the new code is equivalent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is fine, see

is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)

@@ -692,6 +807,10 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix was added due to a test failure that uncovered an existing bug.

The head was initialized but the weights weren't retied as necessary. This was causing self.decoder.bias to be different from self.bias. When loading the pretrained model with low_cpu_mem_usage=True, the self.decoder.bias had uninitiated params (with device=meta) whereas self.bias was set properly (with device=cpu)

I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled this out to its own PR here:
#28948

This issue is unrelated to SDPA, but was just uncovered by a SPDA test, so I just pulled it out to its own PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addition looks OK to me - thanks for digging into this.

I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.

Could you expand on what you think might be an issue?

Copy link
Contributor Author

@hackyon hackyon Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was initially concerned that users were loading and using the model with a wrong bias (ie. device=meta), and this fix to use the correct bias will cause the results to change between versions.

However, that seems unlikely after playing around with this a bit more - turns out it was quite difficult to run the model when the bias had device=meta, so I doubt anyone was actually running the model in this particular configuration before the fix.

@@ -3560,8 +3564,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self._prepare_for_class is necessary to support the BertForMultipleChoice model.

@hackyon hackyon marked this pull request as ready for review February 8, 2024 15:21
@hackyon
Copy link
Contributor Author

hackyon commented Feb 8, 2024

I've rebased off of head and marked as ready for review. I had to dig through a couple of issues to get the tests to pass, let me now if you want to chat about any of them.

Thanks!

@amyeroberts
Copy link
Collaborator

@fxmarty @hackyon There's still several tests failing related to this PR. Once these are resolved you can ping me again for a final review

@hackyon
Copy link
Contributor Author

hackyon commented Feb 8, 2024

The tests are passing now. I also verified that test_modeling_bert passes with RUN_SLOW=1 (which contains the tests to ensure model output is the same for eager and sdpa attentions)

Please take another look when you get a chance. Thanks!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work adding this @hackyon as well as the additional work to dig into weird errors and find solutions. Great work!

Some general comments:

@@ -451,12 +451,10 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks OK to me, cc @fxmarty to confirm.

AFAICT, the difference here is coming from the additional isinstance(mask, torch.fx.Proxy) in the is_tracing_check. I don't believe the reworking to remove pass should affect anything - the new code is equivalent.

src/transformers/models/altclip/modeling_altclip.py Outdated Show resolved Hide resolved
src/transformers/models/bert/modeling_bert.py Show resolved Hide resolved
@@ -692,6 +807,10 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addition looks OK to me - thanks for digging into this.

I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.

Could you expand on what you think might be an issue?

src/transformers/models/xmod/modeling_xmod.py Outdated Show resolved Hide resolved
@fxmarty
Copy link
Collaborator

fxmarty commented Mar 19, 2024

@hackyon could you merge/rebase on main?

@hackyon
Copy link
Contributor Author

hackyon commented Mar 19, 2024

Sure, I just merged with main/HEAD. @amyeroberts @ArthurZucker do you mind taking a look?

I'm having trouble starting my cloud server right now due to high demand, but I'll run it through the slow tests later on when it works again.

@ArthurZucker
Copy link
Collaborator

Hey! Sure I was off for a bit but will have a look

@ArthurZucker ArthurZucker self-requested a review March 25, 2024 08:24
@ArthurZucker
Copy link
Collaborator

Oh wow

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's rebase on main!

@hackyon
Copy link
Contributor Author

hackyon commented Apr 6, 2024

Thanks!

I merged with main/HEAD, and re-ran the RUN_SLOW tests for both bert and also for test_eager_matches_sdpa_inference and they work as expected. There were existing failures for test_eager_matches_sdpa_inference with RUN_SLOW on main/HEAD, but nothing new introduced by this change.

I'm not sure about this test_pipelines_tf failure. I haven't touched any code with tf, and I was able to get the failing test test_stop_sequence_stopping_criteria to pass locally, so I'm thinking it's a flake or unrelated to this change.

@amyeroberts
Copy link
Collaborator

Hi @hackyon - great to see this ready to merge!

The generation tests aren't related to this diff and are failing on other PRs. We're working to push a fix to main - will let you know when resolved, you can rebase and hopefully we have full 🟢 for merging 🤗

@hackyon
Copy link
Contributor Author

hackyon commented Apr 11, 2024

Thanks @amyeroberts @ArthurZucker

Just remerged with main/HEAD, and the unrelated failing TF pipeline test now passes. I checked the bert tests again with RUN_SLOW for good measure, and they continue to pass.

Let me know if there's anything else I could do here. Thanks!

@hackyon
Copy link
Contributor Author

hackyon commented Apr 15, 2024

@ArthurZucker Please let me know if there's anything else you'd like me to do for this PR. Thanks!

@hackyon
Copy link
Contributor Author

hackyon commented Apr 22, 2024

Remerged with the latest main, and fixed a test.

@ArthurZucker @amyeroberts @fxmarty Please let me know if there's anything I can do here.

@amyeroberts
Copy link
Collaborator

@hackyon Everything's green and two approvals, so we're good to merge. Thanks for all the effort in adding this and iterating with us. It's great to have this added to one of the most popular models ❤️

@amyeroberts amyeroberts merged commit dfa7b58 into huggingface:main Apr 26, 2024
19 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hackyon
Copy link
Contributor Author

hackyon commented Apr 26, 2024

Thanks @amyeroberts for the merge! 🎉 I appreciate all the help from @fxmarty, @ArthurZucker, and you in getting this PR merged 🙏

I see you've submitted #30506 as a follow-up, and thank you for covering that. Please let me know if there's any other follow-up work, and I'd be happy to look into it.

@hackyon hackyon deleted the sdpa-bert branch April 26, 2024 19:01
@hackyon
Copy link
Contributor Author

hackyon commented Apr 26, 2024

As I mentioned previously, I've also drafted a PR for adding SDPA support to RoBERTa-based models at #30510. Almost all of the changes are "Copied from" BERT, and so there is a little less room for error.

itazap pushed a commit that referenced this pull request May 14, 2024
* Adding SDPA support for BERT

* Using the proper input name for testing model input in inference()

* Adding documentation for SDPA in BERT model page

* Use the stable link for the documentation

* Adding a gate to only call .contiguous() for torch < 2.2.0

* Additions and fixes to the documentation

* Minor updates to documentation

* Adding extra requirements needed for the contiguous() bug

* Adding "Adapted from" in plcae of the "Copied from"

* Add benchmark speedup tables to the documentation

* Minor fixes to the documentation

* Use ClapText as a replacemenet for Bert in the Copied-From

* Some more fixes for the fix-copies references

* Overriding the test_eager_matches_sdpa_generate in bert tests to not load with low_cpu_mem_usage

[test all]

* Undo changes to separate test

* Refactored SDPA self attention code for KV projections

* Change use_sdpa to attn_implementation

* Fix test_sdpa_can_dispatch_on_flash by preparing input (required for MultipleChoice models)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants