Skip to content

Commit

Permalink
fix: add position_embeddings args to LlamaAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
nagic0 authored Jul 24, 2024
1 parent 86e4e1e commit 432b3e6
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions intel_npu_acceleration_library/nn/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from intel_npu_acceleration_library.nn import Linear
from intel_npu_acceleration_library.backend import run_factory, MLP
from functools import partial
from typing import Optional, List, Generator
from typing import Optional, List, Generator, Tuple
from transformers.cache_utils import Cache
import torch
import uuid
Expand Down Expand Up @@ -169,6 +169,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in Transformers v4.45
):
"""Torch module forward method.
Expand All @@ -180,6 +181,7 @@ def forward(
output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers.. Defaults to False.
use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned. Defaults to False.
cache_position (Optional[torch.LongTensor], optional): Cache position useful for static cache applications . Defaults to None.
position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): If set to a tuple, it means the `sin` and `cos` are uniformly calculated by the outer `LlamaModel` and passed in. Defaults to None.
Returns:
_type_: result
Expand All @@ -202,7 +204,10 @@ def forward(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
Expand Down

0 comments on commit 432b3e6

Please sign in to comment.