Skip to content

Commit b733575

Browse files
committed
🐞 fix(Smeasure中的质心计算): 原始实现在输入过大时会溢出
原本直接基于np.sum()的实现在输入尺寸过大的时候会出现数值溢出。现在基于np.argwhere()的实现方式则避免了这一问题。 关于质心计算的更多细节可见文档:https://www.yuque.com/lart/blog/gpbigm
1 parent 4aa253a commit b733575

File tree

3 files changed

+37
-31
lines changed

3 files changed

+37
-31
lines changed

examples/test_metrics.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# @GitHub : https://github.com/lartpang
55

66
import os
7+
import sys
8+
from pprint import pprint
79

810
import cv2
9-
from tqdm import tqdm
1011

11-
# pip install pysodmetrics
12+
sys.path.append("..")
1213
from py_sod_metrics import MAE, Emeasure, Fmeasure, Smeasure, WeightedFmeasure
1314

1415
FM = Fmeasure()
@@ -21,7 +22,8 @@
2122
mask_root = os.path.join(data_root, "masks")
2223
pred_root = os.path.join(data_root, "preds")
2324
mask_name_list = sorted(os.listdir(mask_root))
24-
for mask_name in tqdm(mask_name_list, total=len(mask_name_list)):
25+
for i, mask_name in enumerate(mask_name_list):
26+
print(f"[{i}] Processing {mask_name}...")
2527
mask_path = os.path.join(mask_root, mask_name)
2628
pred_path = os.path.join(pred_root, mask_name)
2729
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
@@ -50,24 +52,30 @@
5052
"maxFm": fm["curve"].max(),
5153
}
5254

53-
print(results)
54-
# 'Smeasure': 0.9029763868504661,
55-
# 'wFmeasure': 0.5579812753638986,
56-
# 'MAE': 0.03705558476661653,
57-
# 'adpEm': 0.9408760066970631,
58-
# 'meanEm': 0.9566258293508715,
59-
# 'maxEm': 0.966954482892271,
60-
# 'adpFm': 0.5816750824038355,
61-
# 'meanFm': 0.577051059518767,
62-
# 'maxFm': 0.5886784581120638
55+
default_results = {
56+
"v1_2_3": {
57+
"Smeasure": 0.9029763868504661,
58+
"wFmeasure": 0.5579812753638986,
59+
"MAE": 0.03705558476661653,
60+
"adpEm": 0.9408760066970631,
61+
"meanEm": 0.9566258293508715,
62+
"maxEm": 0.966954482892271,
63+
"adpFm": 0.5816750824038355,
64+
"meanFm": 0.577051059518767,
65+
"maxFm": 0.5886784581120638,
66+
},
67+
"v1_3_0": {
68+
"Smeasure": 0.9029761578759272,
69+
"wFmeasure": 0.5579812753638986,
70+
"MAE": 0.03705558476661653,
71+
"adpEm": 0.9408760066970617,
72+
"meanEm": 0.9566258293508704,
73+
"maxEm": 0.9669544828922699,
74+
"adpFm": 0.5816750824038355,
75+
"meanFm": 0.577051059518767,
76+
"maxFm": 0.5886784581120638,
77+
},
78+
}
6379

64-
# version 1.2.3
65-
# 'Smeasure': 0.9029763868504661,
66-
# 'wFmeasure': 0.5579812753638986,
67-
# 'MAE': 0.03705558476661653,
68-
# 'adpEm': 0.9408760066970631,
69-
# 'meanEm': 0.9566258293508715,
70-
# 'maxEm': 0.966954482892271,
71-
# 'adpFm': 0.5816750824038355,
72-
# 'meanFm': 0.577051059518767,
73-
# 'maxFm': 0.5886784581120638
80+
pprint(results)
81+
pprint({k: default_value - results[k] for k, default_value in default_results["v1_3_0"].items()})

py_sod_metrics/sod_metrics.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,17 @@ def centroid(self, matrix: np.ndarray) -> tuple:
266266
so there is no need to use the redundant addition operation when dividing the region later,
267267
because the sequence generated by ``1:X`` in matlab will contain ``X``.
268268
269-
:param matrix: a data array
269+
:param matrix: a bool data array
270270
:return: the centroid coordinate
271271
"""
272272
h, w = matrix.shape
273-
if matrix.sum() == 0:
273+
area_object = np.count_nonzero(matrix)
274+
if area_object == 0:
274275
x = np.round(w / 2)
275276
y = np.round(h / 2)
276277
else:
277-
area_object = np.sum(matrix)
278-
row_ids = np.arange(h)
279-
col_ids = np.arange(w)
280-
x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object)
281-
y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object)
278+
# More details can be found at: https://www.yuque.com/lart/blog/gpbigm
279+
y, x = np.argwhere(matrix).mean(axis=0).round()
282280
return int(x) + 1, int(y) + 1
283281

284282
def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x: int, y: int) -> dict:

version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.3.0
1+
1.3.1

0 commit comments

Comments
 (0)