Skip to content

Commit

Permalink
Merge branch 'main' into henry
Browse files Browse the repository at this point in the history
  • Loading branch information
henry8248 authored Jul 8, 2024
2 parents 8d195d7 + 6f8e11d commit 0479893
Show file tree
Hide file tree
Showing 11 changed files with 884 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Site map

.. toctree::
developer.md
adding_operation.md
adding_operations.md
:maxdepth: 1
:caption: Developements guide:

Expand Down
582 changes: 582 additions & 0 deletions examples/NPU compilation tutorial.ipynb

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions examples/cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ int main() {
// create parameter
auto input = factory->parameter({batch, inC}, ov::element::f16);
auto weights = factory->parameter({outC, inC}, ov::element::f16);
auto bias = factory->parameter({1, outC}, ov::element::f16);

// create matmul
auto matmul = factory->matmul(input, weights);
auto matmul_bias = factory->eltwise_add(matmul, bias);
factory->result(matmul_bias);

// Compile the model
factory->compile(matmul);
factory->compile();

// Save OV model
std::cout << "Saving model to matmul.xml" << std::endl;
Expand All @@ -31,14 +34,17 @@ int main() {
// Here you can create float16 buffers and run inference by using
half_ptr input_buffer = new uint16_t[batch * inC];
half_ptr weights_buffer = new uint16_t[outC * inC];
half_ptr bias_buffer = new uint16_t[outC];
half_ptr output_buffer = new uint16_t[batch * outC];

memset(input_buffer, 0, 128 * 256 * sizeof(uint16_t));
memset(weights_buffer, 0, 128 * 256 * sizeof(uint16_t));
memset(output_buffer, 0, 128 * 512 * sizeof(uint16_t));
memset(input_buffer, 0, batch * inC * sizeof(uint16_t));
memset(weights_buffer, 0, outC * inC * sizeof(uint16_t));
memset(output_buffer, 0, batch * outC * sizeof(uint16_t));
memset(bias_buffer, 0, outC * sizeof(uint16_t));

factory->setInputTensor(input_buffer, 0);
factory->setInputTensor(weights_buffer, 1);
factory->setInputTensor(bias_buffer, 2);
factory->setOutputTensor(output_buffer, 0);

// Run inference
Expand All @@ -49,6 +55,7 @@ int main() {

delete[] input_buffer;
delete[] weights_buffer;
delete[] bias_buffer;
delete[] output_buffer;
return 0;
}
16 changes: 16 additions & 0 deletions include/intel_npu_acceleration_library/nn_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ class ModelFactory : public intel_npu_acceleration_library::OVInferenceModel {
return matmul.get();
}

/**
* @brief Create a new linear operation
*
* @param input matmul lhs input
* @param weights matmul rhs input, a.k.a. weights
* @param bias matmul bias input
* @return ov::op::Op*
*/
ov::op::Op* linear(ov::op::Op* input, ov::op::Op* weights, ov::op::Op* bias) {
auto mm_op = matmul(input, weights);
if (bias != nullptr) {
return eltwise_add(mm_op, bias);
}
return mm_op;
}

/**
* @brief Create a new convolution operation
*
Expand Down
21 changes: 21 additions & 0 deletions intel_npu_acceleration_library/backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,27 @@ def constant(
self._mm, shape_ptr.size, shape_ptr, self.get_backend_dtype(data.dtype), dst
)

@return_tensor
def matmul(
self,
input_node: ctypes._Pointer,
weights_node: ctypes._Pointer,
trA: bool = False,
trB: bool = True,
) -> ctypes._Pointer:
"""Generate a matrix multiplication layer.
Args:
input_node (ctypes._Pointer): layer input node
weights_node (ctypes._Pointer): weights node
trA (bool): transpose input node
trB (bool): transpose weights node
Returns:
ctypes._Pointer: output node
"""
return backend_lib.matmul(self._mm, input_node, weights_node, trA, trB)

@return_tensor
def convolution(
self,
Expand Down
2 changes: 1 addition & 1 deletion intel_npu_acceleration_library/backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_supported_ops() -> List[SupportedOp]:
"""
supported_ops = [
SupportedOp(name="result", inputs=1),
SupportedOp(name="matmul", inputs=2),
SupportedOp(name="matmul", inputs=2, parameters=[ctypes.c_bool, ctypes.c_bool]),
SupportedOp(name="eltwise_add", inputs=2),
SupportedOp(name="eltwise_mul", inputs=2),
SupportedOp(name="eltwise_div", inputs=2),
Expand Down
122 changes: 120 additions & 2 deletions intel_npu_acceleration_library/backend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def __add__(self, other) -> "Tensor":
Returns:
Tensor: The result of the addition.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, other], "eltwise_add")

def __sub__(self, other) -> "Tensor":
Expand All @@ -178,6 +182,10 @@ def __sub__(self, other) -> "Tensor":
Returns:
Tensor: The result of the subtraction.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, -other], "eltwise_add")

def __mul__(self, other) -> "Tensor":
Expand All @@ -190,6 +198,10 @@ def __mul__(self, other) -> "Tensor":
Returns:
Tensor: The result of the multiplication.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, other], "eltwise_mul")

def __truediv__(self, other) -> "Tensor":
Expand All @@ -202,8 +214,76 @@ def __truediv__(self, other) -> "Tensor":
Returns:
Tensor: The result of the division.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, other], "eltwise_div")

def __radd__(self, other) -> "Tensor":
"""
Add two tensors element-wise.
Args:
other (Tensor): The tensor to be added.
Returns:
Tensor: The result of the addition.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, self], "eltwise_add")

def __rsub__(self, other) -> "Tensor":
"""
Subtract two tensors element-wise.
Args:
other (Tensor): The tensor to be subtracted.
Returns:
Tensor: The result of the subtraction.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, -self], "eltwise_add")

def __rmul__(self, other) -> "Tensor":
"""
Multiply two tensors element-wise.
Args:
other (Tensor): The tensor to be multiplied.
Returns:
Tensor: The result of the multiplication.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, self], "eltwise_mul")

def __rtruediv__(self, other) -> "Tensor":
"""
Divide two tensors element-wise.
Args:
other (Tensor): The tensor to be divided.
Returns:
Tensor: The result of the division.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, self], "eltwise_div")

def __neg__(self) -> "Tensor":
"""
Negate the tensor.
Expand Down Expand Up @@ -436,7 +516,7 @@ def __matmul__(self, other) -> "Tensor":
Returns:
Tensor: The result of the matrix multiplication.
"""
return generate_op([self, other], "matmul")
return generate_op([self, other], "matmul", False, False)

def acos(self) -> "Tensor":
"""
Expand Down Expand Up @@ -819,6 +899,43 @@ def sum(
sum = sum.to(dtype)
return sum

def chunk(
self,
chunks: int,
dim: int = 0,
) -> Union["Tensor", list]:
"""
Return the list of tensor chunks.
Args:
chunks (int): The number of chunks to return.
dim (int): The dimension along which to split the tensor. Default is 0.
Returns:
Union["Tensor", list]: The resulting list of split tensors or a single tensor.
Raises:
ValueError: The input chunks value is not valid.
"""
if chunks <= 0:
raise ValueError("The input chunks value is not valid.")
if chunks == 1:
return self
tensors = []
remainder = self.shape[dim] % chunks
chunk_size = self.shape[dim] // chunks + (1 if remainder > 0 else 0)
num_dims = self.dim()

start_idx = 0
for _ in range(chunks):
indexes = [slice(None)] * num_dims
end_idx = start_idx + chunk_size
end_idx = end_idx if end_idx < self.shape[dim] else self.shape[dim]
indexes[dim] = slice(start_idx, end_idx)
tensors.append(self.__getitem__(tuple(indexes)))
start_idx = end_idx
return tensors

def to(self, dtype: NPUDtype) -> "Tensor":
"""
Convert the tensor to the specified data type.
Expand Down Expand Up @@ -920,7 +1037,8 @@ def generate_op(
):
raise ValueError("All tensors must be from the same factory")

factory = tensors[0].factory
# Get the first factory from the tensors
factory = [t for t in tensors if isinstance(t, Tensor)][0].factory

# Replace the tensors that are not from the factory with constant tensors if they are coming from pytorch
tensors = [
Expand Down
30 changes: 28 additions & 2 deletions intel_npu_acceleration_library/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def sign(x: Tensor, out: Optional[Tensor] = None) -> Tensor:

@implements(torch.nn.functional.linear)
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
"""Return the sign of a tensor element-wise.
"""Apply a linear transformation to the incoming data: y = x * A^T + b.
Args:
input (Tensor): The input tensor.
Expand All @@ -320,12 +320,38 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
Returns:
Tensor: Output tensor.
"""
mm = generate_op([input, weight], "matmul")
mm = generate_op([input, weight], "matmul", False, True)
if bias is not None:
return generate_op([mm, bias], "eltwise_add")
return mm


@implements(torch.addmm)
def addmm(
input: Tensor,
mat1: Tensor,
mat2: Tensor,
beta: float = 1,
alpha: float = 1,
out: Optional[Tensor] = None,
) -> Tensor:
"""Return the addmm of a tensor element-wise.
Args:
input (Tensor): The input tensor.
mat1 (Tensor): The first matrix tensor.
mat2 (Tensor): The second matrix tensor.
beta (float): The beta value. Defaults to 1.
alpha (float): The alpha value. Defaults to 1.
out (Optional[Tensor], optional): Output tensor. Defaults to None.
Returns:
Tensor: Output tensor.
"""
out = beta * input + alpha * (mat1 @ mat2)
return out


@implements(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(
query: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* constant(intel_npu_accelerati
}

intel_npu_acceleration_library_DLL_API ov::op::Op* matmul(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* in0, ov::op::Op* in1) {
return factory->matmul(in0, in1);
ov::op::Op* in0, ov::op::Op* in1, bool trA, bool trB) {
return factory->matmul(in0, in1, trA, trB);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* eltwise_add(intel_npu_acceleration_library::ModelFactory* factory,
Expand Down
27 changes: 27 additions & 0 deletions test/python/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,30 @@ def test_logsoftmax(batch, hidden_dim, axis):
assert np.isfinite(out).all(), "NPU output contains NaN or Inf"

assert 1 - r2_score(reference, out) < 0.01


@pytest.mark.parametrize("batch", [16, 128])
@pytest.mark.parametrize("hidden_dim", [128, 256])
@pytest.mark.parametrize("channels", [128, 256])
@pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0])
# @pytest.mark.parametrize("beta", [0, 0.5, 1.0])
def test_addmm(batch, hidden_dim, channels, alpha, beta=1):
torch.manual_seed(42)
m1 = torch.rand((1, channels)).to(torch.float16)
m2 = torch.rand((batch, hidden_dim)).to(torch.float16)
m3 = torch.rand((hidden_dim, channels)).to(torch.float16)

reference = torch.addmm(m1, m2, m3, alpha=alpha, beta=beta).numpy()

model = NNFactory()
par1 = model.parameter(m1.shape, np.float16)
par2 = model.parameter(m2.shape, np.float16)
par3 = model.parameter(m3.shape, np.float16)
out = torch.addmm(par1, par2, par3, alpha=alpha, beta=beta)
model.compile()

assert out.shape == list(reference.shape)

result = model(m1, m2, m3, alpha=alpha, beta=beta).detach().numpy()

assert 1 - r2_score(reference.flatten(), result.flatten()) < 0.01
Loading

0 comments on commit 0479893

Please sign in to comment.