Skip to content

Commit 58a2c2f

Browse files
committed
- Modify some examples.
- Rename the folder `tests` to `examples` to avoid misleading. - Update the `readme.md`
1 parent c6d273a commit 58a2c2f

17 files changed

+179
-199
lines changed

examples/metric_recorder.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2021/1/4
3+
# @Author : Lart Pang
4+
# @GitHub : https://github.com/lartpang
5+
6+
import numpy as np
7+
8+
from py_sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure
9+
10+
11+
def ndarray_to_basetype(data):
12+
"""
13+
将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型,
14+
即列表(.tolist())和python标量
15+
"""
16+
17+
def _to_list_or_scalar(item):
18+
listed_item = item.tolist()
19+
if isinstance(listed_item, list) and len(listed_item) == 1:
20+
listed_item = listed_item[0]
21+
return listed_item
22+
23+
if isinstance(data, (tuple, list)):
24+
results = [_to_list_or_scalar(item) for item in data]
25+
elif isinstance(data, dict):
26+
results = {k: _to_list_or_scalar(item) for k, item in data.items()}
27+
else:
28+
assert isinstance(data, np.ndarray)
29+
results = _to_list_or_scalar(data)
30+
return results
31+
32+
33+
class CalTotalMetric(object):
34+
def __init__(self):
35+
"""
36+
用于统计各种指标的类
37+
https://github.com/lartpang/Py-SOD-VOS-EvalToolkit/blob/81ce89da6813fdd3e22e3f20e3a09fe1e4a1a87c/utils/recorders/metric_recorder.py
38+
"""
39+
self.mae = MAE()
40+
self.fm = Fmeasure()
41+
self.sm = Smeasure()
42+
self.em = Emeasure()
43+
self.wfm = WeightedFmeasure()
44+
45+
def step(self, pre: np.ndarray, gt: np.ndarray):
46+
assert pre.shape == gt.shape
47+
assert pre.dtype == np.uint8
48+
assert gt.dtype == np.uint8
49+
50+
self.mae.step(pre, gt)
51+
self.sm.step(pre, gt)
52+
self.fm.step(pre, gt)
53+
self.em.step(pre, gt)
54+
self.wfm.step(pre, gt)
55+
56+
def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
57+
"""
58+
返回指标计算结果:
59+
60+
- 曲线数据(sequential): fm/em/p/r
61+
- 数值指标(numerical): SM/MAE/maxE/avgE/adpE/maxF/avgF/adpF/wFm
62+
"""
63+
fm_info = self.fm.get_results()
64+
fm = fm_info["fm"]
65+
pr = fm_info["pr"]
66+
wfm = self.wfm.get_results()["wfm"]
67+
sm = self.sm.get_results()["sm"]
68+
em = self.em.get_results()["em"]
69+
mae = self.mae.get_results()["mae"]
70+
71+
sequential_results = {
72+
"fm": np.flip(fm["curve"]),
73+
"em": np.flip(em["curve"]),
74+
"p": np.flip(pr["p"]),
75+
"r": np.flip(pr["r"]),
76+
}
77+
numerical_results = {
78+
"SM": sm,
79+
"MAE": mae,
80+
"maxE": em["curve"].max(),
81+
"avgE": em["curve"].mean(),
82+
"adpE": em["adp"],
83+
"maxF": fm["curve"].max(),
84+
"avgF": fm["curve"].mean(),
85+
"adpF": fm["adp"],
86+
"wFm": wfm,
87+
}
88+
if num_bits is not None and isinstance(num_bits, int):
89+
numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
90+
if not return_ndarray:
91+
sequential_results = ndarray_to_basetype(sequential_results)
92+
numerical_results = ndarray_to_basetype(numerical_results)
93+
return {"sequential": sequential_results, "numerical": numerical_results}
94+
95+
96+
if __name__ == "__main__":
97+
data_loader = ...
98+
model = ...
99+
100+
cal_total_seg_metrics = CalTotalMetric()
101+
for batch in data_loader:
102+
seg_preds = model(batch)
103+
for seg_pred in seg_preds:
104+
mask_array = ...
105+
cal_total_seg_metrics.step(seg_pred, mask_array)
106+
fixed_seg_results = cal_total_seg_metrics.get_results()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

examples/test_metrics.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2020/11/21
3+
# @Author : Lart Pang
4+
# @GitHub : https://github.com/lartpang
5+
6+
import os
7+
8+
import cv2
9+
from tqdm import tqdm
10+
11+
# pip install pysodmetrics
12+
from py_sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure
13+
14+
FM = Fmeasure()
15+
WFM = WeightedFmeasure()
16+
SM = Smeasure()
17+
EM = Emeasure()
18+
MAE = MAE()
19+
20+
data_root = "./test_data"
21+
mask_root = os.path.join(data_root, "masks")
22+
pred_root = os.path.join(data_root, "preds")
23+
mask_name_list = sorted(os.listdir(mask_root))
24+
for mask_name in tqdm(mask_name_list, total=len(mask_name_list)):
25+
mask_path = os.path.join(mask_root, mask_name)
26+
pred_path = os.path.join(pred_root, mask_name)
27+
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
28+
pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
29+
FM.step(pred=pred, gt=mask)
30+
WFM.step(pred=pred, gt=mask)
31+
SM.step(pred=pred, gt=mask)
32+
EM.step(pred=pred, gt=mask)
33+
MAE.step(pred=pred, gt=mask)
34+
35+
fm = FM.get_results()["fm"]
36+
wfm = WFM.get_results()["wfm"]
37+
sm = SM.get_results()["sm"]
38+
em = EM.get_results()["em"]
39+
mae = MAE.get_results()["mae"]
40+
41+
results = {
42+
"Smeasure": sm.round(3),
43+
"wFmeasure": wfm.round(3),
44+
"MAE": mae.round(3),
45+
"adpEm": em["adp"].round(3),
46+
"meanEm": em["curve"].mean().round(3),
47+
"maxEm": em["curve"].max().round(3),
48+
"adpFm": fm["adp"].round(3),
49+
"meanFm": fm["curve"].mean().round(3),
50+
"maxFm": fm["curve"].max().round(3),
51+
}
52+
53+
print(results)

py_sod_metrics/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from py_sod_metrics.sod_metrics import *
2+
3+
__version__ = "1.2.2"

0 commit comments

Comments
 (0)