Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention and Open Delta LoRA #50

Open
conceptofmind opened this issue Feb 27, 2023 · 3 comments
Open

Flash Attention and Open Delta LoRA #50

conceptofmind opened this issue Feb 27, 2023 · 3 comments

Comments

@conceptofmind
Copy link

conceptofmind commented Feb 27, 2023

Hello @ShengdingHu,

Are you able to confirm whether Flash Attention will be compatible with Open Delta LoRA?

For example:

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
tokenizer.pad_token = tokenizer.mask_token

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-1.4b")

max_positions = model_args.max_positions
tokenizer.model_max_length = max_positions
for layer in model.gpt_neox.layers:
    original_emb = layer.attention.rotary_emb
    layer.attention.rotary_emb = RotaryEmbedding(layer.attention.rotary_ndims,max_positions,10000)
    layer.attention.bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
                1, 1, max_positions, max_positions
            )
    layer.attention = FlashAttentionWrapper(layer.attention, max_seqlen = max_positions)

# patching for the random contiguous tensors bug
for p in model.parameters():
    p = p.contiguous()

Visualization(model).structure_graph()

delta_model1 = LoraModel(
    backbone_model=model, 
    modified_modules=[
        'attention.attention.query_key_value',
        'mlp.dense_h_to_4h',
    ]
)
delta_model1.freeze_module()
delta_model1.log(delta_ratio=True, trainable_ratio=True, visualization=True)

Screenshot from 2023-02-26 19-32-48

Thank you for your great work,

Enrico

@telxt
Copy link
Collaborator

telxt commented Mar 13, 2023

Sorry, we can't find the library where FlashAttentionWrapper is located, could you please tell us which library it is?

@conceptofmind
Copy link
Author

conceptofmind commented Mar 13, 2023

Sorry, we can't find the library where FlashAttentionWrapper is located, could you please tell us which library it is?

My apologies. I thought I had linked it previously.

Here is the link to the Huggingface wrappers utilizing Flash Attention: https://github.com/kyleliang919/Long-context-transformers/blob/main/flash_attn_wrappers.py

The wrappers changed a bit so now it would be the NEOX one:

class FlashAttentionWrapperWithRotary(torch.nn.Module):
    def __init__(self, attention, max_seqlen = 8192):
        super().__init__()
        self.attention = attention
        self.max_seqlen = max_seqlen
        self.flash_self_attention = FlashSelfAttention(causal = True, softmax_scale = 1/self.attention.norm_factor)
        self.dropout_p = 0.0

    def forward(self,
        hidden_states,
        attention_mask,
        head_mask=None,
        layer_past=None,
        use_cache=False,
        output_attentions=False):
        has_layer_past = layer_past is not None

        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (np * 3 * head_size)]
        qkv = self.attention.query_key_value(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        new_qkv_shape = qkv.size()[:-1] + (self.attention.num_attention_heads, 3 * self.attention.head_size)
        qkv = qkv.view(*new_qkv_shape)
        
        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = qkv[..., : self.attention.head_size].permute(0, 2, 1, 3)
        key = qkv[..., self.attention.head_size : 2 * self.attention.head_size].permute(0, 2, 1, 3)
        value = qkv[..., 2 * self.attention.head_size :].permute(0, 2, 1, 3)

        # Compute rotary embeddings on rotary_ndims
        query_rot = query[..., : self.attention.rotary_ndims]
        query_pass = query[..., self.attention.rotary_ndims :]
        key_rot = key[..., : self.attention.rotary_ndims]
        key_pass = key[..., self.attention.rotary_ndims :]

        # Compute token offset for rotary embeddings (when decoding)
        seq_len = key.shape[-2]
        offset = 0
        if has_layer_past:
            offset = layer_past[0].shape[-2]
            seq_len += offset
        cos, sin = self.attention.rotary_emb(value, seq_len=seq_len)
        query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
        query = torch.cat((query, query_pass), dim=-1)
        key = torch.cat((key, key_pass), dim=-1)

        # Cache QKV values
        if has_layer_past:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        present = (key, value) if use_cache else None

        # Compute attention
        #attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        qkv = torch.concat([query.unsqueeze(2), key.unsqueeze(2), value.unsqueeze(2)], dim = 2).permute(0, 3, 2, 1, 4).half()
        attn_output = self.flash_self_attention(qkv)
        attn_weights = None

        # Reshape outputs
        attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.attention.num_attention_heads * self.attention.head_size)
        attn_output = self.attention.dense(attn_output)
        
        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

Thank you,

Enrico

@telxt
Copy link
Collaborator

telxt commented Mar 20, 2023

Sorry for not replying in time. Flash Attention cannot be easily compatible with Open Delta LoRA. LoRA essentially splits the parameters of the model (such as q) into two matrices AB. Flash Attention needs to use AB as a whole, but Open Delta LoRA only treats AB as two parts. Therefore, the current Open Delta LoRA cannot directly adapt to Flash Attention.
Sorry for not being able to help you, we will conduct further research and try our best to make improvements so that Open Delta can adapt to Flash Attention in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants