Skip to content

Commit

Permalink
Remove assert check for pythion 3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandropalla committed Jul 25, 2024
1 parent e0805f1 commit f3b112a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
11 changes: 9 additions & 2 deletions intel_npu_acceleration_library/nn/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from intel_npu_acceleration_library.nn import Linear
from intel_npu_acceleration_library.backend import run_factory, MLP
from functools import partial
from typing import Optional, List, Generator
from typing import Optional, List, Generator, Tuple
from transformers.cache_utils import Cache
import torch
import uuid
Expand Down Expand Up @@ -169,6 +169,9 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in Transformers v4.45
):
"""Torch module forward method.
Expand All @@ -180,6 +183,7 @@ def forward(
output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers.. Defaults to False.
use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned. Defaults to False.
cache_position (Optional[torch.LongTensor], optional): Cache position useful for static cache applications . Defaults to None.
position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): If set to a tuple, it means the `sin` and `cos` are uniformly calculated by the outer `LlamaModel` and passed in. Defaults to None.
Returns:
_type_: result
Expand All @@ -202,7 +206,10 @@ def forward(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
Expand Down
4 changes: 1 addition & 3 deletions test/python/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def test_torch_compile():
model = NN()
y_ref = model(x.to(torch.float32)).detach()

if (
sys.platform == "win32" and Version(torch.__version__) < Version("2.2.2")
) or sys.version_info >= (3, 12):
if sys.platform == "win32" and Version(torch.__version__) < Version("2.2.2"):
with pytest.raises(RuntimeError) as e:
compiled_model = torch.compile(model, backend="npu")
assert str(e.value) == "Windows not yet supported for torch.compile"
Expand Down

0 comments on commit f3b112a

Please sign in to comment.