@@ -39,11 +39,13 @@ def _to_list_or_scalar(item):
39
39
}
40
40
41
41
42
- class CalTotalMetricV1 :
42
+ class MetricRecorderV1 :
43
43
def __init__ (self ):
44
44
"""
45
45
用于统计各种指标的类
46
46
https://github.com/lartpang/Py-SOD-VOS-EvalToolkit/blob/81ce89da6813fdd3e22e3f20e3a09fe1e4a1a87c/utils/recorders/metric_recorder.py
47
+
48
+ 主要应用于旧版本实现中的五个指标,即mae/fm/sm/em/wfm。推荐使用V2版本。
47
49
"""
48
50
self .mae = INDIVADUAL_METRIC_MAPPING ["mae" ]()
49
51
self .fm = INDIVADUAL_METRIC_MAPPING ["fm" ]()
@@ -103,46 +105,76 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
103
105
104
106
105
107
BINARY_CLASSIFICATION_METRIC_MAPPING = {
106
- "fmeasure" : {
107
- "handler" : py_sod_metrics .FmeasureHandler ,
108
- "kwargs" : dict (with_dynamic = True , with_adaptive = True , with_binary = True , beta = 0.3 ),
109
- },
110
- "precision" : {
111
- "handler" : py_sod_metrics .PrecisionHandler ,
112
- "kwargs" : dict (with_dynamic = True , with_adaptive = False , with_binary = False ),
113
- },
114
- "recall" : {
115
- "handler" : py_sod_metrics .RecallHandler ,
116
- "kwargs" : dict (with_dynamic = True , with_adaptive = False , with_binary = False ),
117
- },
118
- "iou" : {
119
- "handler" : py_sod_metrics .IOUHandler ,
120
- "kwargs" : dict (with_dynamic = True , with_adaptive = True , with_binary = True ),
121
- },
122
- "dice" : {
123
- "handler" : py_sod_metrics .DICEHandler ,
124
- "kwargs" : dict (with_dynamic = True , with_adaptive = True , with_binary = True ),
125
- },
126
- "specificity" : {
127
- "handler" : py_sod_metrics .SpecificityHandler ,
128
- "kwargs" : dict (with_dynamic = True , with_adaptive = True , with_binary = True ),
129
- },
130
- "ber" : {
131
- "handler" : py_sod_metrics .BERHandler ,
132
- "kwargs" : dict (with_dynamic = True , with_adaptive = True , with_binary = True ),
133
- }
108
+ # 灰度数据指标
109
+ "fm" : py_sod_metrics .FmeasureHandler (with_adaptive = True , with_dynamic = True , beta = 0.3 ),
110
+ "f1" : py_sod_metrics .FmeasureHandler (with_adaptive = True , with_dynamic = True , beta = 0.1 ),
111
+ "pre" : py_sod_metrics .PrecisionHandler (with_adaptive = True , with_dynamic = True ),
112
+ "rec" : py_sod_metrics .RecallHandler (with_adaptive = True , with_dynamic = True ),
113
+ "iou" : py_sod_metrics .IOUHandler (with_adaptive = True , with_dynamic = True ),
114
+ "dice" : py_sod_metrics .DICEHandler (with_adaptive = True , with_dynamic = True ),
115
+ "spec" : py_sod_metrics .SpecificityHandler (with_adaptive = True , with_dynamic = True ),
116
+ "ber" : py_sod_metrics .BERHandler (with_adaptive = True , with_dynamic = True ),
117
+ # 二值化数据指标的特殊情况一:各个样本独立计算指标后取平均
118
+ "sample_bifm" : py_sod_metrics .FmeasureHandler (
119
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True , beta = 0.3
120
+ ),
121
+ "sample_bif1" : py_sod_metrics .FmeasureHandler (
122
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True , beta = 1
123
+ ),
124
+ "sample_bipre" : py_sod_metrics .PrecisionHandler (
125
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
126
+ ),
127
+ "sample_birec" : py_sod_metrics .RecallHandler (
128
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
129
+ ),
130
+ "sample_biiou" : py_sod_metrics .IOUHandler (
131
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
132
+ ),
133
+ "sample_bidice" : py_sod_metrics .DICEHandler (
134
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
135
+ ),
136
+ "sample_bispec" : py_sod_metrics .SpecificityHandler (
137
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
138
+ ),
139
+ "sample_biber" : py_sod_metrics .BERHandler (
140
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
141
+ ),
142
+ # 二值化数据指标的特殊情况二:汇总所有样本的tp、fp、tn、fn后整体计算指标
143
+ "overall_bifm" : py_sod_metrics .FmeasureHandler (
144
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True , beta = 0.3
145
+ ),
146
+ "overall_bif1" : py_sod_metrics .FmeasureHandler (
147
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True , beta = 1
148
+ ),
149
+ "overall_bipre" : py_sod_metrics .PrecisionHandler (
150
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
151
+ ),
152
+ "overall_birec" : py_sod_metrics .RecallHandler (
153
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
154
+ ),
155
+ "overall_biiou" : py_sod_metrics .IOUHandler (
156
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
157
+ ),
158
+ "overall_bidice" : py_sod_metrics .DICEHandler (
159
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
160
+ ),
161
+ "overall_bispec" : py_sod_metrics .SpecificityHandler (
162
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
163
+ ),
164
+ "overall_biber" : py_sod_metrics .BERHandler (
165
+ with_adaptive = False , with_dynamic = False , with_binary = True , sample_based = True
166
+ ),
134
167
}
135
168
136
169
137
- class CalTotalMetricV2 :
138
- # 'fm' is replaced by 'fmeasure' in BINARY_CLASSIFICATION_METRIC_MAPPING
170
+ class MetricRecorderV2 :
139
171
suppoted_metrics = ["mae" , "em" , "sm" , "wfm" ] + sorted (
140
- BINARY_CLASSIFICATION_METRIC_MAPPING .keys ()
172
+ [ k for k in BINARY_CLASSIFICATION_METRIC_MAPPING .keys () if not k . startswith (( 'sample_' , 'overall_' ))]
141
173
)
142
174
143
- def __init__ (self , metric_names = None ):
175
+ def __init__ (self , metric_names = ( "sm" , "wfm" , "mae" , "fmeasure" , "em" ) ):
144
176
"""
145
- 用于统计各种指标的类
177
+ 用于统计各种指标的类,支持更多的指标,更好的兼容性。
146
178
"""
147
179
if not metric_names :
148
180
metric_names = self .suppoted_metrics
@@ -161,24 +193,18 @@ def __init__(self, metric_names=None):
161
193
has_existed = True
162
194
metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING [metric_name ]
163
195
self .metric_objs ["fmeasurev2" ].add_handler (
164
- # instantiate inside the class instead of outside the class
165
- metric_handler ["handler" ](** metric_handler ["kwargs" ])
196
+ handler_name = metric_name ,
197
+ metric_handler = metric_handler ["handler" ](** metric_handler ["kwargs" ]),
166
198
)
167
199
168
- def update (self , pre : np .ndarray , gt : np .ndarray ):
200
+ def step (self , pre : np .ndarray , gt : np .ndarray ):
169
201
assert pre .shape == gt .shape , (pre .shape , gt .shape )
170
202
assert pre .dtype == gt .dtype == np .uint8 , (pre .dtype , gt .dtype )
171
203
172
204
for m_obj in self .metric_objs .values ():
173
205
m_obj .step (pre , gt )
174
206
175
- def show (self , num_bits : int = 3 , return_ndarray : bool = False ) -> dict :
176
- """
177
- 返回指标计算结果:
178
-
179
- - 曲线数据(sequential)
180
- - 数值指标(numerical)
181
- """
207
+ def get_all_results (self , num_bits : int = 3 , return_ndarray : bool = False ) -> dict :
182
208
sequential_results = {}
183
209
numerical_results = {}
184
210
for m_name , m_obj in self .metric_objs .items ():
@@ -187,15 +213,12 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
187
213
for _name , results in info .items ():
188
214
dynamic_results = results .get ("dynamic" )
189
215
adaptive_results = results .get ("adaptive" )
190
- binary_results = results .get ('binary' )
191
216
if dynamic_results is not None :
192
217
sequential_results [_name ] = np .flip (dynamic_results )
193
218
numerical_results [f"max{ _name } " ] = dynamic_results .max ()
194
219
numerical_results [f"avg{ _name } " ] = dynamic_results .mean ()
195
220
if adaptive_results is not None :
196
221
numerical_results [f"adp{ _name } " ] = adaptive_results
197
- if binary_results is not None :
198
- numerical_results [f"bi{ _name } " ] = binary_results
199
222
else :
200
223
results = info [m_name ]
201
224
if m_name in ("wfm" , "sm" , "mae" ):
@@ -204,9 +227,9 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
204
227
sequential_results [m_name ] = np .flip (results ["curve" ])
205
228
numerical_results .update (
206
229
{
207
- "maxe " : results ["curve" ].max (),
208
- "avge " : results ["curve" ].mean (),
209
- "adpe " : results ["adp" ],
230
+ "maxem " : results ["curve" ].max (),
231
+ "avgem " : results ["curve" ].mean (),
232
+ "adpem " : results ["adp" ],
210
233
}
211
234
)
212
235
else :
@@ -219,15 +242,81 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
219
242
numerical_results = ndarray_to_basetype (numerical_results )
220
243
return {"sequential" : sequential_results , "numerical" : numerical_results }
221
244
245
+ def show (self , num_bits : int = 3 , return_ndarray : bool = False ) -> dict :
246
+ return self .get_all_results (num_bits = num_bits , return_ndarray = return_ndarray )["numerical" ]
247
+
248
+
249
+ class BinaryMetricRecorder :
250
+ suppoted_metrics = ["mae" , "sm" , "wfm" ] + sorted (
251
+ [k for k in BINARY_CLASSIFICATION_METRIC_MAPPING .keys () if k .startswith (('sample_' , 'overall_' ))]
252
+ )
253
+
254
+ def __init__ (self , metric_names = ("bif1" , "biprecision" , "birecall" , "biiou" )):
255
+ """
256
+ 用于统计各种指标的类,主要适用于对单通道灰度图计算二值图像的指标。
257
+ """
258
+ if not metric_names :
259
+ metric_names = self .suppoted_metrics
260
+ assert all (
261
+ [m in self .suppoted_metrics for m in metric_names ]
262
+ ), f"Only support: { self .suppoted_metrics } "
263
+
264
+ self .metric_objs = {}
265
+ has_existed = False
266
+ for metric_name in metric_names :
267
+ if metric_name in INDIVADUAL_METRIC_MAPPING :
268
+ self .metric_objs [metric_name ] = INDIVADUAL_METRIC_MAPPING [metric_name ]()
269
+ else : # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING
270
+ if not has_existed : # only init once
271
+ self .metric_objs ["fmeasurev2" ] = py_sod_metrics .FmeasureV2 ()
272
+ has_existed = True
273
+ metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING [metric_name ]
274
+ self .metric_objs ["fmeasurev2" ].add_handler (
275
+ handler_name = metric_name ,
276
+ metric_handler = metric_handler ["handler" ](** metric_handler ["kwargs" ]),
277
+ )
278
+
279
+ def step (self , pre : np .ndarray , gt : np .ndarray ):
280
+ assert pre .shape == gt .shape , (pre .shape , gt .shape )
281
+ assert pre .dtype == gt .dtype == np .uint8 , (pre .dtype , gt .dtype )
282
+
283
+ for m_obj in self .metric_objs .values ():
284
+ m_obj .step (pre , gt )
285
+
286
+ def get_all_results (self , num_bits : int = 3 , return_ndarray : bool = False ) -> dict :
287
+ numerical_results = {}
288
+ for m_name , m_obj in self .metric_objs .items ():
289
+ info = m_obj .get_results ()
290
+ if m_name == "fmeasurev2" :
291
+ for _name , results in info .items ():
292
+ binary_results = results .get ("binary" )
293
+ if binary_results is not None :
294
+ numerical_results [_name ] = binary_results
295
+ else :
296
+ results = info [m_name ]
297
+ if m_name in ("mae" , "sm" , "wfm" ):
298
+ numerical_results [m_name ] = results
299
+ else :
300
+ raise NotImplementedError (m_name )
301
+
302
+ if num_bits is not None and isinstance (num_bits , int ):
303
+ numerical_results = {k : v .round (num_bits ) for k , v in numerical_results .items ()}
304
+ if not return_ndarray :
305
+ numerical_results = ndarray_to_basetype (numerical_results )
306
+ return {"numerical" : numerical_results }
307
+
308
+ def show (self , num_bits : int = 3 , return_ndarray : bool = False ) -> dict :
309
+ return self .get_all_results (num_bits = num_bits , return_ndarray = return_ndarray )["numerical" ]
310
+
222
311
223
312
if __name__ == "__main__" :
224
313
data_loader = ...
225
314
model = ...
226
315
227
- cal_total_seg_metrics = CalTotalMetricV1 ()
316
+ cal_total_seg_metrics = MetricRecorderV2 ()
228
317
for batch in data_loader :
229
318
seg_preds = model (batch )
230
319
for seg_pred in seg_preds :
231
320
mask_array = ...
232
321
cal_total_seg_metrics .step (seg_pred , mask_array )
233
- fixed_seg_results = cal_total_seg_metrics .get_results ()
322
+ fixed_seg_results = cal_total_seg_metrics .show ()
0 commit comments