From f3b112af4f4f78f971f527a87922a44706982870 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Thu, 25 Jul 2024 10:43:51 +0100 Subject: [PATCH] Remove assert check for pythion 3.12 --- intel_npu_acceleration_library/nn/llm.py | 11 +++++++++-- test/python/test_compile.py | 4 +--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/intel_npu_acceleration_library/nn/llm.py b/intel_npu_acceleration_library/nn/llm.py index 8cf6cd3..eeee94d 100644 --- a/intel_npu_acceleration_library/nn/llm.py +++ b/intel_npu_acceleration_library/nn/llm.py @@ -12,7 +12,7 @@ from intel_npu_acceleration_library.nn import Linear from intel_npu_acceleration_library.backend import run_factory, MLP from functools import partial -from typing import Optional, List, Generator +from typing import Optional, List, Generator, Tuple from transformers.cache_utils import Cache import torch import uuid @@ -169,6 +169,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in Transformers v4.45 ): """Torch module forward method. @@ -180,6 +183,7 @@ def forward( output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers.. Defaults to False. use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned. Defaults to False. cache_position (Optional[torch.LongTensor], optional): Cache position useful for static cache applications . Defaults to None. + position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): If set to a tuple, it means the `sin` and `cos` are uniformly calculated by the outer `LlamaModel` and passed in. Defaults to None. Returns: _type_: result @@ -202,7 +206,10 @@ def forward( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids diff --git a/test/python/test_compile.py b/test/python/test_compile.py index 63ce124..a103705 100644 --- a/test/python/test_compile.py +++ b/test/python/test_compile.py @@ -89,9 +89,7 @@ def test_torch_compile(): model = NN() y_ref = model(x.to(torch.float32)).detach() - if ( - sys.platform == "win32" and Version(torch.__version__) < Version("2.2.2") - ) or sys.version_info >= (3, 12): + if sys.platform == "win32" and Version(torch.__version__) < Version("2.2.2"): with pytest.raises(RuntimeError) as e: compiled_model = torch.compile(model, backend="npu") assert str(e.value) == "Windows not yet supported for torch.compile"