forked from huoyijie/AdvancedEAST
-
Notifications
You must be signed in to change notification settings - Fork 1
/
nms.py
86 lines (75 loc) · 2.95 KB
/
nms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# coding=utf-8
import numpy as np
import cfg
def should_merge(region, i, j):
neighbor = {(i, j - 1)}
return not region.isdisjoint(neighbor)
def region_neighbor(region_set):
region_pixels = np.array(list(region_set))
j_min = np.amin(region_pixels, axis=0)[1] - 1
j_max = np.amax(region_pixels, axis=0)[1] + 1
i_m = np.amin(region_pixels, axis=0)[0] + 1
region_pixels[:, 0] += 1
neighbor = {(region_pixels[n, 0], region_pixels[n, 1]) for n in
range(len(region_pixels))}
neighbor.add((i_m, j_min))
neighbor.add((i_m, j_max))
return neighbor
def region_group(region_list):
S = [i for i in range(len(region_list))]
D = []
while len(S) > 0:
m = S.pop(0)
if len(S) == 0:
# S has only one element, put it to D
D.append([m])
else:
D.append(rec_region_merge(region_list, m, S))
return D
def rec_region_merge(region_list, m, S):
rows = [m]
tmp = []
for n in S:
if not region_neighbor(region_list[m]).isdisjoint(region_list[n]) or \
not region_neighbor(region_list[n]).isdisjoint(region_list[m]):
# 第m与n相交
tmp.append(n)
for d in tmp:
S.remove(d)
for e in tmp:
rows.extend(rec_region_merge(region_list, e, S))
return rows
def nms(predict, activation_pixels, threshold=cfg.side_vertex_pixel_threshold):
region_list = []
for i, j in zip(activation_pixels[0], activation_pixels[1]):
merge = False
for k in range(len(region_list)):
if should_merge(region_list[k], i, j):
region_list[k].add((i, j))
merge = True
# Fixme 重叠文本区域处理,存在和多个区域邻接的pixels,先都merge试试
# break
if not merge:
region_list.append({(i, j)})
D = region_group(region_list)
quad_list = np.zeros((len(D), 4, 2))
score_list = np.zeros((len(D), 4))
for group, g_th in zip(D, range(len(D))):
total_score = np.zeros((4, 2))
for row in group:
for ij in region_list[row]:
score = predict[ij[0], ij[1], 1]
if score >= threshold:
ith_score = predict[ij[0], ij[1], 2:3]
if not (cfg.trunc_threshold <= ith_score < 1 -
cfg.trunc_threshold):
ith = int(np.around(ith_score))
total_score[ith * 2:(ith + 1) * 2] += score
px = (ij[1] + 0.5) * cfg.pixel_size
py = (ij[0] + 0.5) * cfg.pixel_size
p_v = [px, py] + np.reshape(predict[ij[0], ij[1], 3:7],
(2, 2))
quad_list[g_th, ith * 2:(ith + 1) * 2] += score * p_v
score_list[g_th] = total_score[:, 0]
quad_list[g_th] /= (total_score + cfg.epsilon)
return score_list, quad_list