forked from zhenhuaw-me/tflite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_mobilenet.py
174 lines (119 loc) · 5.51 KB
/
test_mobilenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import tflite
# Example of parsing a TFLite model with `tflite` python package.
# Use this package, you can *import* the `tflite* package ONLY ONCE.
# Otherwise, you need to import every class when using them.
def test_mobilenet():
cur_dir = os.path.dirname(os.path.abspath(__file__))
tflm_dir = os.path.abspath(cur_dir + '/../assets/tests')
tflm_name = 'mobilenet_v1_1.0_224_quant.tflite'
path = os.path.join(tflm_dir, tflm_name)
with open(path, 'rb') as f:
buf = f.read()
model = tflite.Model.GetRootAsModel(buf, 0)
############# model #########################################################
# Version of the TFLite Converter.
assert(model.Version() == 3)
# Strings are binary format, need to decode.
# Description is useful when exchanging models.
assert(model.Description().decode('utf-8') == 'TOCO Converted.')
# How many operator types in this model.
assert(model.OperatorCodesLength() == 5)
# A model may have multiple subgraphs.
assert(model.SubgraphsLength() == 1)
# How many tensor buffer.
assert(model.BuffersLength() == 90)
############# subgraph ######################################################
# Chose one subgraph.
graph = model.Subgraphs(0)
# Tensors in the subgraph are represented by index description.
assert(graph.InputsLength() == 1)
assert(graph.OutputsLength() == 1)
assert(graph.InputsAsNumpy()[0] == 88)
assert(graph.OutputsAsNumpy()[0] == 87)
# All arrays can dump as Numpy array, or access individually.
assert(graph.Inputs(0) == 88)
assert(graph.Outputs(0) == 87)
# Name may used to debug or check for model containing multiple subgraphs.
assert(graph.Name() == None)
# Operators in the subgraph.
assert(graph.OperatorsLength() == 31)
# Let's use the first operator.
op = graph.Operators(0)
############# operator type #################################################
# Operator Type is also stored as index, which can obtain from `Model` object.
op_code = model.OperatorCodes(op.OpcodeIndex())
# The first operator is a convolution.
assert(op_code.BuiltinCode() == tflite.BuiltinOperator.CONV_2D)
# Custom operator need more interface, won't cover here.
assert(op_code.BuiltinCode() != tflite.BuiltinOperator.CUSTOM)
############# the operator ##################################################
# The first operator is a convolution.
# The inputs are: data, weight and bias.
assert(op.InputsLength() == 3)
assert(op.OutputsLength() == 1)
# The data of first Conv2D is input of the model
assert(op.Inputs(0) == 88)
assert(op.Inputs(0) == graph.Inputs(0))
# Operators have dedicated options per its type
assert(op.BuiltinOptionsType() == tflite.BuiltinOptions.Conv2DOptions)
op_opt = op.BuiltinOptions()
############# operator option ###############################################
# Check the Conv2D options.
# Parse the Table of options.
opt = tflite.Conv2DOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
# The options.
assert(opt.Padding() == tflite.Padding.SAME)
assert(opt.StrideW() == 2)
assert(opt.StrideH() == 2)
assert(opt.DilationWFactor() == 1)
assert(opt.DilationHFactor() == 1)
# Further check activation function type if there were.
assert(opt.FusedActivationFunction() == tflite.ActivationFunctionType.NONE)
############# tensor ########################################################
# Check the weight tensor of the first convolution.
tensor_index = op.Inputs(1)
# use `graph.Tensors(index)` to get the tensor object.
tensor = graph.Tensors(tensor_index)
# view the shape
assert(tensor.ShapeLength() == 4)
assert(tensor.ShapeAsNumpy()[1] == 3)
# All arrays can dump as Numpy array, or access individually.
assert(tensor.Shape(1) == 3)
# data type has been encoded
assert(tensor.Type() == tflite.TensorType.UINT8)
# name is in binary format, decode it
assert(tensor.Name().decode('utf-8') ==
'MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/FakeQuantWithMinMaxVars')
# buffer of the tensor is represented in index too.
assert(tensor.Buffer() == 66)
# quantization parameters of the tensor, only valid for quantized model
assert(tensor.Quantization())
assert(not tensor.IsVariable())
############# quant #######################################################
# Quantization parameters of the tensor, only valid for quantized model
quant = tensor.Quantization()
# Scale and zero point
assert(quant.ScaleAsNumpy()[0] == 0.02182667888700962)
assert(quant.ZeroPointAsNumpy()[0] == 151)
# Min/max also avaiable
assert(quant.MinAsNumpy()[0] == -3.265998125076294)
assert(quant.MaxAsNumpy()[0] == 2.2779781818389893)
# All arrays can dump as Numpy array, or access individually.
assert(quant.Scale(0) == 0.02182667888700962)
assert(quant.ZeroPoint(0) == 151)
assert(quant.Min(0) == -3.265998125076294)
assert(quant.Max(0) == 2.2779781818389893)
############# memory #######################################################
# Get the buffer object.
buf = model.Buffers(tensor.Buffer())
assert(buf.DataLength() == 864)
assert(buf.DataAsNumpy()[0] == 151)
# All arrays can dump as Numpy array, or access individually.
assert(buf.Data(0) == 151)
# The Numpy array is flattened.
npa = buf.DataAsNumpy()
assert(npa.shape == (864,))
if __name__ == '__main__':
test_mobilenet()