Skip to content

Commit

Permalink
Support for Phi-3 MLP layer (#84)
Browse files Browse the repository at this point in the history
* Add support for phi-3 MLP layer

* Updating support for Phi-3 MLP

* Update for Phi-3 MLP testing

* Update for phi-3 mlp layer

* Remove old code for phi-3 mlp layer

* Add type tensor op and quantisation support

* add support for model quantisation and code clean up

* Fix for model quantization

* Add testing for phi-3 mlp quantisation

* Add phi-3 mlp test and enable model profiling toggling

* Update for model profiling toggle

* Add compile config feature

* Fix test for compile config and remove old code

* Fix tests with compile config

* Fix for compiler, updates for tests and examples, doc update

* Update for model examples and remove test code

* Fix for quantization and remove unused code

* Update for quantization of a model

---------

Co-authored-by: SarahByrneIntel <[email protected]>
Co-authored-by: Alessandro Palla <[email protected]>
  • Loading branch information
3 people authored Jul 19, 2024
1 parent 5d3e83a commit 2193535
Show file tree
Hide file tree
Showing 23 changed files with 421 additions and 64 deletions.
18 changes: 16 additions & 2 deletions docs/source/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,33 @@ optimized_model = torch.compile(model, backend="npu")

In windows torch.compile is not supported yet. So you might want to use the explicit function `intel_npu_acceleration_library.compile`. This is true also if you use a `pytorch` version < 2.0.0

To do this, you just need to call the `compile` function with your model and the compiler configuration `CompilerConfig` to compile and optimize the model for the NPU.
```python
import intel_npu_acceleration_library
optimized_model = intel_npu_acceleration_library.compile(model, dtype=torch.int8)
from intel_npu_acceleration_library.compiler import CompilerConfig
compiler_conf = CompilerConfig(dtype=torch.int8)
optimized_model = intel_npu_acceleration_library.compile(model, compiler_conf)

# Use the model as usual

```

To compile and optimize a single layer of a model to be pushed to the NPU as one block, you can set `use_to=True` in the the compiler configuration `CompilerConfig`.
```python
import intel_npu_acceleration_library
from intel_npu_acceleration_library.compiler import CompilerConfig
compiler_conf = CompilerConfig(use_to=True, dtype=torch.int8)
optimized_block = intel_npu_acceleration_library.compile(single_block, compiler_conf)

```

## Training (**Experimental!**)

It is possible to use Intel® NPU Acceleration Library to train a model. As before you just need to call the `compile` function, this time with `training=True`. This allows to use the same training script you use in other device with a very minimal modifications.

```python
import intel_npu_acceleration_library
compiled_model = intel_npu_acceleration_library.compile(model, dtype=torch.float32, training=True)
from intel_npu_acceleration_library.compiler import CompilerConfig
compiler_conf = CompilerConfig(dtype=torch.float32, training=True)
compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf)
```
6 changes: 4 additions & 2 deletions examples/compile_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


from intel_npu_acceleration_library import compile
from intel_npu_acceleration_library.compiler import CompilerConfig
from sklearn.metrics import r2_score
import intel_npu_acceleration_library
import pytest
Expand Down Expand Up @@ -41,12 +42,13 @@ def forward(self, x):
print(
"Windows do not support torch.compile, fallback to intel_npu_acceleration_library.compile"
)
compiled_model = intel_npu_acceleration_library.compile(model)
compiler_conf = CompilerConfig()
compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf)
else:
compiled_model = torch.compile(model, backend="npu")

# Get the NPU output
with torch.no_grad():
y = compiled_model(x)

print(f"Reference vs actual R2 score: {r2_score(y_ref, y):.2f}")
print(f"Reference vs actual R2 score: {r2_score(y_ref.numpy(), y.numpy()):.2f}")
4 changes: 3 additions & 1 deletion examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM, int4
from intel_npu_acceleration_library.compiler import CompilerConfig

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

compiler_conf = CompilerConfig(dtype=int4)
model = NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, dtype=int4, attn_implementation="sdpa"
model_id, use_cache=True, config=compiler_conf, attn_implementation="sdpa"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down
6 changes: 5 additions & 1 deletion examples/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM, int4
from intel_npu_acceleration_library.compiler import CompilerConfig

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

model = NPUModelForCausalLM.from_pretrained(model_id, dtype=int4, use_cache=True).eval()
compiler_conf = CompilerConfig(dtype=int4)
model = NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, config=compiler_conf
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

Expand Down
4 changes: 3 additions & 1 deletion examples/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TextStreamer,
)
from transformers.feature_extraction_utils import BatchFeature
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library
import torch

Expand All @@ -21,7 +22,8 @@
# Load model
model = LlavaForConditionalGeneration.from_pretrained(checkpoint)

model = intel_npu_acceleration_library.compile(model)
compiler_conf = CompilerConfig()
model = intel_npu_acceleration_library.compile(model, compiler_conf)

image_processor = CLIPImageProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
Expand Down
4 changes: 3 additions & 1 deletion examples/phi-2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from langchain.chains import LLMChain
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, pipeline, TextStreamer
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library as npu_lib

model_id = "microsoft/Phi-2"

compiler_conf = CompilerConfig(dtype=npu_lib.int4)
model = npu_lib.NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, dtype=npu_lib.int4
model_id, use_cache=True, config=compiler_conf
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
Expand Down
4 changes: 3 additions & 1 deletion examples/phi-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import torch
from transformers import AutoTokenizer, pipeline, TextStreamer
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library as npu_lib
import warnings

torch.random.manual_seed(0)

compiler_conf = CompilerConfig(dtype=npu_lib.int4)
model = npu_lib.NPUModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
config=compiler_conf,
torch_dtype="auto",
dtype=npu_lib.int4,
)

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
Expand Down
6 changes: 5 additions & 1 deletion examples/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForSeq2SeqLM
from intel_npu_acceleration_library.compiler import CompilerConfig

model_id = "google/flan-t5-small"

model = NPUModelForSeq2SeqLM.from_pretrained(model_id, use_cache=True).eval()
compiler_conf = CompilerConfig()
model = NPUModelForSeq2SeqLM.from_pretrained(
model_id, use_cache=True, config=compiler_conf
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
Expand Down
4 changes: 3 additions & 1 deletion examples/tiny_llama_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

from transformers import pipeline, TextStreamer, set_seed
from intel_npu_acceleration_library.compiler import CompilerConfig
import intel_npu_acceleration_library
import torch
import os
Expand All @@ -15,7 +16,8 @@
"text-generation", model=model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
print("Compiling the model for NPU...")
pipe.model = intel_npu_acceleration_library.compile(pipe.model, dtype=torch.int8)
compiler_conf = CompilerConfig(dtype=torch.int8)
pipe.model = intel_npu_acceleration_library.compile(pipe.model, compiler_conf)

streamer = TextStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)

Expand Down
5 changes: 3 additions & 2 deletions examples/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch import nn
import intel_npu_acceleration_library
from intel_npu_acceleration_library.compiler import CompilerConfig
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
Expand Down Expand Up @@ -90,8 +91,8 @@ def test_loop(dataloader, model, loss_fn):


model = NeuralNetwork()

model = intel_npu_acceleration_library.compile(model, torch.float32, training=True)
compiler_conf = CompilerConfig(dtype=torch.float32, training=True)
model = intel_npu_acceleration_library.compile(model, compiler_conf)

learning_rate = 1e-3
batch_size = 64
Expand Down
12 changes: 12 additions & 0 deletions intel_npu_acceleration_library/backend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,18 @@ def to(self, dtype: NPUDtype) -> "Tensor":
"""
return generate_op([self], "to", dtype)

def type(self, dtype: NPUDtype) -> "Tensor":
"""
Convert the tensor to the specified data type.
Args:
dtype (NPUDtype): The data type to convert the tensor to.
Returns:
Tensor: The converted tensor.
"""
return self.to(dtype)

@classmethod
def __torch_function__(
cls: Any,
Expand Down
Loading

0 comments on commit 2193535

Please sign in to comment.