-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention and Open Delta LoRA #50
Comments
Sorry, we can't find the library where |
My apologies. I thought I had linked it previously. Here is the link to the Huggingface wrappers utilizing Flash Attention: https://github.com/kyleliang919/Long-context-transformers/blob/main/flash_attn_wrappers.py The wrappers changed a bit so now it would be the NEOX one: class FlashAttentionWrapperWithRotary(torch.nn.Module):
def __init__(self, attention, max_seqlen = 8192):
super().__init__()
self.attention = attention
self.max_seqlen = max_seqlen
self.flash_self_attention = FlashSelfAttention(causal = True, softmax_scale = 1/self.attention.norm_factor)
self.dropout_p = 0.0
def forward(self,
hidden_states,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.attention.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.attention.num_attention_heads, 3 * self.attention.head_size)
qkv = qkv.view(*new_qkv_shape)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.attention.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.attention.head_size : 2 * self.attention.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.attention.head_size :].permute(0, 2, 1, 3)
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.attention.rotary_ndims]
query_pass = query[..., self.attention.rotary_ndims :]
key_rot = key[..., : self.attention.rotary_ndims]
key_pass = key[..., self.attention.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
offset = 0
if has_layer_past:
offset = layer_past[0].shape[-2]
seq_len += offset
cos, sin = self.attention.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
#attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
qkv = torch.concat([query.unsqueeze(2), key.unsqueeze(2), value.unsqueeze(2)], dim = 2).permute(0, 3, 2, 1, 4).half()
attn_output = self.flash_self_attention(qkv)
attn_weights = None
# Reshape outputs
attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.attention.num_attention_heads * self.attention.head_size)
attn_output = self.attention.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs Thank you, Enrico |
Sorry for not replying in time. Flash Attention cannot be easily compatible with Open Delta LoRA. LoRA essentially splits the parameters of the model (such as q) into two matrices AB. Flash Attention needs to use AB as a whole, but Open Delta LoRA only treats AB as two parts. Therefore, the current Open Delta LoRA cannot directly adapt to Flash Attention. |
Hello @ShengdingHu,
Are you able to confirm whether Flash Attention will be compatible with Open Delta LoRA?
For example:
Thank you for your great work,
Enrico
The text was updated successfully, but these errors were encountered: