Skip to content

Commit b228260

Browse files
committed
✨ feat:
1. Update the metrics for binary image. 2. Use unittest library to check the resutls. 3. Update texamples.
1 parent f3756f2 commit b228260

File tree

3 files changed

+378
-271
lines changed

3 files changed

+378
-271
lines changed

examples/metric_recorder.py

Lines changed: 141 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ def _to_list_or_scalar(item):
3939
}
4040

4141

42-
class CalTotalMetricV1:
42+
class MetricRecorderV1:
4343
def __init__(self):
4444
"""
4545
用于统计各种指标的类
4646
https://github.com/lartpang/Py-SOD-VOS-EvalToolkit/blob/81ce89da6813fdd3e22e3f20e3a09fe1e4a1a87c/utils/recorders/metric_recorder.py
47+
48+
主要应用于旧版本实现中的五个指标,即mae/fm/sm/em/wfm。推荐使用V2版本。
4749
"""
4850
self.mae = INDIVADUAL_METRIC_MAPPING["mae"]()
4951
self.fm = INDIVADUAL_METRIC_MAPPING["fm"]()
@@ -103,46 +105,76 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
103105

104106

105107
BINARY_CLASSIFICATION_METRIC_MAPPING = {
106-
"fmeasure": {
107-
"handler": py_sod_metrics.FmeasureHandler,
108-
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True, beta=0.3),
109-
},
110-
"precision": {
111-
"handler": py_sod_metrics.PrecisionHandler,
112-
"kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False),
113-
},
114-
"recall": {
115-
"handler": py_sod_metrics.RecallHandler,
116-
"kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False),
117-
},
118-
"iou": {
119-
"handler": py_sod_metrics.IOUHandler,
120-
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
121-
},
122-
"dice": {
123-
"handler": py_sod_metrics.DICEHandler,
124-
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
125-
},
126-
"specificity": {
127-
"handler": py_sod_metrics.SpecificityHandler,
128-
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
129-
},
130-
"ber": {
131-
"handler": py_sod_metrics.BERHandler,
132-
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
133-
}
108+
# 灰度数据指标
109+
"fm": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.3),
110+
"f1": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.1),
111+
"pre": py_sod_metrics.PrecisionHandler(with_adaptive=True, with_dynamic=True),
112+
"rec": py_sod_metrics.RecallHandler(with_adaptive=True, with_dynamic=True),
113+
"iou": py_sod_metrics.IOUHandler(with_adaptive=True, with_dynamic=True),
114+
"dice": py_sod_metrics.DICEHandler(with_adaptive=True, with_dynamic=True),
115+
"spec": py_sod_metrics.SpecificityHandler(with_adaptive=True, with_dynamic=True),
116+
"ber": py_sod_metrics.BERHandler(with_adaptive=True, with_dynamic=True),
117+
# 二值化数据指标的特殊情况一:各个样本独立计算指标后取平均
118+
"sample_bifm": py_sod_metrics.FmeasureHandler(
119+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=0.3
120+
),
121+
"sample_bif1": py_sod_metrics.FmeasureHandler(
122+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=1
123+
),
124+
"sample_bipre": py_sod_metrics.PrecisionHandler(
125+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
126+
),
127+
"sample_birec": py_sod_metrics.RecallHandler(
128+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
129+
),
130+
"sample_biiou": py_sod_metrics.IOUHandler(
131+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
132+
),
133+
"sample_bidice": py_sod_metrics.DICEHandler(
134+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
135+
),
136+
"sample_bispec": py_sod_metrics.SpecificityHandler(
137+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
138+
),
139+
"sample_biber": py_sod_metrics.BERHandler(
140+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
141+
),
142+
# 二值化数据指标的特殊情况二:汇总所有样本的tp、fp、tn、fn后整体计算指标
143+
"overall_bifm": py_sod_metrics.FmeasureHandler(
144+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=0.3
145+
),
146+
"overall_bif1": py_sod_metrics.FmeasureHandler(
147+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=1
148+
),
149+
"overall_bipre": py_sod_metrics.PrecisionHandler(
150+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
151+
),
152+
"overall_birec": py_sod_metrics.RecallHandler(
153+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
154+
),
155+
"overall_biiou": py_sod_metrics.IOUHandler(
156+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
157+
),
158+
"overall_bidice": py_sod_metrics.DICEHandler(
159+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
160+
),
161+
"overall_bispec": py_sod_metrics.SpecificityHandler(
162+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
163+
),
164+
"overall_biber": py_sod_metrics.BERHandler(
165+
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
166+
),
134167
}
135168

136169

137-
class CalTotalMetricV2:
138-
# 'fm' is replaced by 'fmeasure' in BINARY_CLASSIFICATION_METRIC_MAPPING
170+
class MetricRecorderV2:
139171
suppoted_metrics = ["mae", "em", "sm", "wfm"] + sorted(
140-
BINARY_CLASSIFICATION_METRIC_MAPPING.keys()
172+
[k for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() if not k.startswith(('sample_', 'overall_'))]
141173
)
142174

143-
def __init__(self, metric_names=None):
175+
def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
144176
"""
145-
用于统计各种指标的类
177+
用于统计各种指标的类,支持更多的指标,更好的兼容性。
146178
"""
147179
if not metric_names:
148180
metric_names = self.suppoted_metrics
@@ -161,24 +193,18 @@ def __init__(self, metric_names=None):
161193
has_existed = True
162194
metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name]
163195
self.metric_objs["fmeasurev2"].add_handler(
164-
# instantiate inside the class instead of outside the class
165-
metric_handler["handler"](**metric_handler["kwargs"])
196+
handler_name=metric_name,
197+
metric_handler=metric_handler["handler"](**metric_handler["kwargs"]),
166198
)
167199

168-
def update(self, pre: np.ndarray, gt: np.ndarray):
200+
def step(self, pre: np.ndarray, gt: np.ndarray):
169201
assert pre.shape == gt.shape, (pre.shape, gt.shape)
170202
assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype)
171203

172204
for m_obj in self.metric_objs.values():
173205
m_obj.step(pre, gt)
174206

175-
def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
176-
"""
177-
返回指标计算结果:
178-
179-
- 曲线数据(sequential)
180-
- 数值指标(numerical)
181-
"""
207+
def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
182208
sequential_results = {}
183209
numerical_results = {}
184210
for m_name, m_obj in self.metric_objs.items():
@@ -187,15 +213,12 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
187213
for _name, results in info.items():
188214
dynamic_results = results.get("dynamic")
189215
adaptive_results = results.get("adaptive")
190-
binary_results = results.get('binary')
191216
if dynamic_results is not None:
192217
sequential_results[_name] = np.flip(dynamic_results)
193218
numerical_results[f"max{_name}"] = dynamic_results.max()
194219
numerical_results[f"avg{_name}"] = dynamic_results.mean()
195220
if adaptive_results is not None:
196221
numerical_results[f"adp{_name}"] = adaptive_results
197-
if binary_results is not None:
198-
numerical_results[f"bi{_name}"] = binary_results
199222
else:
200223
results = info[m_name]
201224
if m_name in ("wfm", "sm", "mae"):
@@ -204,9 +227,9 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
204227
sequential_results[m_name] = np.flip(results["curve"])
205228
numerical_results.update(
206229
{
207-
"maxe": results["curve"].max(),
208-
"avge": results["curve"].mean(),
209-
"adpe": results["adp"],
230+
"maxem": results["curve"].max(),
231+
"avgem": results["curve"].mean(),
232+
"adpem": results["adp"],
210233
}
211234
)
212235
else:
@@ -219,15 +242,81 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
219242
numerical_results = ndarray_to_basetype(numerical_results)
220243
return {"sequential": sequential_results, "numerical": numerical_results}
221244

245+
def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
246+
return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"]
247+
248+
249+
class BinaryMetricRecorder:
250+
suppoted_metrics = ["mae", "sm", "wfm"] + sorted(
251+
[k for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() if k.startswith(('sample_', 'overall_'))]
252+
)
253+
254+
def __init__(self, metric_names=("bif1", "biprecision", "birecall", "biiou")):
255+
"""
256+
用于统计各种指标的类,主要适用于对单通道灰度图计算二值图像的指标。
257+
"""
258+
if not metric_names:
259+
metric_names = self.suppoted_metrics
260+
assert all(
261+
[m in self.suppoted_metrics for m in metric_names]
262+
), f"Only support: {self.suppoted_metrics}"
263+
264+
self.metric_objs = {}
265+
has_existed = False
266+
for metric_name in metric_names:
267+
if metric_name in INDIVADUAL_METRIC_MAPPING:
268+
self.metric_objs[metric_name] = INDIVADUAL_METRIC_MAPPING[metric_name]()
269+
else: # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING
270+
if not has_existed: # only init once
271+
self.metric_objs["fmeasurev2"] = py_sod_metrics.FmeasureV2()
272+
has_existed = True
273+
metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name]
274+
self.metric_objs["fmeasurev2"].add_handler(
275+
handler_name=metric_name,
276+
metric_handler=metric_handler["handler"](**metric_handler["kwargs"]),
277+
)
278+
279+
def step(self, pre: np.ndarray, gt: np.ndarray):
280+
assert pre.shape == gt.shape, (pre.shape, gt.shape)
281+
assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype)
282+
283+
for m_obj in self.metric_objs.values():
284+
m_obj.step(pre, gt)
285+
286+
def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
287+
numerical_results = {}
288+
for m_name, m_obj in self.metric_objs.items():
289+
info = m_obj.get_results()
290+
if m_name == "fmeasurev2":
291+
for _name, results in info.items():
292+
binary_results = results.get("binary")
293+
if binary_results is not None:
294+
numerical_results[_name] = binary_results
295+
else:
296+
results = info[m_name]
297+
if m_name in ("mae", "sm", "wfm"):
298+
numerical_results[m_name] = results
299+
else:
300+
raise NotImplementedError(m_name)
301+
302+
if num_bits is not None and isinstance(num_bits, int):
303+
numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
304+
if not return_ndarray:
305+
numerical_results = ndarray_to_basetype(numerical_results)
306+
return {"numerical": numerical_results}
307+
308+
def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
309+
return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"]
310+
222311

223312
if __name__ == "__main__":
224313
data_loader = ...
225314
model = ...
226315

227-
cal_total_seg_metrics = CalTotalMetricV1()
316+
cal_total_seg_metrics = MetricRecorderV2()
228317
for batch in data_loader:
229318
seg_preds = model(batch)
230319
for seg_pred in seg_preds:
231320
mask_array = ...
232321
cal_total_seg_metrics.step(seg_pred, mask_array)
233-
fixed_seg_results = cal_total_seg_metrics.get_results()
322+
fixed_seg_results = cal_total_seg_metrics.show()

0 commit comments

Comments
 (0)