Skip to content

Commit a826e2b

Browse files
committed
add data_format selection support to ocr
1 parent db0ad17 commit a826e2b

17 files changed

+443
-93
lines changed

configs/cls/cls_mv3.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ Architecture:
2323
name: MobileNetV3
2424
scale: 0.35
2525
model_name: small
26+
data_format: NHWC
2627
Neck:
2728
Head:
2829
name: ClsHead
2930
class_dim: 2
31+
data_format: NHWC
3032

3133
Loss:
3234
name: ClsLoss

configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,16 @@ Architecture:
3535
scale: 0.5
3636
model_name: large
3737
disable_se: true
38+
data_format: NHWC
3839
Neck:
3940
name: RSEFPN
4041
out_channels: 96
4142
shortcut: True
43+
data_format: NHWC
4244
Head:
4345
name: DBHead
4446
k: 50
47+
data_format: NHWC
4548
Student2:
4649
pretrained:
4750
model_type: det
@@ -52,13 +55,16 @@ Architecture:
5255
scale: 0.5
5356
model_name: large
5457
disable_se: true
58+
data_format: NHWC
5559
Neck:
5660
name: RSEFPN
5761
out_channels: 96
5862
shortcut: True
63+
data_format: NHWC
5964
Head:
6065
name: DBHead
6166
k: 50
67+
data_format: NHWC
6268
Teacher:
6369
freeze_params: true
6470
return_all_feats: false
@@ -68,13 +74,16 @@ Architecture:
6874
name: ResNet_vd
6975
in_channels: 3
7076
layers: 50
77+
data_format: NHWC
7178
Neck:
7279
name: LKPAN
7380
out_channels: 256
81+
data_format: NHWC
7482
Head:
7583
name: DBHead
7684
kernel_list: [7,2,2]
7785
k: 50
86+
data_format: NHWC
7887

7988
Loss:
8089
name: CombinedLoss

configs/det/det_mv3_db.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ Architecture:
2525
name: MobileNetV3
2626
scale: 0.5
2727
model_name: large
28+
data_format: NHWC
2829
Neck:
2930
name: DBFPN
3031
out_channels: 256
32+
data_format: NHWC
3133
Head:
3234
name: DBHead
3335
k: 50
36+
data_format: NHWC
3437

3538
Loss:
3639
name: DBLoss
@@ -64,7 +67,7 @@ Metric:
6467
Train:
6568
dataset:
6669
name: SimpleDataSet
67-
data_dir: ./train_data/icdar2015/text_localization/
70+
data_dir: ./
6871
label_file_list:
6972
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
7073
ratio_list: [1.0]
@@ -107,7 +110,7 @@ Train:
107110
Eval:
108111
dataset:
109112
name: SimpleDataSet
110-
data_dir: ./train_data/icdar2015/text_localization/
113+
data_dir: ./
111114
label_file_list:
112115
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
113116
transforms:

configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Architecture:
4444
last_conv_stride: [1, 2]
4545
last_pool_type: avg
4646
last_pool_kernel_size: [2, 2]
47+
data_format: 'NHWC'
4748
Head:
4849
name: MultiHead
4950
head_list:
@@ -59,6 +60,7 @@ Architecture:
5960
- SARHead:
6061
enc_dim: 512
6162
max_text_length: *max_text_length
63+
data_format: 'NHWC'
6264

6365
Loss:
6466
name: MultiLoss

configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml

100644100755
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Architecture:
5454
last_conv_stride: [1, 2]
5555
last_pool_type: avg
5656
last_pool_kernel_size: [2, 2]
57+
data_format: 'NHWC'
5758
Head:
5859
name: MultiHead
5960
head_list:
@@ -69,6 +70,7 @@ Architecture:
6970
- SARHead:
7071
enc_dim: 512
7172
max_text_length: *max_text_length
73+
data_format: 'NHWC'
7274
Student:
7375
pretrained:
7476
freeze_params: false
@@ -82,6 +84,7 @@ Architecture:
8284
last_conv_stride: [1, 2]
8385
last_pool_type: avg
8486
last_pool_kernel_size: [2, 2]
87+
data_format: 'NHWC'
8588
Head:
8689
name: MultiHead
8790
head_list:
@@ -97,6 +100,7 @@ Architecture:
97100
- SARHead:
98101
enc_dim: 512
99102
max_text_length: *max_text_length
103+
data_format: 'NHWC'
100104
Loss:
101105
name: CombinedLoss
102106
loss_config_list:

ppocr/modeling/backbones/det_mobilenet_v3.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ def make_divisible(v, divisor=8, min_value=None):
3535

3636
class MobileNetV3(nn.Layer):
3737
def __init__(
38-
self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs
38+
self,
39+
in_channels=3,
40+
model_name="large",
41+
scale=0.5,
42+
disable_se=False,
43+
data_format="NCHW",
44+
**kwargs,
3945
):
4046
"""
4147
the MobilenetV3 backbone network for detection module.
@@ -46,6 +52,7 @@ def __init__(
4652

4753
self.disable_se = disable_se
4854

55+
self.nchw = data_format == "NCHW"
4956
if model_name == "large":
5057
cfg = [
5158
# k, exp, c, se, nl, s,
@@ -102,6 +109,7 @@ def __init__(
102109
groups=1,
103110
if_act=True,
104111
act="hardswish",
112+
data_format=data_format,
105113
)
106114

107115
self.stages = []
@@ -125,6 +133,7 @@ def __init__(
125133
stride=s,
126134
use_se=se,
127135
act=nl,
136+
data_format=data_format,
128137
)
129138
)
130139
inplanes = make_divisible(scale * c)
@@ -139,6 +148,7 @@ def __init__(
139148
groups=1,
140149
if_act=True,
141150
act="hardswish",
151+
data_format=data_format,
142152
)
143153
)
144154
self.stages.append(nn.Sequential(*block_list))
@@ -147,6 +157,8 @@ def __init__(
147157
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
148158

149159
def forward(self, x):
160+
if not self.nchw:
161+
x = x.transpose([0, 2, 3, 1])
150162
x = self.conv(x)
151163
out_list = []
152164
for stage in self.stages:
@@ -166,6 +178,7 @@ def __init__(
166178
groups=1,
167179
if_act=True,
168180
act=None,
181+
data_format="NCHW",
169182
):
170183
super(ConvBNLayer, self).__init__()
171184
self.if_act = if_act
@@ -178,9 +191,12 @@ def __init__(
178191
padding=padding,
179192
groups=groups,
180193
bias_attr=False,
194+
data_format=data_format,
181195
)
182196

183-
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
197+
self.bn = nn.BatchNorm(
198+
num_channels=out_channels, act=None, data_layout=data_format
199+
)
184200

185201
def forward(self, x):
186202
x = self.conv(x)
@@ -210,6 +226,7 @@ def __init__(
210226
stride,
211227
use_se,
212228
act=None,
229+
data_format="NCHW",
213230
):
214231
super(ResidualUnit, self).__init__()
215232
self.if_shortcut = stride == 1 and in_channels == out_channels
@@ -223,6 +240,7 @@ def __init__(
223240
padding=0,
224241
if_act=True,
225242
act=act,
243+
data_format=data_format,
226244
)
227245
self.bottleneck_conv = ConvBNLayer(
228246
in_channels=mid_channels,
@@ -233,9 +251,10 @@ def __init__(
233251
groups=mid_channels,
234252
if_act=True,
235253
act=act,
254+
data_format=data_format,
236255
)
237256
if self.if_se:
238-
self.mid_se = SEModule(mid_channels)
257+
self.mid_se = SEModule(mid_channels, data_format=data_format)
239258
self.linear_conv = ConvBNLayer(
240259
in_channels=mid_channels,
241260
out_channels=out_channels,
@@ -244,6 +263,7 @@ def __init__(
244263
padding=0,
245264
if_act=False,
246265
act=None,
266+
data_format=data_format,
247267
)
248268

249269
def forward(self, inputs):
@@ -258,22 +278,24 @@ def forward(self, inputs):
258278

259279

260280
class SEModule(nn.Layer):
261-
def __init__(self, in_channels, reduction=4):
281+
def __init__(self, in_channels, reduction=4, data_format="NCHW"):
262282
super(SEModule, self).__init__()
263-
self.avg_pool = nn.AdaptiveAvgPool2D(1)
283+
self.avg_pool = nn.AdaptiveAvgPool2D(1, data_format=data_format)
264284
self.conv1 = nn.Conv2D(
265285
in_channels=in_channels,
266286
out_channels=in_channels // reduction,
267287
kernel_size=1,
268288
stride=1,
269289
padding=0,
290+
data_format=data_format,
270291
)
271292
self.conv2 = nn.Conv2D(
272293
in_channels=in_channels // reduction,
273294
out_channels=in_channels,
274295
kernel_size=1,
275296
stride=1,
276297
padding=0,
298+
data_format=data_format,
277299
)
278300

279301
def forward(self, inputs):

0 commit comments

Comments
 (0)