Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Phi-3 MLP layer #84

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
22c627c
Add support for phi-3 MLP layer
Jul 1, 2024
ea4b27a
Updating support for Phi-3 MLP
Jul 2, 2024
39c070c
Update for Phi-3 MLP testing
Jul 3, 2024
2042fab
Merge branch 'main' into sarah/feature/phi3MLP_layer
Jul 5, 2024
5660cc3
Merge branch 'intel:main' into sarah/feature/phi3MLP_layer
SarahByrneIntel Jul 5, 2024
727454e
Update for phi-3 mlp layer
Jul 8, 2024
00a64f0
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
Jul 8, 2024
100fe88
Merge branch 'intel:main' into sarah/feature/phi3MLP_layer
SarahByrneIntel Jul 8, 2024
ea4ea19
Remove old code for phi-3 mlp layer
Jul 8, 2024
53c7b0d
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
Jul 8, 2024
1fef8a4
Add type tensor op and quantisation support
Jul 12, 2024
cc5d373
add support for model quantisation and code clean up
Jul 15, 2024
ff47c1d
Merge branch 'main' into sarah/feature/phi3MLP_layer
SarahByrneIntel Jul 15, 2024
d2fe9fe
Fix for model quantization
Jul 15, 2024
b7825e7
Add testing for phi-3 mlp quantisation
Jul 16, 2024
c652859
Add phi-3 mlp test and enable model profiling toggling
Jul 17, 2024
786c663
Update for model profiling toggle
Jul 17, 2024
003d639
Add compile config feature
Jul 18, 2024
c63c223
Fix test for compile config and remove old code
Jul 18, 2024
e652eaa
Fix tests with compile config
Jul 18, 2024
7f2faf9
Fix for compiler, updates for tests and examples, doc update
Jul 18, 2024
4b5f857
Update for model examples and remove test code
Jul 18, 2024
2718e13
Merge branch 'main' into sarah/feature/phi3MLP_layer
alessandropalla Jul 19, 2024
ae1fd61
Fix for quantization and remove unused code
Jul 19, 2024
5d578a1
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
Jul 19, 2024
2890299
Update for quantization of a model
Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions docs/source/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,33 @@ optimized_model = torch.compile(model, backend="npu")

In windows torch.compile is not supported yet. So you might want to use the explicit function `intel_npu_acceleration_library.compile`. This is true also if you use a `pytorch` version < 2.0.0

To do this, you just need to call the `compile` function with your model and the compiler configuration `CompilerConfig` to compile and optimize the model for the NPU.
```python
import intel_npu_acceleration_library
optimized_model = intel_npu_acceleration_library.compile(model, dtype=torch.int8)
from intel_npu_acceleration_library.compiler import CompilerConfig
compiler_conf = CompilerConfig(dtype=torch.int8)
optimized_model = intel_npu_acceleration_library.compile(model, compiler_conf)

# Use the model as usual

```

To compile and optimize a single layer of a model to be pushed to the NPU as one block, you can set `use_to=True` in the the compiler configuration `CompilerConfig`.
```python
import intel_npu_acceleration_library
from intel_npu_acceleration_library.compiler import CompilerConfig
compiler_conf = CompilerConfig(use_to=True, dtype=torch.int8)
optimized_block = intel_npu_acceleration_library.compile(single_block, compiler_conf)

```

## Training (**Experimental!**)

It is possible to use Intel® NPU Acceleration Library to train a model. As before you just need to call the `compile` function, this time with `training=True`. This allows to use the same training script you use in other device with a very minimal modifications.

```python
import intel_npu_acceleration_library
compiled_model = intel_npu_acceleration_library.compile(model, dtype=torch.float32, training=True)
from intel_npu_acceleration_library.compiler import CompilerConfig
compiler_conf = CompilerConfig(dtype=torch.float32, training=True)
compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf)
```
6 changes: 4 additions & 2 deletions examples/compile_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


from intel_npu_acceleration_library import compile
from intel_npu_acceleration_library.compiler import CompilerConfig
from sklearn.metrics import r2_score
import intel_npu_acceleration_library
import pytest
Expand Down Expand Up @@ -41,12 +42,13 @@ def forward(self, x):
print(
"Windows do not support torch.compile, fallback to intel_npu_acceleration_library.compile"
)
compiled_model = intel_npu_acceleration_library.compile(model)
compiler_conf = CompilerConfig()
compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf)
else:
compiled_model = torch.compile(model, backend="npu")

# Get the NPU output
with torch.no_grad():
y = compiled_model(x)

print(f"Reference vs actual R2 score: {r2_score(y_ref, y):.2f}")
print(f"Reference vs actual R2 score: {r2_score(y_ref.numpy(), y.numpy()):.2f}")
4 changes: 3 additions & 1 deletion examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM, int4
from intel_npu_acceleration_library.compiler import CompilerConfig

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

compiler_conf = CompilerConfig(dtype=int4)
model = NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, dtype=int4, attn_implementation="sdpa"
model_id, use_cache=True, config=compiler_conf, attn_implementation="sdpa"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down
6 changes: 5 additions & 1 deletion examples/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM, int4
from intel_npu_acceleration_library.compiler import CompilerConfig

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

model = NPUModelForCausalLM.from_pretrained(model_id, dtype=int4, use_cache=True).eval()
compiler_conf = CompilerConfig(dtype=int4)
model = NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, config=compiler_conf
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

Expand Down
4 changes: 3 additions & 1 deletion examples/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TextStreamer,
)
from transformers.feature_extraction_utils import BatchFeature
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library
import torch

Expand All @@ -21,7 +22,8 @@
# Load model
model = LlavaForConditionalGeneration.from_pretrained(checkpoint)

model = intel_npu_acceleration_library.compile(model)
compiler_conf = CompilerConfig()
model = intel_npu_acceleration_library.compile(model, compiler_conf)

image_processor = CLIPImageProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
Expand Down
4 changes: 3 additions & 1 deletion examples/phi-2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from langchain.chains import LLMChain
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, pipeline, TextStreamer
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library as npu_lib

model_id = "microsoft/Phi-2"

compiler_conf = CompilerConfig(dtype=npu_lib.int4)
model = npu_lib.NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, dtype=npu_lib.int4
model_id, use_cache=True, config=compiler_conf
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
Expand Down
4 changes: 3 additions & 1 deletion examples/phi-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import torch
from transformers import AutoTokenizer, pipeline, TextStreamer
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library as npu_lib
import warnings

torch.random.manual_seed(0)

compiler_conf = CompilerConfig(dtype=npu_lib.int4)
model = npu_lib.NPUModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
config=compiler_conf,
torch_dtype="auto",
dtype=npu_lib.int4,
)

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
Expand Down
6 changes: 5 additions & 1 deletion examples/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForSeq2SeqLM
from intel_npu_acceleration_library.compiler import CompilerConfig

model_id = "google/flan-t5-small"

model = NPUModelForSeq2SeqLM.from_pretrained(model_id, use_cache=True).eval()
compiler_conf = CompilerConfig()
model = NPUModelForSeq2SeqLM.from_pretrained(
model_id, use_cache=True, config=compiler_conf
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
Expand Down
4 changes: 3 additions & 1 deletion examples/tiny_llama_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

from transformers import pipeline, TextStreamer, set_seed
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library
import torch
import os
Expand All @@ -15,7 +16,8 @@
"text-generation", model=model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
print("Compiling the model for NPU...")
pipe.model = intel_npu_acceleration_library.compile(pipe.model, dtype=torch.int8)
compiler_conf = CompilerConfig(dtype=torch.int8)
pipe.model = intel_npu_acceleration_library.compile(pipe.model, compiler_conf)

streamer = TextStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)

Expand Down
5 changes: 3 additions & 2 deletions examples/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch import nn
import intel_npu_acceleration_library
from intel_npu_acceleration_library.compiler import CompilerConfig
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
Expand Down Expand Up @@ -90,8 +91,8 @@ def test_loop(dataloader, model, loss_fn):


model = NeuralNetwork()

model = intel_npu_acceleration_library.compile(model, torch.float32, training=True)
compiler_conf = CompilerConfig(dtype=torch.float32, training=True)
model = intel_npu_acceleration_library.compile(model, compiler_conf)

learning_rate = 1e-3
batch_size = 64
Expand Down
12 changes: 12 additions & 0 deletions intel_npu_acceleration_library/backend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,18 @@ def to(self, dtype: NPUDtype) -> "Tensor":
"""
return generate_op([self], "to", dtype)

def type(self, dtype: NPUDtype) -> "Tensor":
"""
Convert the tensor to the specified data type.

Args:
dtype (NPUDtype): The data type to convert the tensor to.

Returns:
Tensor: The converted tensor.
"""
return self.to(dtype)

@classmethod
def __torch_function__(
cls: Any,
Expand Down
Loading
Loading