diff --git a/CHANGELOG.md b/CHANGELOG.md index 257daafa76a..5d0bfad4116 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ All notable changes to this project will be documented in this file. () - Upgrade OpenVINO to 2024.5 and NNCF to 2.14.0 () +- Improve FMetric computation + () ### Bug fixes diff --git a/src/otx/algo/instance_segmentation/segmentors/maskrcnn_tv.py b/src/otx/algo/instance_segmentation/segmentors/maskrcnn_tv.py index 4613e66736d..5721f0d8793 100644 --- a/src/otx/algo/instance_segmentation/segmentors/maskrcnn_tv.py +++ b/src/otx/algo/instance_segmentation/segmentors/maskrcnn_tv.py @@ -91,6 +91,7 @@ def postprocess( for i, (pred, scale_factor, ori_shape) in enumerate(zip(result, scale_factors, ori_shapes)): boxes = pred["boxes"] labels = pred["labels"] + scores = pred["scores"] _scale_factor = [1 / s for s in scale_factor] # (H, W) boxes = boxes * boxes.new_tensor(_scale_factor[::-1]).repeat((1, int(boxes.size(-1) / 2))) h, w = ori_shape @@ -99,8 +100,10 @@ def postprocess( keep_indices = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) > 0 boxes = boxes[keep_indices > 0] labels = labels[keep_indices > 0] + scores = scores[keep_indices > 0] result[i]["boxes"] = boxes result[i]["labels"] = labels - 1 # Convert back to 0-indexed labels + result[i]["scores"] = scores if "masks" in pred: masks = pred["masks"][keep_indices] masks = paste_masks_in_image(masks, boxes, ori_shape) diff --git a/src/otx/core/metrics/fmeasure.py b/src/otx/core/metrics/fmeasure.py index d3c71285f94..3c03c02dde6 100644 --- a/src/otx/core/metrics/fmeasure.py +++ b/src/otx/core/metrics/fmeasure.py @@ -10,98 +10,22 @@ from typing import Any import numpy as np +import torch from torch import Tensor from torchmetrics import Metric, MetricCollection from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchvision import tv_tensors +from torchvision.ops import box_iou +from otx.core.data.entity.base import ImageInfo +from otx.core.data.entity.detection import DetDataEntity, DetPredEntity from otx.core.types.label import LabelInfo logger = logging.getLogger() ALL_CLASSES_NAME = "All Classes" -def intersection_box( - box1: tuple, - box2: tuple, -) -> tuple[float, float, float, float]: - """Calculate the intersection box of two bounding boxes. - - Args: - box1 (tuple): (x1, y1, x2, y2, class, score) - box2 (tuple): (x1, y1, x2, y2, class, score) - - Returns: - tuple[float, float, float, float]: (x_left, x_right, y_bottom, y_top) - """ - x_left = max(box1[0], box2[0]) - y_top = max(box1[1], box2[1]) - x_right = min(box1[2], box2[2]) - y_bottom = min(box1[3], box2[3]) - return (x_left, x_right, y_bottom, y_top) - - -def bounding_box_intersection_over_union( - box1: tuple, - box2: tuple, -) -> float: - """Calculate the Intersection over Union (IoU) of two bounding boxes. - - Args: - box1 (tuple): (x1, y1, x2, y2, class, score) - box2 (tuple): (x1, y1, x2, y2, class, score) - - Raises: - ValueError: In case the IoU is outside of [0.0, 1.0] - - Returns: - float: Intersection-over-union of box1 and box2. - """ - x_left, x_right, y_bottom, y_top = intersection_box(box1, box2) - - if x_right <= x_left or y_bottom <= y_top: - iou = 0.0 - else: - intersection_area = (x_right - x_left) * (y_bottom - y_top) - bb1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) - bb2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) - union_area = float(bb1_area + bb2_area - intersection_area) - iou = 0.0 if union_area == 0 else intersection_area / union_area - if iou < 0.0 or iou > 1.0: - msg = f"intersection over union should be in range [0,1], actual={iou}" - raise ValueError(msg) - return iou - - -def get_iou_matrix( - ground_truth: list[tuple], - predicted: list[tuple], -) -> np.ndarray: - """Constructs an iou matrix of shape [num_ground_truth_boxes, num_predicted_boxes]. - - Each cell(x,y) in the iou matrix contains the intersection over union of ground truth box(x) and predicted box(y) - An iou matrix corresponds to a single image - - Args: - ground_truth (list[tuple]): list of ground truth boxes. - Each box is a list of (x,y) coordinates and a label. - a box: [x1: float, y1, x2, y2, class: str, score: float] - boxes_per_image: [box1, box2, …] - boxes1: [boxes_per_image_1, boxes_per_image_2, boxes_per_image_3, …] - predicted (list[tuple]): list of predicted boxes. - Each box is a list of (x,y) coordinates and a label. - a box: [x1: float, y1, x2, y2, class: str, score: float] - boxes_per_image: [box1, box2, …] - boxes2: [boxes_per_image_1, boxes_per_image_2, boxes_per_image_3, …] - - Returns: - np.ndarray: IoU matrix of shape [ground_truth_boxes, predicted_boxes] - """ - return np.array( - [[bounding_box_intersection_over_union(gts, preds) for preds in predicted] for gts in ground_truth], - ) - - -def get_n_false_negatives(iou_matrix: np.ndarray, iou_threshold: float) -> int: +def get_n_false_negatives(iou_matrix: Tensor, iou_threshold: float) -> Tensor: """Get the number of false negatives inside the IoU matrix for a given threshold. The first loop accounts for all the ground truth boxes which do not have a high enough iou with any predicted @@ -110,19 +34,20 @@ def get_n_false_negatives(iou_matrix: np.ndarray, iou_threshold: float) -> int: box. The principle is that each ground truth box requires a unique prediction box Args: - iou_matrix (np.ndarray): IoU matrix of shape [ground_truth_boxes, predicted_boxes] + iou_matrix (torch.Tensor): IoU matrix of shape [ground_truth_boxes, predicted_boxes] iou_threshold (float): IoU threshold to use for the false negatives. Returns: - int: Number of false negatives + Tensor: Number of false negatives """ + # First loop n_false_negatives = 0 - for row in iou_matrix: - if max(row) < iou_threshold: - n_false_negatives += 1 - for column in np.rot90(iou_matrix): - indices = np.where(column > iou_threshold) - n_false_negatives += max(len(indices[0]) - 1, 0) + values = torch.max(iou_matrix, 1)[0] < iou_threshold - 1e-6 # 1e-6 is to avoid numerical instability + n_false_negatives += sum(values) + + # Second loop + matrix = torch.sum(iou_threshold < iou_matrix.T, 1) + n_false_negatives += sum(torch.clamp(matrix - 1, min=0)) return n_false_negatives @@ -206,7 +131,6 @@ class _OverallResults: Args: per_confidence (_AggregatedResults): _AggregatedResults object for each confidence level. - per_nms (_AggregatedResults | None): _AggregatedResults object for each NMS threshold. best_f_measure_per_class (dict[str, float]): Best f-measure per class. best_f_measure (float): Best f-measure. """ @@ -214,12 +138,10 @@ class _OverallResults: def __init__( self, per_confidence: _AggregatedResults, - per_nms: _AggregatedResults | None, best_f_measure_per_class: dict[str, float], best_f_measure: float, ): self.per_confidence = per_confidence - self.per_nms = per_nms self.best_f_measure_per_class = best_f_measure_per_class self.best_f_measure = best_f_measure @@ -228,33 +150,20 @@ class _FMeasureCalculator: """This class contains the functions to calculate FMeasure. Args: - ground_truth_boxes_per_image (list[list[tuple]]): - a box: [x1: float, y1, x2, y2, class: str, score: float] - boxes_per_image: [box1, box2, …] - ground_truth_boxes_per_image: [boxes_per_image_1, boxes_per_image_2, boxes_per_image_3, …] - prediction_boxes_per_image (list[list[tuple]]): - a box: [x1: float, y1, x2, y2, class: str, score: float] - boxes_per_image: [box1, box2, …] - predicted_boxes_per_image: [boxes_per_image_1, boxes_per_image_2, boxes_per_image_3, …] + classes (list[str]): List of classes. """ - def __init__( - self, - ground_truth_boxes_per_image: list[list[tuple]], - prediction_boxes_per_image: list[list[tuple]], - ): - self.ground_truth_boxes_per_image = ground_truth_boxes_per_image - self.prediction_boxes_per_image = prediction_boxes_per_image + def __init__(self, classes: list[str]): + self.classes = classes self.confidence_range = [0.025, 1.0, 0.025] self.nms_range = [0.1, 1, 0.05] self.default_confidence_threshold = 0.35 def evaluate_detections( self, - classes: list[str], + gt_entities: list[DetDataEntity], + pred_entities: list[DetPredEntity], iou_threshold: float = 0.5, - result_based_nms_threshold: bool = False, - cross_class_nms: bool = False, ) -> _OverallResults: """Evaluates detections by computing f_measures across multiple confidence thresholds and iou thresholds. @@ -266,51 +175,39 @@ def evaluate_detections( used to achieve them. Args: - classes (list[str]): Names of classes to be evaluated. + gt_entities (list[DetDataEntity]): List of ground truth entities. + pred_entities (list[DetPredEntity]): List of predicted entities. iou_threshold (float): IOU threshold. Defaults to 0.5. - result_based_nms_threshold (bool): Boolean that determines whether multiple nms threshold are examined. - Defaults to False. - cross_class_nms (bool): Set to True to perform NMS between boxes with different classes. Defaults to False. Returns: _OverallResults: _OverallResults object with the result statistics (e.g F-measure). """ best_f_measure_per_class = {} - results_per_confidence = self.get_results_per_confidence( - classes=classes, + results_per_confidence = self._get_results_per_confidence( + classes=self.classes.copy(), + gt_entities=gt_entities, + pred_entities=pred_entities, confidence_range=self.confidence_range, iou_threshold=iou_threshold, ) best_f_measure = results_per_confidence.best_f_measure - for class_name in classes: + for class_name in self.classes: best_f_measure_per_class[class_name] = max(results_per_confidence.f_measure_curve[class_name]) - results_per_nms: _AggregatedResults | None = None - - if result_based_nms_threshold: - results_per_nms = self.get_results_per_nms( - classes=classes, - iou_threshold=iou_threshold, - min_f_measure=results_per_confidence.best_f_measure, - cross_class_nms=cross_class_nms, - ) - - for class_name in classes: - best_f_measure_per_class[class_name] = max(results_per_nms.f_measure_curve[class_name]) - return _OverallResults( results_per_confidence, - results_per_nms, best_f_measure_per_class, best_f_measure, ) - def get_results_per_confidence( + def _get_results_per_confidence( self, classes: list[str], + gt_entities: list[DetDataEntity], + pred_entities: list[DetPredEntity], confidence_range: list[float], iou_threshold: float, ) -> _AggregatedResults: @@ -321,6 +218,8 @@ def get_results_per_confidence( Args: classes (list[str]): Names of classes to be evaluated. + gt_entities (list[DetDataEntity]): List of ground truth entities. + pred_entities (list[DetPredEntity]): List of predicted entities. confidence_range (list[float]): list of confidence thresholds to be evaluated. iou_threshold (float): IoU threshold to use for false negatives. @@ -332,7 +231,9 @@ def get_results_per_confidence( for confidence_threshold in np.arange(*confidence_range): result_point = self.evaluate_classes( - classes=classes.copy(), + gt_entities=gt_entities, + pred_entities=pred_entities, + classes=classes, iou_threshold=iou_threshold, confidence_threshold=confidence_threshold, ) @@ -348,65 +249,10 @@ def get_results_per_confidence( result.best_threshold = confidence_threshold return result - def get_results_per_nms( - self, - classes: list[str], - iou_threshold: float, - min_f_measure: float, - cross_class_nms: bool = False, - ) -> _AggregatedResults: - """Returns results for nms threshold in range nms_range. - - First, we calculate the critical nms of each box, meaning the nms_threshold - that would cause it to be disappear - This is an expensive O(n**2) operation, however, doing this makes filtering for every single nms_threshold much - faster at O(n) - - Args: - classes (list[str]): list of classes - iou_threshold (float): IoU threshold - min_f_measure (float): the minimum F-measure required to select a NMS threshold - cross_class_nms (bool): set to True to perform NMS between boxes with different classes. Defaults to False. - - Returns: - _AggregatedResults: Object containing the results for each NMS threshold value - """ - result = _AggregatedResults(classes) - result.best_f_measure = min_f_measure - result.best_threshold = 0.5 - - critical_nms_per_image = self.__get_critical_nms(self.prediction_boxes_per_image, cross_class_nms) - - for nms_threshold in np.arange(*self.nms_range): - predicted_boxes_per_image_per_nms = self.__filter_nms( - self.prediction_boxes_per_image, - critical_nms_per_image, - nms_threshold, - ) - boxes_pair_for_nms = _FMeasureCalculator( - self.ground_truth_boxes_per_image, - predicted_boxes_per_image_per_nms, - ) - result_point = boxes_pair_for_nms.evaluate_classes( - classes=classes.copy(), - iou_threshold=iou_threshold, - confidence_threshold=self.default_confidence_threshold, - ) - all_classes_f_measure = result_point[ALL_CLASSES_NAME].f_measure - result.all_classes_f_measure_curve.append(all_classes_f_measure) - - for class_name in classes: - result.f_measure_curve[class_name].append(result_point[class_name].f_measure) - result.precision_curve[class_name].append(result_point[class_name].precision) - result.recall_curve[class_name].append(result_point[class_name].recall) - - if all_classes_f_measure > 0.0 and all_classes_f_measure >= result.best_f_measure: - result.best_f_measure = all_classes_f_measure - result.best_threshold = nms_threshold - return result - def evaluate_classes( self, + gt_entities: list[DetDataEntity], + pred_entities: list[DetPredEntity], classes: list[str], iou_threshold: float, confidence_threshold: float, @@ -414,6 +260,8 @@ def evaluate_classes( """Returns dict of f_measure, precision and recall for each class. Args: + gt_entites (list[DetDataEntity]): List of ground truth entities. + pred_entities (list[DetPredEntity]): List of predicted entities. classes (list[str]): list of classes to be evaluated. iou_threshold (float): IoU threshold to use for false negatives. confidence_threshold (float): Confidence threshold to use for false negatives. @@ -427,9 +275,11 @@ def evaluate_classes( if ALL_CLASSES_NAME in classes: classes.remove(ALL_CLASSES_NAME) - for class_name in classes: + for label_idx, class_name in enumerate(classes): metrics, counters = self.get_f_measure_for_class( - class_name=class_name, + gt_entities=gt_entities, + pred_entities=pred_entities, + label_idx=label_idx, iou_threshold=iou_threshold, confidence_threshold=confidence_threshold, ) @@ -444,7 +294,9 @@ def evaluate_classes( def get_f_measure_for_class( self, - class_name: str, + gt_entities: list[DetDataEntity], + pred_entities: list[DetPredEntity], + label_idx: int, iou_threshold: float, confidence_threshold: float, ) -> tuple[_Metrics, _ResultCounters]: @@ -454,7 +306,9 @@ def get_f_measure_for_class( all boxes are filtered at this stage by class and predicted boxes are filtered by confidence threshold Args: - class_name (str): Name of the class for which the F measure is computed + gt_entities (list[DetDataEntity]): List of ground truth entities. + pred_entities (list[DetPredEntity]): List of predicted entities. + label_idx (int): Index of the class for which the boxes are filtered. iou_threshold (float): IoU threshold confidence_threshold (float): Confidence threshold @@ -462,18 +316,22 @@ def get_f_measure_for_class( tuple[_Metrics, _ResultCounters]: a structure containing the statistics (e.g. f_measure) and a structure containing the intermediated counters used to derive the stats (e.g. num. false positives) """ - class_ground_truth_boxes_per_image = self.__filter_class(self.ground_truth_boxes_per_image, class_name) - confidence_predicted_boxes_per_image = self.__filter_confidence( - self.prediction_boxes_per_image, - confidence_threshold, + batch_gt_bboxes = self.__filter_gt( + gt_entities, + label_idx=label_idx, + ) + batch_pred_bboxes = self.__filter_pred( + pred_entities, + label_idx=label_idx, + confidence_threshold=confidence_threshold, ) - class_predicted_boxes_per_image = self.__filter_class(confidence_predicted_boxes_per_image, class_name) - if len(class_ground_truth_boxes_per_image) > 0: - boxes_pair_per_class = _FMeasureCalculator( - ground_truth_boxes_per_image=class_ground_truth_boxes_per_image, - prediction_boxes_per_image=class_predicted_boxes_per_image, + + if len(batch_gt_bboxes) > 0: + result_counters = self.get_counters( + batch_gt=batch_gt_bboxes, + batch_pred=batch_pred_bboxes, + iou_threshold=iou_threshold, ) - result_counters = boxes_pair_per_class.get_counters(iou_threshold=iou_threshold) result_metrics = result_counters.calculate_f_measure() results = (result_metrics, result_counters) else: @@ -483,122 +341,61 @@ def get_f_measure_for_class( return results @staticmethod - def __get_critical_nms( - boxes_per_image: list[list[tuple]], - cross_class_nms: bool = False, - ) -> list[list[float]]: - """Return list of critical NMS values for each box in each image. - - Maps each predicted box to the highest nms-threshold which would suppress that box, aka the smallest - nms_threshold before the box disappears. - Having these values allows us to later filter by nms-threshold in O(n) rather than O(n**2) - Highest losing iou, holds the value of the highest iou that a box has with any - other box of the same class and higher confidence score. - - Args: - boxes_per_image (list[list[tuple]]): list of predicted boxes per - image. - a box: [x1: float, y1, x2, y2, class: str, score: float] - boxes_per_image: [box1, box2, …] - cross_class_nms (bool): Whether to use cross class NMS. - - Returns: - list[list[float]]: list of critical NMS values for each box in each image. - """ - critical_nms_per_image = [] - for boxes in boxes_per_image: - critical_nms_per_box = [] - for box1 in boxes: - highest_losing_iou = 0.0 - for box2 in boxes: - iou = bounding_box_intersection_over_union(box1, box2) - if ( - (cross_class_nms or box1[4] == box2[4]) - and box1[5] < box2[5] # type: ignore[operator] - and iou > highest_losing_iou - ): - highest_losing_iou = iou - critical_nms_per_box.append(highest_losing_iou) - critical_nms_per_image.append(critical_nms_per_box) - return critical_nms_per_image - - @staticmethod - def __filter_nms( - boxes_per_image: list[list[tuple]], - critical_nms: list[list[float]], - nms_threshold: float, - ) -> list[list[tuple]]: - """Filters out predicted boxes whose critical nms is higher than the given nms_threshold. - - Args: - boxes_per_image (list[list[tuple]]): list of boxes per image. - a box: [x1: float, y1, x2, y2, class: str, score: float] - boxes_per_image: [box1, box2, …] - critical_nms (list[list[float]]): list of list of critical nms for each box in each image - nms_threshold (float): NMS threshold used for filtering - - Returns: - list[list[tuple]]: list of list of filtered boxes in each image - """ - new_boxes_per_image = [] - for boxes, boxes_nms in zip(boxes_per_image, critical_nms): - new_boxes = [] - for box, nms in zip(boxes, boxes_nms): - if nms < nms_threshold: - new_boxes.append(box) - new_boxes_per_image.append(new_boxes) - return new_boxes_per_image - - @staticmethod - def __filter_class( - boxes_per_image: list[list[tuple]], - class_name: str, - ) -> list[list[tuple]]: + def __filter_gt( + entities: list[DetDataEntity], + label_idx: int, + ) -> list[Tensor]: """Filters boxes to only keep members of one class. Args: - boxes_per_image (list[list[tuple]]): a list of lists of boxes - class_name (str): Name of the class for which the boxes are filtered + entities (list[DetDataEntity]): a list of DetDataEntity objects containing the ground truth annotations. + label_idx (int): Index of the class for which the boxes are filtered. Returns: - list[list[tuple]]: a list of lists of boxes + list[Tensor]: a list of bounding boxes for label_idx """ - filtered_boxes_per_image = [] - for boxes in boxes_per_image: - filtered_boxes = [box for box in boxes if box[4].lower() == class_name.lower()] - filtered_boxes_per_image.append(filtered_boxes) - return filtered_boxes_per_image + batch_bboxes = [] + for entity in entities: + keep = entity.labels == label_idx + batch_bboxes.append(entity.bboxes[keep]) + return batch_bboxes @staticmethod - def __filter_confidence( - boxes_per_image: list[list[tuple]], + def __filter_pred( + entities: list[DetPredEntity], + label_idx: int, confidence_threshold: float, - ) -> list[list[tuple]]: - """Filters boxes to only keep ones with higher confidence than a given confidence threshold. + ) -> list[Tensor]: + """Filters boxes to only keep members of one class. Args: - boxes_per_image (list[list[tuple]]): - a box: [x1: float, y1, x2, y2, class: str, score: float] - boxes_per_image: [box1, box2, …] + entities (list[DetPredEntity]): a list of DetPredEntity objects containing the predicted boxes. + label_idx (int): Index of the class for which the boxes are filtered. confidence_threshold (float): Confidence threshold Returns: - list[list[tuple]]: Boxes with higher confidence than the given - threshold. + list[list[tuple]]: a list of lists of boxes """ - filtered_boxes_per_image = [] - for boxes in boxes_per_image: - filtered_boxes = [box for box in boxes if float(box[5]) > confidence_threshold] - filtered_boxes_per_image.append(filtered_boxes) - return filtered_boxes_per_image + batch_bboxes = [] + for entity in entities: + keep = (entity.labels == label_idx) & (entity.score > confidence_threshold) + batch_bboxes.append(entity.bboxes[keep]) + return batch_bboxes - def get_counters(self, iou_threshold: float) -> _ResultCounters: + @staticmethod + def get_counters( + batch_gt: list[Tensor], + batch_pred: list[Tensor], + iou_threshold: float, + ) -> _ResultCounters: """Return counts of true positives, false positives and false negatives for a given iou threshold. For each image (the loop), compute the number of false negatives, the number of predicted boxes, and the number of ground truth boxes, then add each value to its corresponding counter Args: + batch_gt (list[Tensor]): List of ground truth boxes + batch_pred (list[Tensor]): List of predicted boxes iou_threshold (float): IoU threshold Returns: @@ -607,18 +404,15 @@ def get_counters(self, iou_threshold: float) -> _ResultCounters: n_false_negatives = 0 n_true = 0 n_predicted = 0 - for ground_truth_boxes, predicted_boxes in zip( - self.ground_truth_boxes_per_image, - self.prediction_boxes_per_image, - ): - n_true += len(ground_truth_boxes) - n_predicted += len(predicted_boxes) - if len(predicted_boxes) > 0: - if len(ground_truth_boxes) > 0: - iou_matrix = get_iou_matrix(ground_truth_boxes, predicted_boxes) + for gt_bboxes, pred_bboxes in zip(batch_gt, batch_pred, strict=True): + n_true += len(gt_bboxes) + n_predicted += len(pred_bboxes) + if len(pred_bboxes) > 0: + if len(gt_bboxes) > 0: + iou_matrix = box_iou(gt_bboxes, pred_bboxes) n_false_negatives += get_n_false_negatives(iou_matrix, iou_threshold) else: - n_false_negatives += len(ground_truth_boxes) + n_false_negatives += len(gt_bboxes) return _ResultCounters(n_false_negatives, n_true, n_predicted) @@ -632,36 +426,20 @@ class FMeasure(Metric): is used based on a minimum intersection-over-union (IoU), by default a value of 0.5 is used. - In addition spurious results are eliminated by applying non-max suppression (NMS) so that two predicted boxes with - IoU > threshold are reduced to one. This threshold can be determined automatically by setting `vary_nms_threshold` - to True. - # TODO(someone): need to update for distriubted training. refer https://lightning.ai/docs/torchmetrics/stable/pages/implement.html Args: label_info (int): Dataclass including label information. - vary_nms_threshold (bool): if True the maximal F-measure is determined by optimizing for different NMS threshold - values. Defaults to False. - cross_class_nms (bool): Whether non-max suppression should be applied cross-class. If True this will eliminate - boxes with sufficient overlap even if they are from different classes. Defaults to False. """ def __init__( self, label_info: LabelInfo, - *, - vary_nms_threshold: bool = False, - cross_class_nms: bool = False, ): super().__init__() - self.vary_nms_threshold = vary_nms_threshold - self.cross_class_nms = cross_class_nms self.label_info: LabelInfo = label_info - self._f_measure_per_confidence: dict | None = None - self._f_measure_per_nms: dict | None = None self._best_confidence_threshold: float | None = None - self._best_nms_threshold: float | None = None self._f_measure = float("-inf") self.reset() @@ -672,27 +450,28 @@ def reset(self) -> None: Please be careful that some variables should not be reset for each epoch. """ super().reset() - self.preds: list[list[tuple]] = [] - self.targets: list[list[tuple]] = [] + self.preds: list[DetPredEntity] = [] + self.targets: list[DetDataEntity] = [] def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update total predictions and targets from given batch predicitons and targets.""" - for pred, tget in zip(preds, target): + for i, (pred, tget) in enumerate(zip(preds, target)): self.preds.append( - [ - (*box, self.classes[label], score) - for box, label, score in zip( - pred["boxes"].tolist(), - pred["labels"].tolist(), - pred["scores"].tolist(), - ) - ], + DetPredEntity( + image=tv_tensors.Image(torch.empty(0, 0)), + img_info=ImageInfo(img_idx=i, img_shape=(0, 0), ori_shape=(0, 0)), + bboxes=pred["boxes"], + score=pred["scores"], + labels=pred["labels"], + ), ) self.targets.append( - [ - (*box, self.classes[label], 0.0) - for box, label in zip(tget["boxes"].tolist(), tget["labels"].tolist()) - ], + DetDataEntity( + image=tv_tensors.Image(torch.empty(0, 0)), + img_info=ImageInfo(img_idx=i, img_shape=(0, 0), ori_shape=(0, 0)), + bboxes=tget["boxes"], + labels=tget["labels"], + ), ) def compute(self, best_confidence_threshold: float | None = None) -> dict: @@ -703,12 +482,8 @@ def compute(self, best_confidence_threshold: float | None = None) -> dict: If this value is None, then FMeasure will find best confidence threshold and store it as member variable. Defaults to None. """ - boxes_pair = _FMeasureCalculator(self.targets, self.preds) - result = boxes_pair.evaluate_detections( - result_based_nms_threshold=self.vary_nms_threshold, - classes=self.classes, - cross_class_nms=self.cross_class_nms, - ) + boxes_pair = _FMeasureCalculator(classes=self.classes) + result = boxes_pair.evaluate_detections(self.targets, self.preds) self._f_measure_per_label = {label: result.best_f_measure_per_class[label] for label in self.classes} if best_confidence_threshold is not None: @@ -733,14 +508,6 @@ def compute(self, best_confidence_threshold: float | None = None) -> dict: if self._f_measure < computed_f_measure: self._f_measure = result.best_f_measure self._best_confidence_threshold = best_confidence_threshold - - if self.vary_nms_threshold and result.per_nms is not None: - self._f_measure_per_nms = { - "xs": list(np.arange(*boxes_pair.nms_range)), - "ys": result.per_nms.all_classes_f_measure_curve, - } - self._best_nms_threshold = result.per_nms.best_threshold - return {"f1-score": Tensor([computed_f_measure])} @property @@ -769,16 +536,6 @@ def best_confidence_threshold(self) -> float: raise RuntimeError(msg) return self._best_confidence_threshold - @property - def f_measure_per_nms(self) -> None | dict: - """Returns the curve for f-measure per nms threshold as CurveMetric if exists.""" - return self._f_measure_per_nms - - @property - def best_nms_threshold(self) -> None | float: - """Returns the best NMS threshold as ScoreMetric if exists.""" - return self._best_nms_threshold - @property def classes(self) -> list[str]: """Class information of dataset.""" @@ -794,7 +551,7 @@ class MeanAveragePrecisionFMeasure(MetricCollection): doing line search on confidence threshold axis. The correct way to obtain the test set F1 score is to use the best confidence threshold obtained from the validation set. - You should use `--metric otx.core.metrics.fmeasure.FMeasureCallable`override + You should use `--metric otx.core.metrics.fmeasure.FMeasureCallable` override to correctly obtain F1 score from a test set. """ diff --git a/tests/unit/core/metrics/test_fmeasure.py b/tests/unit/core/metrics/test_fmeasure.py index b934275b827..f03b347bd72 100644 --- a/tests/unit/core/metrics/test_fmeasure.py +++ b/tests/unit/core/metrics/test_fmeasure.py @@ -5,9 +5,10 @@ from __future__ import annotations +import numpy as np import pytest import torch -from otx.core.metrics.fmeasure import FMeasure +from otx.core.metrics.fmeasure import FMeasure, get_n_false_negatives from otx.core.types.label import LabelInfo @@ -65,3 +66,22 @@ def test_fmeasure_with_fixed_threshold(self, fxt_preds, fxt_targets) -> None: metric.update(fxt_preds, fxt_targets) result = metric.compute(best_confidence_threshold=0.85) assert result["f1-score"] == 0.3333333432674408 + + def test_get_fn(self): + def _get_n_false_negatives_numpy(iou_matrix: np.ndarray, iou_threshold: float) -> int: + n_false_negatives = 0 + for row in iou_matrix: + if max(row) < iou_threshold: + n_false_negatives += 1 + for column in np.rot90(iou_matrix): + indices = np.where(column > iou_threshold) + n_false_negatives += max(len(indices[0]) - 1, 0) + return n_false_negatives + + iou_matrix = torch.rand((10, 20)) + iou_threshold = np.random.rand() + + assert get_n_false_negatives(iou_matrix, iou_threshold) == _get_n_false_negatives_numpy( + iou_matrix.numpy(), + iou_threshold, + )