diff --git a/intel_npu_acceleration_library/nn/llm.py b/intel_npu_acceleration_library/nn/llm.py index 8cf6cd3..ff97a15 100644 --- a/intel_npu_acceleration_library/nn/llm.py +++ b/intel_npu_acceleration_library/nn/llm.py @@ -12,7 +12,7 @@ from intel_npu_acceleration_library.nn import Linear from intel_npu_acceleration_library.backend import run_factory, MLP from functools import partial -from typing import Optional, List, Generator +from typing import Optional, List, Generator, Tuple from transformers.cache_utils import Cache import torch import uuid @@ -169,6 +169,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in Transformers v4.45 ): """Torch module forward method. @@ -180,6 +181,7 @@ def forward( output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers.. Defaults to False. use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned. Defaults to False. cache_position (Optional[torch.LongTensor], optional): Cache position useful for static cache applications . Defaults to None. + position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): If set to a tuple, it means the `sin` and `cos` are uniformly calculated by the outer `LlamaModel` and passed in. Defaults to None. Returns: _type_: result @@ -202,7 +204,10 @@ def forward( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids