Skip to content

Commit

Permalink
Create python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandropalla committed Jul 16, 2024
1 parent f2ff0cd commit 9760d57
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 27 deletions.
3 changes: 2 additions & 1 deletion include/intel_npu_acceleration_library/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>
#include "intel_npu_acceleration_library/common.h"
#include "intel_npu_acceleration_library/parameters.h"
#include "intel_npu_acceleration_library/tensor.h"

namespace intel_npu_acceleration_library {

Expand Down Expand Up @@ -81,7 +82,7 @@ class OVInferenceModel {
// set letency hint
core.set_property(ov::cache_dir("cache"));
core.set_property(device, ov::hint::performance_mode(ov::hint::PerformanceMode::THROUGHPUT));
// core.set_property("NPU", ov::log::level(ov::log::Level::DEBUG));
core.set_property("NPU", ov::log::level(ov::log::Level::DEBUG));
if (device == "NPU") {
core.set_property(device, intel_npu_acceleration_library::npu_compiler_type("DRIVER"));
if (profile) {
Expand Down
48 changes: 48 additions & 0 deletions include/intel_npu_acceleration_library/tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//
// Copyright © 2024 Intel Corporation
// SPDX-License-Identifier: Apache 2.0
//

#include "intel_npu_acceleration_library/common.h"

namespace intel_npu_acceleration_library {

/**
* @brief Class representing a NPU tensor
*
*/
class Tensor {
private:
ov::intel_npu::level_zero::ZeroBufferTensor _remote_tensor;
void* data_ptr;

public:
/**
* @brief Construct a new Tensor object
*
* @param dtype tensor datatype
* @param shape tensor shape
* @param data pointer to tensor data
* @param tensor_type tensor type. Choices between INPUT, OUTPUT, BINDED
* @param device target device for the tensor
*/
Tensor(ov::element::Type_t dtype, ov::Shape shape, void* data,
ov::intel_npu::TensorType tensor_type = ov::intel_npu::TensorType::INPUT, std::string device = "NPU") {
ov::Core core;
auto context = core.get_default_context(device).as<ov::intel_npu::level_zero::ZeroContext>();
_remote_tensor = context.create_l0_host_tensor(dtype, shape, tensor_type);
data_ptr = _remote_tensor.get();
std::memcpy(data_ptr, data, _remote_tensor.get_byte_size());
}

/**
* @brief Get the data pointer
*
* @return void*
*/
void* data() {
return data_ptr;
}
};

} // namespace intel_npu_acceleration_library
9 changes: 9 additions & 0 deletions intel_npu_acceleration_library/backend/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def init_common(lib: ctypes.CDLL):

lib.compressToI4.argtypes = [c_i8_array, c_u8_array, ctypes.c_int]

# Remote tensors
lib.to_npu.argtypes = [ctypes.c_int, c_u32_array, ctypes.c_char_p, ctypes.c_void_p]
lib.to_npu.restype = handler

lib.remote_tensor_data.argtypes = [handler]
lib.remote_tensor_data.restype = ctypes.c_void_p

lib.del_remote_tensor.argtypes = [handler]


def init_network_factory(lib: ctypes.CDLL):
"""Initialize Netowrk factory bindings.
Expand Down
28 changes: 2 additions & 26 deletions intel_npu_acceleration_library/backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from intel_npu_acceleration_library.backend.ops import get_supported_ops
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
from intel_npu_acceleration_library.backend.tensor import Tensor
from intel_npu_acceleration_library.dtypes import int4, bfloat16
from intel_npu_acceleration_library.dtypes import get_backend_dtype
from typing import Optional, Tuple, Any, Union, Sequence, TypeVar, Callable, cast, List
from functools import partial
import numpy.typing as npt
Expand Down Expand Up @@ -115,34 +115,10 @@ def get_backend_dtype(self, dtype) -> ctypes.c_char_p:
Args:
dtype: numpy dtype
Raises:
RuntimeError: Unsupported datatype
Returns:
ctypes.c_char_p: string representation of the dtype
"""
if dtype in [np.int8, torch.int8]:
str_dtype = "int8"
elif dtype == np.uint8 or dtype == int4:
# u8 represents packed i4 dtypes
str_dtype = "int4"
elif dtype in [np.int16, torch.int16]:
str_dtype = "int16"
elif dtype in [np.int32, torch.int32]:
str_dtype = "int32"
elif dtype in [np.int64, torch.int64]:
str_dtype = "int64"
elif dtype in [np.float16, torch.float16]:
str_dtype = "float16"
elif dtype in [np.float32, torch.float32]:
str_dtype = "float32"
elif dtype in [np.float64, torch.float64]:
str_dtype = "float64"
elif dtype in [bfloat16, torch.bfloat16]:
str_dtype = "bfloat16"
else:
raise RuntimeError(f"DType is not supported {dtype}")
return ctypes.c_char_p(str_dtype.encode())
return get_backend_dtype(dtype)

@return_tensor
def parameter(
Expand Down
76 changes: 76 additions & 0 deletions intel_npu_acceleration_library/backend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,90 @@
int32,
int64,
NPUDtype,
get_backend_dtype,
)
from dataclasses import dataclass
import functools
from math import prod
import numpy as np
import ctypes
import torch


class RemoteTensor(torch.Tensor):
"""
Represent a remote tensor object.
Attrs:
_remote_tensor (ctypes._Pointer): The pointer to the underlying remote tensor.
Methods:
from_torch(x: torch.Tensor): Create a remote tensor from a torch tensor.
"""

_remote_tensor = None

@staticmethod
def __new__(cls, x: Any, remote_tensor: ctypes._Pointer, *args: Any, **kwargs: Any):
"""
Create a new remote tensor object.
Args:
x (Any): tensor input
remote_tensor (ctypes._Pointer): remote tensor pointer
args (Any): additional arguments
kwargs (Any): additional keyword arguments
Returns:
RemoteTensor: a RemoteTensor object
"""
return super().__new__(cls, x, *args, **kwargs)

def __init__(self, x: Any, remote_tensor: ctypes._Pointer):
"""
Initialize the remote tensor object.
Args:
x (Any): tensor input
remote_tensor (ctypes._Pointer): remote tensor pointer
"""
self._remote_tensor = remote_tensor

# def __del__(self):
# if self._remote_tensor and backend_lib:
# backend_lib.del_remote_tensor(self._remote_tensor)

@staticmethod
def from_torch(x: torch.Tensor) -> "RemoteTensor":
"""
Create a remote tensor from a torch tensor.
Args:
x (torch.Tensor): The torch tensor.
Returns:
RemoteTensor: The remote tensor.
"""
shape_arr = np.array(x.shape, dtype=np.uint32)
dtype_str = get_backend_dtype(x.dtype)
p = ctypes.cast(x.data_ptr(), ctypes.c_void_p)

rt = backend_lib.to_npu(shape_arr.size, shape_arr, dtype_str, p)

pointer = ctypes.cast(
backend_lib.remote_tensor_data(rt),
ctypes.POINTER(ctypes.c_uint8),
)

arr = (pointer._type_ * prod(x.shape) * x.element_size()).from_address(
ctypes.addressof(pointer.contents)
)

pt_tensor = torch.frombuffer(arr, dtype=x.dtype).view(*x.shape)

return RemoteTensor(pt_tensor, rt)


@dataclass
class Tensor:
"""
Expand Down
37 changes: 37 additions & 0 deletions intel_npu_acceleration_library/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Union
import numpy as np
import torch
import ctypes


@dataclass(frozen=True)
Expand Down Expand Up @@ -81,6 +82,42 @@ def __repr__(self) -> str:
return self.name


def get_backend_dtype(dtype) -> ctypes.c_char_p:
"""Get the string representation of the dtype.
Args:
dtype: numpy dtype
Raises:
RuntimeError: Unsupported datatype
Returns:
ctypes.c_char_p: string representation of the dtype
"""
if dtype in [np.int8, torch.int8]:
str_dtype = "int8"
elif dtype == np.uint8 or dtype == int4:
# u8 represents packed i4 dtypes
str_dtype = "int4"
elif dtype in [np.int16, torch.int16]:
str_dtype = "int16"
elif dtype in [np.int32, torch.int32]:
str_dtype = "int32"
elif dtype in [np.int64, torch.int64]:
str_dtype = "int64"
elif dtype in [np.float16, torch.float16]:
str_dtype = "float16"
elif dtype in [np.float32, torch.float32]:
str_dtype = "float32"
elif dtype in [np.float64, torch.float64]:
str_dtype = "float64"
elif dtype in [bfloat16, torch.bfloat16]:
str_dtype = "bfloat16"
else:
raise RuntimeError(f"DType is not supported {dtype}")
return ctypes.c_char_p(str_dtype.encode())


float16 = NPUDtype(
"fp16",
16,
Expand Down
19 changes: 19 additions & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ intel_npu_acceleration_library_DLL_API uint32_t getNPUDriverVersion() {
return intel_npu_acceleration_library::driver_version(core);
}

// ######################## Remote Tensors ########################

intel_npu_acceleration_library_DLL_API intel_npu_acceleration_library::Tensor* to_npu(size_t size,
unsigned int* shape_data,
char* dtype, void* data) {
ov::element::Type_t ov_dtype = intel_npu_acceleration_library::dtype_from_string(std::string(dtype));
std::vector<size_t> shape(shape_data, shape_data + size);

return new intel_npu_acceleration_library::Tensor(ov_dtype, shape, data);
}

intel_npu_acceleration_library_DLL_API void* remote_tensor_data(intel_npu_acceleration_library::Tensor* rt) {
return rt->data();
}

intel_npu_acceleration_library_DLL_API void del_remote_tensor(intel_npu_acceleration_library::Tensor* rt) {
delete rt;
}

// ######################## Compression ########################

intel_npu_acceleration_library_DLL_API void compressToI4(const int8_t* src, uint8_t* dst, size_t size) {
Expand Down

0 comments on commit 9760d57

Please sign in to comment.