Skip to content

Commit

Permalink
Merge branch 'master' into fix_instance_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
masakistan authored Oct 15, 2020
2 parents cea3239 + 58ace0a commit 24877e8
Show file tree
Hide file tree
Showing 19 changed files with 848 additions and 497 deletions.
18 changes: 9 additions & 9 deletions doc/support_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ Notes:
|Celu|-|-|-|-|-|-|-|-|-|-|-|**12**|12|Celu|
|Clip|**1**|1|1|1|1|**6**|6|6|6|6|**11**|**12**|**13**|Clip|
|Compress|-|-|-|-|-|-|-|-|**9**|9|**11**|11|11|Compress|
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|11|**13**:small_red_triangle:|Concat|
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|11|**13**|Concat|
|ConcatFromSequence|-|-|-|-|-|-|-|-|-|-|**11**:small_orange_diamond:|11:small_orange_diamond:|11:small_orange_diamond:|ConcatFromSequence|
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|**12**|**13**:small_red_triangle:|Constant|
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|**12**|**13**|Constant|
|ConstantOfShape|-|-|-|-|-|-|-|-|**9**|9|9|9|9|ConstantOfShape|
|Conv|**1**|1|1|1|1|1|1|1|1|1|**11**|11|11|Conv|
|ConvInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|10|ConvInteger|
Expand Down Expand Up @@ -68,26 +68,26 @@ Notes:
|GlobalAveragePool|**1**|1|1|1|1|1|1|1|1|1|1|1|1|GlobalAveragePool|
|GlobalLpPool|**1**|**2**|2|2|2|2|2|2|2|2|2|2|2|GlobalLpPool|
|GlobalMaxPool|**1**|1|1|1|1|1|1|1|1|1|1|1|1|GlobalMaxPool|
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**:small_red_triangle:|Greater|
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**|Greater|
|GreaterOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|GreaterOrEqual|
|HardSigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|HardSigmoid|
|Hardmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Hardmax|
|Identity|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**:small_red_triangle:|Identity|
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|If|
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|If|
|InstanceNormalization|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|InstanceNormalization|
|IsInf|-|-|-|-|-|-|-|-|-|**10**|10|10|10|IsInf|
|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**:small_red_triangle:|IsNaN|
|LRN|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**|LRN|
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|LSTM|
|LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|LeakyRelu|
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**:small_red_triangle:|Less|
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**|Less|
|LessOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|LessOrEqual|
|Log|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Log|
|LogSoftmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|LogSoftmax|
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Loop|
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|Loop|
|LpNormalization|**1**|1|1|1|1|1|1|1|1|1|1|1|1|LpNormalization|
|LpPool|**1**|**2**|2|2|2|2|2|2|2|2|**11**|11|11|LpPool|
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|9|**13**:small_red_triangle:|MatMul|
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|9|**13**|MatMul|
|MatMulInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|10|MatMulInteger|
|Max|**1**|1|1|1|1|**6**|6|**8**|8|8|8|**12**|**13**|Max|
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_orange_diamond:|**11**:small_orange_diamond:|**12**:small_orange_diamond:|12:small_orange_diamond:|MaxPool|
Expand Down Expand Up @@ -164,7 +164,7 @@ Notes:
|Sqrt|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Sqrt|
|Squeeze|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Squeeze|
|StringNormalizer|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|StringNormalizer|
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**:small_red_triangle:|Sub|
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**|Sub|
|Sum|**1**|1|1|1|1|**6**|6|**8**|8|8|8|8|**13**:small_red_triangle:|Sum|
|Tan|-|-|-|-|-|-|**7**|7|7|7|7|7|7|Tan|
|Tanh|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Tanh|
Expand All @@ -179,7 +179,7 @@ Notes:
|Where|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Where|
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|7|7|Xor|

ONNX-TF Supported Operators / ONNX Operators: 103 / 162
ONNX-TF Supported Operators / ONNX Operators: 105 / 162

Notes:
1. Cast: Cast string to data types other than float32/float64/int32/int64 is not supported in Tensorflow
Expand Down
45 changes: 9 additions & 36 deletions onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,12 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
# initialized: A list of names of the initialized tensors.

if graph_def.initializer:
input_dict_items = cls._onnx_initializer_to_input_dict_items(
graph_def.initializer)
initialized = {init.name for init in graph_def.initializer}
else:
input_dict_items = []
initialized = set()

input_dict = dict()

module = BackendTFModule(handlers, opset, strict, graph_def, cls)
signatures = dict()
for value_info in graph_def.input:
Expand All @@ -146,7 +145,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
shape=shape
) if value_info.name not in input_tensor_dict else input_tensor_dict[
value_info.name]
input_dict_items.append((value_info_name, x))
input_dict[value_info.name] = x

tf_rep = TensorflowRep()
tf_rep.inputs = [
Expand All @@ -159,8 +158,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
tf_rep.tf_module = module
tf_rep.signatures = signatures
tf_rep.tensor_dict = module.gen_tensor_dict(
input_dict_items) if gen_tensor_dict else None

input_dict) if gen_tensor_dict else None
return tf_rep

@classmethod
Expand Down Expand Up @@ -288,55 +286,30 @@ def supports_device(cls, device):
@classmethod
def onnx_graph_to_tensorflow_ops(cls,
subgraph,
input_values,
tensor_dict,
opset=None,
strict=True):
"""
Converts ONNX graph to Tensorflow operations
Args:
subgraph: the ONNX graph to be converted
input_values: dictionary with values/tensors to initialize
the subgraph inputs. if the subgraph.input
are send in as parameters then it is required,
otherwise this can be empty dictionary
tensor_dict: the dictionary that contain values for all the
node.inputs in the subgraph that are not defined
in the subgraph or input_values.
subgraph: the ONNX graph to be converted.
tensor_dict: tensor dict of the subgraph.
opset: opset version of the operator set.
strict: whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Returns:
array of Tensorflow Tensors
"""
# get the subgraph.input from input_values
subgraph_tensor_dict = input_values.copy()
# get the rest of the subgraph input from tensor_dict
for i in subgraph.input:
if i.name not in subgraph_tensor_dict.keys():
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
# get the required initializer constant node(s) for the subgraph
# Need to get the initializer constant nodes from tensor_dict here
# because input from initializer will not be send in as inputs
# to the subgraph and those nodes are not in the subgraph
nodes_outputs = []
for node in subgraph.node:
for o_name in node.output:
nodes_outputs.append(o_name)
for node in subgraph.node:
for i_name in node.input:
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
):
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
subgraph_tensor_dict,
tensor_dict,
opset=opset,
strict=strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
subgraph_tensor_dict.update(curr_node_output_map)
return subgraph_tensor_dict
tensor_dict.update(curr_node_output_map)
return tensor_dict

@classmethod
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):
Expand Down
38 changes: 27 additions & 11 deletions onnx_tf/backend_tf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,32 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
self.backend = backend
self.outputs = []

# get initializer from the main graph and all subgraphs in loop or if or scan
# into tensor_dict
def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
if graph.initializer:
graph_tensor_dict.update(
self.backend._onnx_initializer_to_input_dict_items(graph.initializer))
for node in graph.node:
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
body, graph_tensor_dict)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
then_branch, graph_tensor_dict)
else_branch = onnx_node.attrs['else_branch']
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
else_branch, graph_tensor_dict)
return graph_tensor_dict

@tf.function
def gen_tensor_dict(self, input_dict_items):
tensor_dict = dict(input_dict_items)
def gen_tensor_dict(self, input_dict):
tensor_dict = self._get_initializer_from_graph_and_subgraphs(
self.graph_def, dict(input_dict))

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
Expand All @@ -31,15 +54,8 @@ def gen_tensor_dict(self, input_dict_items):

@tf.function
def __call__(self, **kwargs):
tensor_dict = kwargs

if self.graph_def.initializer:
input_dict_items = self.backend._onnx_initializer_to_input_dict_items(
self.graph_def.initializer)
else:
input_dict_items = []

tensor_dict.update(input_dict_items)
tensor_dict = self._get_initializer_from_graph_and_subgraphs(
self.graph_def, kwargs)

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
Expand Down
4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ def version_4(cls, node, **kwargs):
@classmethod
def version_11(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_13(cls, node, **kwargs):
return cls._common(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ def version_12(cls, node, **kwargs):
inputs=[value],
attrs={"dtype": dtype})
]

@classmethod
def version_13(cls, node, **kwargs):
return cls.version_12(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/greater.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def version_7(cls, node, **kwargs):
@classmethod
def version_9(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

@classmethod
def version_13(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
15 changes: 8 additions & 7 deletions onnx_tf/handlers/backend/if.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,19 @@ def _common(cls, node, **kwargs):
def true_fn():
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
subgraph=then_branch,
input_values={}, # all inputs of then_branch are in tensor_dict
tensor_dict=kwargs["tensor_dict"],
tensor_dict=dict(kwargs["tensor_dict"]),
opset=current_opset)
return [subgraph_tensor_dict[o.name] for o in then_branch.output]

def false_fn():
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
subgraph=else_branch,
input_values={}, # all inputs of else_branch are in tensor_dict
tensor_dict=kwargs["tensor_dict"],
tensor_dict=dict(kwargs["tensor_dict"]),
opset=current_opset)
return [subgraph_tensor_dict[o.name] for o in else_branch.output]

return [
cls.make_tensor_from_onnx_node(node, inputs=[cond, true_fn, false_fn])
]
return cls.make_tensor_from_onnx_node(node,
inputs=[cond, true_fn, false_fn])

@classmethod
def version_1(cls, node, **kwargs):
Expand All @@ -45,3 +42,7 @@ def version_1(cls, node, **kwargs):
@classmethod
def version_11(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_13(cls, node, **kwargs):
return cls._common(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/less.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def version_7(cls, node, **kwargs):
@classmethod
def version_9(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

@classmethod
def version_13(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
Loading

0 comments on commit 24877e8

Please sign in to comment.