Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RoBERTa-based] Add support for sdpa #30510

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.o
For now, Transformers supports SDPA inference and training for the following architectures:
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
Expand All @@ -217,8 +218,10 @@ For now, Transformers supports SDPA inference and training for the following arc
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)

* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)

<Tip>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ class PreTrainedModel
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

# Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
Expand Down
173 changes: 157 additions & 16 deletions src/transformers/models/camembert/modeling_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN, gelu
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
Expand All @@ -40,6 +45,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
get_torch_version,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -297,6 +303,104 @@ def forward(
return outputs


# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert
class CamembertSdpaSelfAttention(CamembertSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type=position_embedding_type)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

# Adapted from CamembertSelfAttention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal
# mask in case tgt_len == 1.
is_causal = self.is_decoder and attention_mask is None and tgt_len > 1

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert
class CamembertSelfOutput(nn.Module):
def __init__(self, config):
Expand All @@ -314,6 +418,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to

CAMEMBERT_SELF_ATTENTION_CLASSES = {
"eager": CamembertSelfAttention,
"sdpa": CamembertSdpaSelfAttention,
}


Expand Down Expand Up @@ -606,6 +711,7 @@ class CamembertPreTrainedModel(PreTrainedModel):
config_class = CamembertConfig
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
_supports_sdpa = True

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down Expand Up @@ -752,7 +858,7 @@ class CamembertModel(CamembertPreTrainedModel):

_no_split_modules = []

# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Camembert
# Copied from transformers.models.roberta.modeling_roberta.RobertaModel.__init__ with Roberta->Camembert
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
Expand All @@ -762,6 +868,9 @@ def __init__(self, config, add_pooling_layer=True):

self.pooler = CamembertPooler(config) if add_pooling_layer else None

self.attn_implementation = config._attn_implementation
self.position_embedding_type = config.position_embedding_type

# Initialize weights and apply final processing
self.post_init()

Expand All @@ -785,7 +894,7 @@ class PreTrainedModel
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
# Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -849,9 +958,6 @@ def forward(
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
Expand All @@ -860,9 +966,43 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)

if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)

use_sdpa_attention_masks = (
self.attn_implementation == "sdpa"
and self.position_embedding_type == "absolute"
and head_mask is None
and not output_attentions
)

# Expand the attention mask
if use_sdpa_attention_masks:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
if self.config.is_decoder:
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
embedding_output,
past_key_values_length,
)
else:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand All @@ -871,7 +1011,15 @@ def forward(
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)

if use_sdpa_attention_masks:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None

Expand All @@ -882,13 +1030,6 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
Expand Down