diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 7800d9f..3669aa7 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -1,4 +1,5 @@ name: Documentation +permissions: read-all on: workflow_dispatch: diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 2f8e10c..b4afaf5 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -1,4 +1,5 @@ name: Style +permissions: read-all on: workflow_dispatch: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 672ff52..5aaa61d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,5 @@ name: Test +permissions: read-all on: workflow_dispatch: diff --git a/README.md b/README.md index 9a566d5..21dd4cc 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The Intel® NPU Acceleration Library is a Python library designed to boost the efficiency of your applications by leveraging the power of the Intel Neural Processing Unit (NPU) to perform high-speed computations on compatible hardware. -_Note: The **Intel® NPU Acceleration Library** is currently in active development, with our team working to introduce a variety of features that are anticipated to dramatically enhance performance._ +_Note: The **Intel® NPU Acceleration Library** is currently in active development, with our team working to introduce a variety of features that are anticipated to dramatically enhance performance. For performant production ready solutions please refer to like [OpenVINO](https://github.com/openvinotoolkit/openvino) or [DirectML](https://devblogs.microsoft.com/directx/introducing-neural-processor-unit-npu-support-in-directml-developer-preview/). _ ## Intel NPU diff --git a/dev_requirements.txt b/dev_requirements.txt index 46a974d..ce4a002 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,7 +1,7 @@ pytest pytest-xdist pytest-cov -scikit-learn <= 1.5.0 +scikit-learn <= 1.5.1 pre-commit; sys_platform == 'darwin' sphinx breathe diff --git a/docs/source/usage.md b/docs/source/usage.md index 62a4cdb..aff2716 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -38,19 +38,33 @@ optimized_model = torch.compile(model, backend="npu") In windows torch.compile is not supported yet. So you might want to use the explicit function `intel_npu_acceleration_library.compile`. This is true also if you use a `pytorch` version < 2.0.0 +To do this, you just need to call the `compile` function with your model and the compiler configuration `CompilerConfig` to compile and optimize the model for the NPU. ```python import intel_npu_acceleration_library -optimized_model = intel_npu_acceleration_library.compile(model, dtype=torch.int8) +from intel_npu_acceleration_library.compiler import CompilerConfig +compiler_conf = CompilerConfig(dtype=torch.int8) +optimized_model = intel_npu_acceleration_library.compile(model, compiler_conf) # Use the model as usual ``` +To compile and optimize a single layer of a model to be pushed to the NPU as one block, you can set `use_to=True` in the the compiler configuration `CompilerConfig`. +```python +import intel_npu_acceleration_library +from intel_npu_acceleration_library.compiler import CompilerConfig +compiler_conf = CompilerConfig(use_to=True, dtype=torch.int8) +optimized_block = intel_npu_acceleration_library.compile(single_block, compiler_conf) + +``` + ## Training (**Experimental!**) It is possible to use Intel® NPU Acceleration Library to train a model. As before you just need to call the `compile` function, this time with `training=True`. This allows to use the same training script you use in other device with a very minimal modifications. ```python import intel_npu_acceleration_library -compiled_model = intel_npu_acceleration_library.compile(model, dtype=torch.float32, training=True) +from intel_npu_acceleration_library.compiler import CompilerConfig +compiler_conf = CompilerConfig(dtype=torch.float32, training=True) +compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf) ``` diff --git a/examples/compile_model.py b/examples/compile_model.py index 2146fcd..afe51ce 100644 --- a/examples/compile_model.py +++ b/examples/compile_model.py @@ -5,6 +5,7 @@ from intel_npu_acceleration_library import compile +from intel_npu_acceleration_library.compiler import CompilerConfig from sklearn.metrics import r2_score import intel_npu_acceleration_library import pytest @@ -41,7 +42,8 @@ def forward(self, x): print( "Windows do not support torch.compile, fallback to intel_npu_acceleration_library.compile" ) - compiled_model = intel_npu_acceleration_library.compile(model) + compiler_conf = CompilerConfig() + compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf) else: compiled_model = torch.compile(model, backend="npu") @@ -49,4 +51,4 @@ def forward(self, x): with torch.no_grad(): y = compiled_model(x) - print(f"Reference vs actual R2 score: {r2_score(y_ref, y):.2f}") + print(f"Reference vs actual R2 score: {r2_score(y_ref.numpy(), y.numpy()):.2f}") diff --git a/examples/llama.py b/examples/llama.py index 9c2aaba..e4aebb3 100644 --- a/examples/llama.py +++ b/examples/llama.py @@ -5,11 +5,13 @@ from transformers import AutoTokenizer, TextStreamer from intel_npu_acceleration_library import NPUModelForCausalLM, int4 +from intel_npu_acceleration_library.compiler import CompilerConfig model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +compiler_conf = CompilerConfig(dtype=int4) model = NPUModelForCausalLM.from_pretrained( - model_id, use_cache=True, dtype=int4, attn_implementation="sdpa" + model_id, use_cache=True, config=compiler_conf, attn_implementation="sdpa" ).eval() tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True) tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/examples/llama3.py b/examples/llama3.py index 5a4fb95..9f6ec2a 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -5,10 +5,14 @@ from transformers import AutoTokenizer, TextStreamer from intel_npu_acceleration_library import NPUModelForCausalLM, int4 +from intel_npu_acceleration_library.compiler import CompilerConfig model_id = "meta-llama/Meta-Llama-3-8B-Instruct" -model = NPUModelForCausalLM.from_pretrained(model_id, dtype=int4, use_cache=True).eval() +compiler_conf = CompilerConfig(dtype=int4) +model = NPUModelForCausalLM.from_pretrained( + model_id, use_cache=True, config=compiler_conf +).eval() tokenizer = AutoTokenizer.from_pretrained(model_id) streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) diff --git a/examples/llava.py b/examples/llava.py index a8e5545..dafa22d 100644 --- a/examples/llava.py +++ b/examples/llava.py @@ -12,6 +12,7 @@ TextStreamer, ) from transformers.feature_extraction_utils import BatchFeature +from intel_npu_acceleration_library.compiler import CompilerConfig import intel_npu_acceleration_library import torch @@ -21,7 +22,8 @@ # Load model model = LlavaForConditionalGeneration.from_pretrained(checkpoint) -model = intel_npu_acceleration_library.compile(model) +compiler_conf = CompilerConfig() +model = intel_npu_acceleration_library.compile(model, compiler_conf) image_processor = CLIPImageProcessor.from_pretrained(checkpoint) tokenizer = AutoTokenizer.from_pretrained(checkpoint) diff --git a/examples/phi-2.py b/examples/phi-2.py index 8bf59d4..7b4a4ae 100644 --- a/examples/phi-2.py +++ b/examples/phi-2.py @@ -7,12 +7,14 @@ from langchain.chains import LLMChain from langchain.llms import HuggingFacePipeline from transformers import AutoTokenizer, pipeline, TextStreamer +from intel_npu_acceleration_library.compiler import CompilerConfig import intel_npu_acceleration_library as npu_lib model_id = "microsoft/Phi-2" +compiler_conf = CompilerConfig(dtype=npu_lib.int4) model = npu_lib.NPUModelForCausalLM.from_pretrained( - model_id, use_cache=True, dtype=npu_lib.int4 + model_id, use_cache=True, config=compiler_conf ).eval() tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True) streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) diff --git a/examples/phi-3.py b/examples/phi-3.py index 184b428..87ec94b 100644 --- a/examples/phi-3.py +++ b/examples/phi-3.py @@ -5,15 +5,17 @@ import torch from transformers import AutoTokenizer, pipeline, TextStreamer +from intel_npu_acceleration_library.compiler import CompilerConfig import intel_npu_acceleration_library as npu_lib import warnings torch.random.manual_seed(0) +compiler_conf = CompilerConfig(dtype=npu_lib.int4) model = npu_lib.NPUModelForCausalLM.from_pretrained( "microsoft/Phi-3-mini-4k-instruct", + config=compiler_conf, torch_dtype="auto", - dtype=npu_lib.int4, ) tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") diff --git a/examples/t5.py b/examples/t5.py index bec55b3..1607e22 100644 --- a/examples/t5.py +++ b/examples/t5.py @@ -5,10 +5,14 @@ from transformers import AutoTokenizer, TextStreamer from intel_npu_acceleration_library import NPUModelForSeq2SeqLM +from intel_npu_acceleration_library.compiler import CompilerConfig model_id = "google/flan-t5-small" -model = NPUModelForSeq2SeqLM.from_pretrained(model_id, use_cache=True).eval() +compiler_conf = CompilerConfig() +model = NPUModelForSeq2SeqLM.from_pretrained( + model_id, use_cache=True, config=compiler_conf +).eval() tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True) tokenizer.pad_token_id = tokenizer.eos_token_id streamer = TextStreamer(tokenizer, skip_special_tokens=True) diff --git a/examples/tiny_llama_chat.py b/examples/tiny_llama_chat.py index 13f595c..82a699e 100644 --- a/examples/tiny_llama_chat.py +++ b/examples/tiny_llama_chat.py @@ -4,6 +4,7 @@ # from transformers import pipeline, TextStreamer, set_seed +from intel_npu_acceleration_library.compiler import CompilerConfig import intel_npu_acceleration_library import torch import os @@ -15,7 +16,8 @@ "text-generation", model=model_id, torch_dtype=torch.bfloat16, device_map="auto" ) print("Compiling the model for NPU...") -pipe.model = intel_npu_acceleration_library.compile(pipe.model, dtype=torch.int8) +compiler_conf = CompilerConfig(dtype=torch.int8) +pipe.model = intel_npu_acceleration_library.compile(pipe.model, compiler_conf) streamer = TextStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True) diff --git a/examples/train_mnist.py b/examples/train_mnist.py index 972eb81..6e14a22 100644 --- a/examples/train_mnist.py +++ b/examples/train_mnist.py @@ -7,6 +7,7 @@ import torch from torch import nn import intel_npu_acceleration_library +from intel_npu_acceleration_library.compiler import CompilerConfig from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor @@ -90,8 +91,8 @@ def test_loop(dataloader, model, loss_fn): model = NeuralNetwork() - -model = intel_npu_acceleration_library.compile(model, torch.float32, training=True) +compiler_conf = CompilerConfig(dtype=torch.float32, training=True) +model = intel_npu_acceleration_library.compile(model, compiler_conf) learning_rate = 1e-3 batch_size = 64 diff --git a/intel_npu_acceleration_library/backend/tensor.py b/intel_npu_acceleration_library/backend/tensor.py index 2236eda..e8cca7f 100644 --- a/intel_npu_acceleration_library/backend/tensor.py +++ b/intel_npu_acceleration_library/backend/tensor.py @@ -948,6 +948,18 @@ def to(self, dtype: NPUDtype) -> "Tensor": """ return generate_op([self], "to", dtype) + def type(self, dtype: NPUDtype) -> "Tensor": + """ + Convert the tensor to the specified data type. + + Args: + dtype (NPUDtype): The data type to convert the tensor to. + + Returns: + Tensor: The converted tensor. + """ + return self.to(dtype) + @classmethod def __torch_function__( cls: Any, diff --git a/intel_npu_acceleration_library/compiler.py b/intel_npu_acceleration_library/compiler.py index 4e80d04..5952c98 100644 --- a/intel_npu_acceleration_library/compiler.py +++ b/intel_npu_acceleration_library/compiler.py @@ -8,23 +8,43 @@ from transformers.models.gemma.modeling_gemma import GemmaMLP, GemmaAttention from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear from intel_npu_acceleration_library.quantization import quantize_model -from intel_npu_acceleration_library.dtypes import int8, int4 +from intel_npu_acceleration_library.dtypes import int8, int4, NPUDtype +from intel_npu_acceleration_library.nn.module import NPUModuleWrapper import intel_npu_acceleration_library.nn as nn from torch._dynamo import register_backend from typing import Union, Callable, Any from typing import List import torch +from functools import partial -def compile( - model: torch.nn.Module, dtype: torch.dtype = torch.float16, training: bool = False -) -> torch.nn.Module: +class CompilerConfig: + """Configuration class to store the compilation configuration of a model for the NPU.""" + + def __init__( + self, + use_to: bool = False, + dtype: Union[torch.dtype, NPUDtype] = torch.float16, + training: bool = False, + ) -> None: + """Initialize the configuration class. + + Args: + use_to (bool): Enable model compiling using .to() . Defaults to disabled + dtype (Union[torch.dtype, NPUDtype]): The dtype to compile the model with. Defaults to torch.float16 + training (bool): Enable training. Defaults to disabled + """ + self.use_to = use_to + self.dtype = dtype + self.training = training + + +def compile(model: torch.nn.Module, config: CompilerConfig) -> torch.nn.Module: """Compile a model for the NPU. Args: model (torch.nn.Module): a pytorch nn.Module to compile and optimize for the npu - dtype (torch.dtype): the model target datatype, default to torch.float16 - training (bool): enable training. Default disabled + config (CompilerConfig): the compiler configuration Raises: RuntimeError: invalid datatypes @@ -32,23 +52,29 @@ def compile( Returns: torch.nn.Module: compiled NPU nn.Module """ - if not (dtype.is_floating_point or dtype in (int8, int4)): + if not (config.dtype.is_floating_point or config.dtype in (int8, int4)): raise RuntimeError( - f"intel-npu-acceleration-library library do not support yet the requeste datatype: {dtype}" + f"intel-npu-acceleration-library library do not support yet the requeste datatype: {config.dtype}" ) # Prepare and optimize model for NPU with torch.no_grad(): - # General optimizations - apply_general_optimizations(model) - if dtype in (int8, int4): + # Model lowering to NPU ops + if config.use_to: + model = model.to("npu") + else: + # General optimizations + apply_general_optimizations(model) + + if config.dtype in (int8, int4): # Quantize model - model = quantize_model(model, dtype) + model = quantize_model(model, config.dtype) + weights_quantization(model) - # Model lowering to NPU ops - create_npu_kernels(model) + if not config.use_to: + create_npu_kernels(model) - if dtype.is_floating_point and training: + if config.dtype.is_floating_point and config.training: # Set model to evaluation only as quantized training is not supported yet return model @@ -95,13 +121,22 @@ def wrapper(model: torch.nn.Module, *args: Any, **kwargs: Any): kwargs (Any): keyword arguments """ - for name, layer in model.named_children(): - new_layer = func(name, layer, *args, **kwargs) - if new_layer: - model.add_module(name, new_layer) - wrapper(new_layer, *args, **kwargs) - else: - wrapper(layer, *args, **kwargs) + if not isinstance(model, NPUModuleWrapper) or kwargs.get( + "ignore_isinstance", False + ): + for name, layer in model.named_children(): + new_layer = func(name, layer, *args, **kwargs) + if new_layer: + model.add_module(name, new_layer) + if not isinstance(new_layer, NPUModuleWrapper) or kwargs.get( + "ignore_isinstance", False + ): + wrapper(new_layer, *args, **kwargs) + else: + if not isinstance(layer, NPUModuleWrapper) or kwargs.get( + "ignore_isinstance", False + ): + wrapper(layer, *args, **kwargs) return wrapper @@ -174,9 +209,68 @@ def optimize_llama_attention( return None +@module_optimization +def weights_quantization( + name: str, layer: torch.nn.Module, ignore_isinstance: bool = True +) -> Union[torch.nn.Module, None]: + """Apply weights quantization. + + Args: + name (str): Layer name + layer (torch.nn.Module): Original torch.nn.Linear module + ignore_isinstance (bool): ignore isinstance check in module_optimization. Defaults to True. + + Raises: + RuntimeError: unsupported quantization bits + + Returns: + None: Returns None + """ + if isinstance(layer, WeightOnlyLinear): + if (layer.bits == 4) or (layer.bits == 8): + layer.forward = partial(forward, layer) + else: + raise RuntimeError(f"Unsupported quantization bits: {layer.bits}") + return None + + +def forward(self, input): + """Override forward method for WeightOnlyLinear class. + + Args: + input: The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + if self.bits == 4: + # Unpack the int4 values + lower_int4 = self.qweight & 0x0F + lower_int4 = lower_int4 - (lower_int4 & 0x8) * 2 + upper_int4 = (self.qweight >> 4) & 0x0F + upper_int4 = upper_int4 - (upper_int4 & 0x8) * 2 + + w = torch.stack((lower_int4, upper_int4), dim=2) + w = w.contiguous().view(self.qweight.shape[0], -1) + + elif self.bits == 8: + w = self.qweight.view(torch.int8) + + output = ( + torch.nn.functional.linear(input.to(torch.float16), w.to(torch.float16), None) + * self.scales.T + ) + + if self.bias: + return output + self.bias + + return output + + @register_backend def npu( - gm: Union[torch.nn.Module, torch.fx.GraphModule], example_inputs: List[torch.Tensor] + gm: Union[torch.nn.Module, torch.fx.GraphModule], + example_inputs: List[torch.Tensor], ) -> Union[torch.nn.Module, torch.fx.GraphModule]: """Implement the custom torch 2.0 compile backend for the NPU. @@ -191,4 +285,5 @@ def npu( gm = horizontal_fusion_linear(gm) # For now compile in fp16 - return compile(gm) + config = CompilerConfig() + return compile(gm, config) diff --git a/intel_npu_acceleration_library/modelling.py b/intel_npu_acceleration_library/modelling.py index 420db3c..606cbd4 100644 --- a/intel_npu_acceleration_library/modelling.py +++ b/intel_npu_acceleration_library/modelling.py @@ -4,6 +4,7 @@ # from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM import intel_npu_acceleration_library as npu_lib +from intel_npu_acceleration_library.compiler import CompilerConfig from functools import partialmethod from typing import Type, Any, Tuple, Optional import hashlib @@ -62,8 +63,7 @@ class NPUModel: @staticmethod def from_pretrained( model_name_or_path: str, - dtype: torch.dtype = torch.float16, - training: bool = False, + config: CompilerConfig, transformers_class: Optional[Type] = None, export=True, *args: Any, @@ -73,8 +73,7 @@ def from_pretrained( Args: model_name_or_path (str): model name or path - dtype (torch.dtype, optional): compilation dtype. Defaults to torch.float16. - training (bool, optional): enable training. Defaults to False. + config (CompilerConfig): compiler configuration transformers_class (Optional[Type], optional): base class to use. Must have a `from_pretrained` method. Defaults to None. export (bool, optional): enable the caching of the model. Defaults to True. args (Any): positional arguments @@ -91,18 +90,18 @@ def from_pretrained( raise RuntimeError(f"Invalid transformer class {type(transformers_class)}") # get the model cache dir and path from the name and arguments model_dir_path, model_path = get_model_path( - model_name_or_path, dtype, training, *args, **kwargs + model_name_or_path, config.dtype, config.training, *args, **kwargs ) if os.path.isdir(model_dir_path) and os.path.isfile(model_path): # Model already exist so I can load it directly return torch.load(model_path) else: # Model does not exists, so I need to compile it first - print(f"Compiling model {model_name_or_path} {dtype} for the NPU") + print(f"Compiling model {model_name_or_path} {config.dtype} for the NPU") model = transformers_class.from_pretrained( model_name_or_path, *args, **kwargs ) - model = npu_lib.compile(model, dtype, training) + model = npu_lib.compile(model, config) if export: if kwargs.get("trust_remote_code", False): raise AttributeError( diff --git a/intel_npu_acceleration_library/nn/autograd.py b/intel_npu_acceleration_library/nn/autograd.py index 5f5f5ca..2211343 100644 --- a/intel_npu_acceleration_library/nn/autograd.py +++ b/intel_npu_acceleration_library/nn/autograd.py @@ -63,6 +63,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Iterable[Union[torch.Tensor, Non dl_dx = run_matmul(grad_output, torch.transpose(w, -1, -2)) dl_dw = run_matmul( - torch.transpose(grad_output, -1, -2), torch.transpose(x, -1, -2) + torch.transpose(grad_output, -1, -2), + torch.transpose(x, -1, -2).to(torch.float16), ) return dl_dx, dl_dw, None diff --git a/intel_npu_acceleration_library/nn/module.py b/intel_npu_acceleration_library/nn/module.py index ef23c8e..e861260 100644 --- a/intel_npu_acceleration_library/nn/module.py +++ b/intel_npu_acceleration_library/nn/module.py @@ -4,6 +4,7 @@ # from intel_npu_acceleration_library.backend import NNFactory, Tensor from typing import MutableMapping, Sequence, Any, List +from torch.profiler import record_function import numpy as np import torch @@ -104,12 +105,17 @@ def patch_modules(module: torch.nn.Module, model: NNFactory): class Module(torch.nn.Module): """A PyTorch module that runs on the NPU.""" - def __init__(self) -> None: - """Initialize the module.""" + def __init__(self, profile: bool = False) -> None: + """Initialize the module. + + Args: + profile (bool): Enable model profiling. Defaults to False. + """ super().__init__() self._nn_factory_cache: MutableMapping[str, NNFactory] = {} self._npu_inference = False self.npu_top_level_module = True + self.profile = profile def extract_tensors_from_arguments( self, args: Sequence[Any] @@ -170,7 +176,7 @@ def create_model( Returns: NNFactory: The model. """ - model = NNFactory() + model = NNFactory(profile=self.profile) def create_args_from_list(args: Sequence[Any]) -> Sequence[Any]: """Create arguments from a list. @@ -249,7 +255,8 @@ def _call_impl(self, *args: Any, **kwargs: Any) -> Any: # Run the model by replacing the forward method with the factory_forward old_forward = self.forward self.forward = self.factory_forward # type: ignore - out = super()._call_impl(*args, **kwargs) + with record_function(f"npu_{self.__class__.__name__}"): + out = super()._call_impl(*args, **kwargs) # Restore the original forward method self.forward = old_forward # type: ignore @@ -322,7 +329,8 @@ def forward(self, *args, **kwargs) -> torch.Tensor: Returns: torch.Tensor: The output tensor. """ - return self.module(*args, **kwargs) + with record_function(f"npu_{self.module.__class__.__name__}"): + return self.module(*args, **kwargs) def convert_to_npu_module(module: torch.nn.Module) -> Module: diff --git a/script/export.py b/script/export.py index 892711e..4f63f71 100644 --- a/script/export.py +++ b/script/export.py @@ -5,6 +5,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from intel_npu_acceleration_library.compiler import compile +from intel_npu_acceleration_library.compiler import CompilerConfig +from intel_npu_acceleration_library.dtypes import int8, int4 import argparse import torch import os @@ -41,15 +43,19 @@ def export(model_id, dtype, output): if dtype == "fp16": print(f"Compiling model {model_id}") - torch_dtype = torch.float16 + dtype = torch.float16 elif dtype == "int8": print(f"Quantizing & Compiling model {model_id}") - torch_dtype = torch.int8 + dtype = int8 + elif dtype == "int4": + print(f"Quantizing & Compiling model {model_id}") + dtype = int4 else: raise RuntimeError(f"Invalid dtype {dtype}") with torch.no_grad(): - compile(model, dtype=torch_dtype) + compiler_conf = CompilerConfig(dtype=dtype) + compile(model, compiler_conf) filename = os.path.join(PATH, "model.pth") os.makedirs(PATH, exist_ok=True) diff --git a/script/profile_llm.py b/script/profile_llm.py index 6a69089..cdf7c76 100644 --- a/script/profile_llm.py +++ b/script/profile_llm.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from intel_npu_acceleration_library.nn.llm import generate_with_static_shape from intel_npu_acceleration_library.dtypes import int8, int4 +from intel_npu_acceleration_library.compiler import CompilerConfig from torch.profiler import profile, ProfilerActivity import intel_npu_acceleration_library @@ -52,7 +53,8 @@ def main( if not disable_intel_npu_acceleration_library: if not compiled: - model = intel_npu_acceleration_library.compile(model, dtype) + compiler_conf = CompilerConfig(dtype=dtype) + model = intel_npu_acceleration_library.compile(model, compiler_conf) intel_npu_acceleration_library.nn.llm.warm_up_decoder_model( tokenizer, model, context_size ) diff --git a/script/profile_mlp.py b/script/profile_mlp.py new file mode 100644 index 0000000..4c64fc1 --- /dev/null +++ b/script/profile_mlp.py @@ -0,0 +1,126 @@ +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +from transformers.models.phi3.modeling_phi3 import Phi3Config, Phi3MLP +from intel_npu_acceleration_library.dtypes import int8, int4 +from intel_npu_acceleration_library.compiler import CompilerConfig +from torch.profiler import profile, ProfilerActivity +from sklearn.metrics import r2_score +import intel_npu_acceleration_library +import argparse +import torch +import numpy as np + + +def main( + seq_len=128, + hidden_size=256, + intermediate_size=512, + dtype="float16", + _profile=False, +): + + conf = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + conf.num_hidden_layers = 1 + conf.hidden_size = hidden_size + conf.intermediate_size = intermediate_size + + # Define a single Phi-3 MLP layer + mlp = Phi3MLP(conf) + + hidden_states = torch.rand((seq_len, conf.hidden_size)) + + reference = mlp(hidden_states.to(torch.float32)).to(torch.float16) + + if dtype == "float16": + dtype = torch.float16 + elif dtype == "int8": + dtype = int8 + elif dtype == "int4": + dtype = int4 + else: + raise RuntimeError(f"Invalid dtype: {dtype}") + + # Compile model + compiler_conf = CompilerConfig(use_to=True, dtype=dtype) + model = intel_npu_acceleration_library.compile(mlp, compiler_conf) + if _profile: + model.profile = True + + with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: + for _ in range(1000): + results = model(hidden_states) + + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="cpu_time_total", row_limit=20 + ) + ) + + prof.export_chrome_trace("trace.json") + + results = results.detach().numpy() + reference = reference.detach().numpy() + + assert results.shape == reference.shape, "Output shape mismatch" + assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf" + assert np.isfinite(results).all(), "NPU output contains NaN or Inf" + + if dtype == int4: + assert 1 - r2_score(reference, results) < 0.05 + else: + assert 1 - r2_score(reference, results) < 0.001 + + +def define_and_parse_args(): + parser = argparse.ArgumentParser(description="Profiling a MLP layer in the NPU") + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length (default: %(default)s)", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=256, + help="Hidden size (default: %(default)s)", + ) + parser.add_argument( + "--intermediate-size", + type=int, + default=512, + help="Intermediate size (default: %(default)s)", + ) + parser.add_argument( + "--dtype", + default="float16", + choices=["float16", "int8", "int4"], + help="Select the target dtype (default: %(default)s)", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Enable the profiling (default: False)", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = define_and_parse_args() + + print( + f"Profiling with sequence length {args.seq_len}, hidden size {args.hidden_size}, intermediate size {args.intermediate_size}, dtype {args.dtype}" + ) + + main( + seq_len=args.seq_len, + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + dtype=args.dtype, + _profile=args.profile, + ) diff --git a/test/python/test_compile.py b/test/python/test_compile.py index 07fb144..faf3d28 100644 --- a/test/python/test_compile.py +++ b/test/python/test_compile.py @@ -4,6 +4,7 @@ # from intel_npu_acceleration_library.compiler import compile +from intel_npu_acceleration_library.compiler import CompilerConfig from intel_npu_acceleration_library.dtypes import int4 from sklearn.metrics import r2_score import intel_npu_acceleration_library @@ -39,7 +40,8 @@ def test_compilation(dtype): y_ref = model(x).detach() - compiled_model = compile(model, dtype) + compiler_conf = CompilerConfig(dtype=dtype) + compiled_model = compile(model, compiler_conf) assert compiled_model @@ -104,7 +106,8 @@ def test_compile_training(dtype): model = NN() - compiled_model = compile(model, dtype, training=True) + compiler_conf = CompilerConfig(dtype=dtype, training=True) + compiled_model = compile(model, compiler_conf) for name, layer in compiled_model.named_children(): if dtype == torch.int8: @@ -118,7 +121,8 @@ def test_compile_inference(dtype): model = NN() - compiled_model = compile(model, dtype) + compiler_conf = CompilerConfig(dtype=dtype) + compiled_model = compile(model, compiler_conf) for name, layer in compiled_model.named_children(): assert layer.training == False diff --git a/test/python/test_conv.py b/test/python/test_conv.py index 5a0ec5b..6fa94a6 100644 --- a/test/python/test_conv.py +++ b/test/python/test_conv.py @@ -5,6 +5,7 @@ import intel_npu_acceleration_library +from intel_npu_acceleration_library.compiler import CompilerConfig from sklearn.metrics import r2_score import pytest import torch @@ -71,7 +72,8 @@ def test_conv( conv.conv.weight.data *= 128 y_ref = conv(X) - npu_conv = intel_npu_acceleration_library.compile(conv, dtype) + compiler_conf = CompilerConfig(dtype=dtype) + npu_conv = intel_npu_acceleration_library.compile(conv, compiler_conf) y = npu_conv(X) assert y.dtype == y_ref.dtype diff --git a/test/python/test_llm.py b/test/python/test_llm.py index 8e4dbf0..49e2952 100644 --- a/test/python/test_llm.py +++ b/test/python/test_llm.py @@ -5,11 +5,16 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaConfig from transformers.models.phi.modeling_phi import PhiConfig, PhiMLP +from transformers.models.phi3.modeling_phi3 import Phi3Config, Phi3MLP from transformers import AutoTokenizer, AutoModelForCausalLM +from intel_npu_acceleration_library.dtypes import int8, int4 +from intel_npu_acceleration_library.compiler import CompilerConfig from sklearn.metrics import r2_score +from torch.profiler import profile, ProfilerActivity import intel_npu_acceleration_library import pytest import torch +import numpy as np @pytest.fixture @@ -34,7 +39,8 @@ def tokenizer(): @pytest.mark.parametrize("model_seq_length", [128, 256]) def test_warm_up(tokenizer, model, model_seq_length): - compiled_model = intel_npu_acceleration_library.compile(model) + compiler_conf = CompilerConfig() + compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf) intel_npu_acceleration_library.nn.llm.warm_up_decoder_model( tokenizer, compiled_model, model_seq_length ) @@ -45,7 +51,10 @@ def test_compilation(tokenizer, decoder_model, dtype): prefill = tokenizer("test sentence", return_tensors="pt")["input_ids"].to("cpu") y_ref = decoder_model(prefill).logits.detach() - compiled_model = intel_npu_acceleration_library.compile(decoder_model, dtype=dtype) + compiler_conf = CompilerConfig(dtype=dtype) + compiled_model = intel_npu_acceleration_library.compile( + decoder_model, compiler_conf + ) assert compiled_model @@ -76,3 +85,53 @@ def test_phi2_mlp(seq_len, hidden_size, intermediate_size): out = model(x) assert 1 - r2_score(reference.numpy(), out.numpy()) < 0.001 + + +@torch.no_grad +@pytest.mark.parametrize("seq_len", [16, 128, 256]) +@pytest.mark.parametrize("hidden_size", [256, 512]) +@pytest.mark.parametrize("intermediate_size", [512]) +@pytest.mark.parametrize("dtype", ["float16", "int8", "int4"]) +def test_phi3_mlp_compile(seq_len, hidden_size, intermediate_size, dtype): + conf = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + conf.num_hidden_layers = 1 + conf.hidden_size = hidden_size + conf.intermediate_size = intermediate_size + + if dtype == "int8": + dtype = int8 + elif dtype == "int4": + dtype = int4 + else: + dtype = torch.float16 + + mlp = Phi3MLP(conf) + + hidden_states = torch.rand((seq_len, conf.hidden_size)) + + reference = mlp(hidden_states.to(torch.float32)).to(torch.float16).detach().numpy() + + compiler_conf = CompilerConfig(use_to=True, dtype=dtype) + model = intel_npu_acceleration_library.compile(mlp, compiler_conf) + + assert model + + with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: + out = model(hidden_states) + + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="cpu_time_total", row_limit=20 + ) + ) + + out = out.detach().numpy() + + assert out.shape == reference.shape, "Output shape mismatch" + assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf" + assert np.isfinite(out).all(), "NPU output contains NaN or Inf" + + if dtype == int4: + assert 1 - r2_score(reference, out) < 0.05 + else: + assert 1 - r2_score(reference, out) < 0.001 diff --git a/test/python/test_optimizations.py b/test/python/test_optimizations.py index 0f02f07..b3c5b97 100644 --- a/test/python/test_optimizations.py +++ b/test/python/test_optimizations.py @@ -7,6 +7,7 @@ from transformers.models.llama.modeling_llama import LlamaConfig, LlamaMLP, LlamaModel from transformers.models.gemma.modeling_gemma import GemmaConfig, GemmaMLP, GemmaModel from intel_npu_acceleration_library.optimizations import horizontal_fusion_linear +from intel_npu_acceleration_library.compiler import CompilerConfig from sklearn.metrics import r2_score import torch.nn as nn import intel_npu_acceleration_library @@ -142,7 +143,8 @@ def test_model(model_name, hidden_size, intermediate_size, sequence_length, bias reference = model(example_input)[0] - optimized = intel_npu_acceleration_library.compile(model, torch.float16) + compiler_conf = CompilerConfig(dtype=torch.float16) + optimized = intel_npu_acceleration_library.compile(model, compiler_conf) output = optimized(example_input)[0] diff --git a/test/python/test_quantization.py b/test/python/test_quantization.py index 50044b2..c0a1c27 100644 --- a/test/python/test_quantization.py +++ b/test/python/test_quantization.py @@ -4,6 +4,7 @@ # from sklearn.metrics import r2_score +from intel_npu_acceleration_library.compiler import CompilerConfig import numpy as np import intel_npu_acceleration_library import pytest @@ -88,7 +89,9 @@ def test_compiled_quantized(batch, inC, outC): model = NN(inC, outC) y_ref = model(X.to(torch.float32)).detach() - compiled_model = intel_npu_acceleration_library.compile(model, torch.int8) + + compiler_conf = CompilerConfig(dtype=torch.int8) + compiled_model = intel_npu_acceleration_library.compile(model, compiler_conf) assert compiled_model y1 = compiled_model(X).detach() diff --git a/test/python/test_training.py b/test/python/test_training.py index aa8f390..adc398d 100644 --- a/test/python/test_training.py +++ b/test/python/test_training.py @@ -6,6 +6,7 @@ from sklearn.metrics import r2_score from intel_npu_acceleration_library import compile +from intel_npu_acceleration_library.compiler import CompilerConfig import torch import pytest import copy @@ -28,12 +29,14 @@ def forward(self, x): @pytest.fixture def model_no_bias(): - return compile(NN(inc=in_c, outc=out_c, bias=False)) + compiler_conf = CompilerConfig() + return compile(NN(inc=in_c, outc=out_c, bias=False), compiler_conf) @pytest.fixture def model(): - return compile(NN(inc=in_c, outc=out_c, bias=True)) + compiler_conf = CompilerConfig() + return compile(NN(inc=in_c, outc=out_c, bias=True), compiler_conf) def test_parameters(model, model_no_bias): @@ -48,7 +51,8 @@ def test_gradient(): cpu_model.load_state_dict(copy.deepcopy(npu_model.state_dict())) # Compile one of the model on npu - compile(npu_model, training=True) + compiler_conf = CompilerConfig(training=True) + compile(npu_model, compiler_conf) x = torch.rand([batch, in_c]).half() yref = torch.rand([batch, in_c]).half()