Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented ONNX MatMul converter. #58

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nngen/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import reduce
from . import conv
from . import gemm
from . import matmul
from . import pool
from . import pad
from . import act_func
Expand Down Expand Up @@ -46,6 +47,7 @@
'ArgMin': reduce.ArgMin,
'Conv': conv.Conv,
'Gemm': gemm.Gemm,
'MatMul': matmul.MatMul,
'AveragePool': pool.AveragePool,
'GlobalAveragePool': pool.GlobalAveragePool,
'MaxPool': pool.MaxPool,
Expand Down
131 changes: 131 additions & 0 deletions nngen/onnx/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import collections

import nngen.storage as storage
import nngen.operator as operator
import nngen.dtype_list as dtype_list

from . import util
from . import flatten
from . import reshape


def MatMul(visitor, node,
batchnorm_scale=None, batchnorm_bias=None, act_func=None):

# input, filter
srcs = []

for i, src in enumerate(node.input):
src_node = util.search_node_from_model(visitor.model, src)

if src_node is None:
pass
else:
if (i == 0 and src_node.op_type == 'Flatten' and
len(visitor.consumers[src]) == 1):

src_obj = flatten.Flatten(visitor, src_node, no_transpose=True)
srcs.append(src_obj)
continue

if (i == 0 and src_node.op_type == 'Reshape' and
len(visitor.consumers[src]) == 1):

shape = visitor.visit(src_node.input[1])
if len(shape) == 2:
src_obj = reshape.Reshape(visitor, src_node, no_transpose=True)
srcs.append(src_obj)
continue

src_obj = visitor.visit(src)
srcs.append(src_obj)

input = srcs[0]
filter = srcs[1]
filter.value = srcs[1].value.T
filter.shape = srcs[1].value.shape

orig_shape = input.get_original_shape()
orig_layout = input.get_original_layout()
orig_onnx_layout = input.get_original_onnx_layout()

if orig_layout is None:
pass
elif orig_layout == orig_onnx_layout:
pass
else:
# The weight layout of Matmul is identical to nngen.matmul.
# However, Matmul assumes values before the Reshape operator have the different layout.
# (Almost ONNX models usually have 'NCHW' layouts).
# Thus the weight layout is transposed.

shape = ([filter.shape[0]] +
[orig_shape[orig_layout.index(s)] for s in orig_onnx_layout[1:]])
reshape_value = filter.value.reshape(shape)
perm = [orig_onnx_layout.index(s) for s in orig_layout]
transpose_value = reshape_value.transpose(perm)
new_value = transpose_value.reshape([filter.shape[0], -1])
filter.value = new_value

bias = srcs[2] if len(srcs) > 2 else None

name = util.get_name(node)

scale_name = '_'.join(['onnx', name, 'matmul.scale'])
scale_dtype = visitor.default_scale_dtype
scale_shape = batchnorm_scale.shape if batchnorm_scale is not None else (1,)
scale = storage.variable(dtype=scale_dtype, shape=scale_shape, name=scale_name)
scale_value = batchnorm_scale if batchnorm_scale is not None else [1]
scale.set_value(scale_value)
visitor.variables[scale_name] = scale

if bias is None and batchnorm_bias is not None:
bias_name = '_'.join(['onnx', name, 'matmul.bias'])
bias_dtype = visitor.default_bias_dtype
bias_shape = batchnorm_bias.shape
bias = storage.variable(dtype=bias_dtype, shape=bias_shape, name=bias_name)
bias_value = batchnorm_bias / batchnorm_scale
bias.set_value(bias_value)
visitor.variables[bias_name] = bias

elif bias is not None and batchnorm_bias is not None:
bias.dtype = visitor.default_bias_dtype
bias_value = batchnorm_bias / batchnorm_scale + bias.value
bias.set_value(bias_value)

elif bias is not None:
bias.dtype = visitor.default_bias_dtype

rshift_out = 0

if name in visitor.value_dtypes:
dtype = visitor.value_dtypes[name]
else:
dtype = visitor.default_operator_dtype

if dtype.width >= 16:
sum_dtype = dtype_list.dtype_int(dtype.width * 4)
else:
sum_dtype = dtype_list.int32

args = [input, filter]

kwargs = collections.OrderedDict()
kwargs['bias'] = bias
kwargs['scale'] = scale
kwargs['transposed_a'] = False
# kwargs['transposed_b'] = False
kwargs['transposed_b'] = True
kwargs['rshift_out'] = rshift_out
kwargs['act_func'] = act_func
kwargs['dtype'] = dtype
kwargs['sum_dtype'] = sum_dtype
kwargs['name'] = name

c = operator.matmul(*args, **kwargs)

return c
4 changes: 2 additions & 2 deletions nngen/verify/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from . import concat


def normalize(x, y, z, shamt,
def normalize(x, y, z, shamt, asymmetric_clip=False,
dtype=None, sum_dtype=None, name=None, par=1,
x_dtype=None, y_dtype=None, z_dtype=None, shamt_dtype=None):

return basic.multiply_add_rshift_clip(x, y, z, shamt,
return basic.multiply_add_rshift_clip(x, y, z, shamt, asymmetric_clip,
dtype, sum_dtype, name, par,
x_dtype, y_dtype, z_dtype, shamt_dtype)

Expand Down
31 changes: 31 additions & 0 deletions tests/onnx_matrix_matmul/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
TARGET=$(shell ls *.py | grep -v test | grep -v parsetab.py)
ARGS=

PYTHON=python3
#PYTHON=python
#OPT=-m pdb
#OPT=-m cProfile -s time
#OPT=-m cProfile -o profile.rslt
SIMTYPE=iverilog

.PHONY: all
all: test

.PHONY: run
run:
$(PYTHON) $(OPT) $(TARGET) $(ARGS)

.PHONY: test
test:
$(PYTHON) -m pytest -vv --sim $(SIMTYPE)

.PHONY: check
check:
$(PYTHON) $(OPT) $(TARGET) $(ARGS) > tmp.v
iverilog -tnull -Wall tmp.v
rm -f tmp.v

.PHONY: clean
clean:
rm -rf *.pyc __pycache__ parsetab.py .cache *.out *.png *.dot tmp.v uut.vcd
rm -rf *.onnx
Loading