Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Commit 432b3e6

Browse files
authored
1 parent 86e4e1e commit 432b3e6

File tree

1 file changed

+7
-2
lines changed
  • intel_npu_acceleration_library/nn

1 file changed

+7
-2
lines changed

intel_npu_acceleration_library/nn/llm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from intel_npu_acceleration_library.nn import Linear
1313
from intel_npu_acceleration_library.backend import run_factory, MLP
1414
from functools import partial
15-
from typing import Optional, List, Generator
15+
from typing import Optional, List, Generator, Tuple
1616
from transformers.cache_utils import Cache
1717
import torch
1818
import uuid
@@ -169,6 +169,7 @@ def forward(
169169
output_attentions: Optional[bool] = False,
170170
use_cache: Optional[bool] = False,
171171
cache_position: Optional[torch.LongTensor] = None,
172+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in Transformers v4.45
172173
):
173174
"""Torch module forward method.
174175
@@ -180,6 +181,7 @@ def forward(
180181
output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers.. Defaults to False.
181182
use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned. Defaults to False.
182183
cache_position (Optional[torch.LongTensor], optional): Cache position useful for static cache applications . Defaults to None.
184+
position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): If set to a tuple, it means the `sin` and `cos` are uniformly calculated by the outer `LlamaModel` and passed in. Defaults to None.
183185
184186
Returns:
185187
_type_: result
@@ -202,7 +204,10 @@ def forward(
202204
bsz, q_len, self.num_key_value_heads, self.head_dim
203205
).transpose(1, 2)
204206

205-
cos, sin = self.rotary_emb(value_states, position_ids)
207+
if position_embeddings is None:
208+
cos, sin = self.rotary_emb(value_states, position_ids)
209+
else:
210+
cos, sin = position_embeddings
206211

207212
query_states, key_states = apply_rotary_pos_emb(
208213
query_states, key_states, cos, sin, position_ids

0 commit comments

Comments
 (0)