diff --git a/detect.py b/detect.py new file mode 100644 index 0000000..3f521e4 --- /dev/null +++ b/detect.py @@ -0,0 +1,22 @@ +from sahi.model import Yolov6DetectionModel +from sahi.predict import get_prediction, get_sliced_prediction, predict + +detection_model = Yolov6DetectionModel( + model_path='yolov6s.pt', + confidence_threshold=0.3, + device="cpu", # or 'cuda:0' + image_size=640, +) + +#result = get_prediction("demo/demo_data/highway.jpg", detection_model) + +result = get_sliced_prediction( + 'demo/demo_data/highway.jpg', + detection_model, + slice_height = 1280, + slice_width = 1280, + overlap_height_ratio = 0.6, + overlap_width_ratio = 0.6, +) + +result.export_visuals(export_dir="demo_data/") diff --git a/sahi/__init__.py b/sahi/__init__.py new file mode 100644 index 0000000..6a7690f --- /dev/null +++ b/sahi/__init__.py @@ -0,0 +1,3 @@ +__version__ = "0.10.1" + +from sahi.auto_model import AutoDetectionModel diff --git a/sahi/annotation.py b/sahi/annotation.py new file mode 100644 index 0000000..e17520a --- /dev/null +++ b/sahi/annotation.py @@ -0,0 +1,677 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import copy +from typing import Dict, List, Optional + +import numpy as np + +from sahi.utils.coco import CocoAnnotation, CocoPrediction +from sahi.utils.cv import ( + get_bbox_from_bool_mask, + get_bool_mask_from_coco_segmentation, + get_coco_segmentation_from_bool_mask, +) +from sahi.utils.shapely import ShapelyAnnotation + + +class BoundingBox: + """ + Bounding box of the annotation. + """ + + def __init__(self, box: List[int], shift_amount: List[int] = [0, 0]): + """ + Args: + box: List[int] + [minx, miny, maxx, maxy] + shift_amount: List[int] + To shift the box and mask predictions from sliced image + to full sized image, should be in the form of [shift_x, shift_y] + """ + if box[0] < 0 or box[1] < 0 or box[2] < 0 or box[3] < 0: + raise Exception("Box coords [minx, miny, maxx, maxy] cannot be negative") + self.minx = int(box[0]) + self.miny = int(box[1]) + self.maxx = int(box[2]) + self.maxy = int(box[3]) + + self.shift_x = shift_amount[0] + self.shift_y = shift_amount[1] + + @property + def shift_amount(self): + """ + Returns the shift amount of the bbox slice as [shift_x, shift_y] + """ + return [self.shift_x, self.shift_y] + + def get_expanded_box(self, ratio=0.1, max_x=None, max_y=None): + w = self.maxx - self.minx + h = self.maxy - self.miny + y_mar = int(w * ratio) + x_mar = int(h * ratio) + maxx = min(max_x, self.maxx + x_mar) if max_x else self.maxx + x_mar + minx = max(0, self.minx - x_mar) + maxy = min(max_y, self.maxy + y_mar) if max_y else self.maxy + y_mar + miny = max(0, self.miny - y_mar) + box = [minx, miny, maxx, maxy] + return BoundingBox(box) + + def to_coco_bbox(self): + """ + Returns: [xmin, ymin, width, height] + """ + return [self.minx, self.miny, self.maxx - self.minx, self.maxy - self.miny] + + def to_voc_bbox(self): + """ + Returns: [xmin, ymin, xmax, ymax] + """ + return [self.minx, self.miny, self.maxx, self.maxy] + + def get_shifted_box(self): + """ + Returns: shifted BoundingBox + """ + box = [ + self.minx + self.shift_x, + self.miny + self.shift_y, + self.maxx + self.shift_x, + self.maxy + self.shift_y, + ] + return BoundingBox(box) + + def __repr__(self): + return f"BoundingBox: <{(self.minx, self.miny, self.maxx, self.maxy)}, w: {self.maxx - self.minx}, h: {self.maxy - self.miny}>" + + +class Category: + """ + Category of the annotation. + """ + + def __init__(self, id=None, name=None): + """ + Args: + id: int + ID of the object category + name: str + Name of the object category + """ + if not isinstance(id, int): + raise TypeError("id should be integer") + if not isinstance(name, str): + raise TypeError("name should be string") + self.id = id + self.name = name + + def __repr__(self): + return f"Category: " + + +class Mask: + @classmethod + def from_float_mask( + cls, + mask, + full_shape=None, + mask_threshold: float = 0.5, + shift_amount: list = [0, 0], + ): + """ + Args: + mask: np.ndarray of np.float elements + Mask values between 0 and 1 (should have a shape of height*width) + mask_threshold: float + Value to threshold mask pixels between 0 and 1 + shift_amount: List + To shift the box and mask predictions from sliced image + to full sized image, should be in the form of [shift_x, shift_y] + full_shape: List + Size of the full image after shifting, should be in the form of [height, width] + """ + bool_mask = mask > mask_threshold + return cls( + bool_mask=bool_mask, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + @classmethod + def from_coco_segmentation( + cls, + segmentation, + full_shape=None, + shift_amount: list = [0, 0], + ): + """ + Init Mask from coco segmentation representation. + + Args: + segmentation : List[List] + [ + [x1, y1, x2, y2, x3, y3, ...], + [x1, y1, x2, y2, x3, y3, ...], + ... + ] + full_shape: List + Size of the full image, should be in the form of [height, width] + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + """ + # confirm full_shape is given + if full_shape is None: + raise ValueError("full_shape must be provided") + bool_mask = get_bool_mask_from_coco_segmentation(segmentation, height=full_shape[0], width=full_shape[1]) + return cls( + bool_mask=bool_mask, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + def __init__( + self, + bool_mask=None, + full_shape=None, + shift_amount: list = [0, 0], + ): + """ + Args: + bool_mask: np.ndarray with bool elements + 2D mask of object, should have a shape of height*width + full_shape: List + Size of the full image, should be in the form of [height, width] + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + """ + + if len(bool_mask) > 0: + has_bool_mask = True + else: + has_bool_mask = False + + if has_bool_mask: + self.bool_mask = bool_mask.astype(bool) + else: + self.bool_mask = None + + self.shift_x = shift_amount[0] + self.shift_y = shift_amount[1] + + if full_shape: + self.full_shape_height = full_shape[0] + self.full_shape_width = full_shape[1] + elif has_bool_mask: + self.full_shape_height = self.bool_mask.shape[0] + self.full_shape_width = self.bool_mask.shape[1] + else: + self.full_shape_height = None + self.full_shape_width = None + + @property + def shape(self): + """ + Returns mask shape as [height, width] + """ + return [self.bool_mask.shape[0], self.bool_mask.shape[1]] + + @property + def full_shape(self): + """ + Returns full mask shape after shifting as [height, width] + """ + return [self.full_shape_height, self.full_shape_width] + + @property + def shift_amount(self): + """ + Returns the shift amount of the mask slice as [shift_x, shift_y] + """ + return [self.shift_x, self.shift_y] + + def get_shifted_mask(self): + # Confirm full_shape is specified + if (self.full_shape_height is None) or (self.full_shape_width is None): + raise ValueError("full_shape is None") + # init full mask + mask_fullsized = np.full( + ( + self.full_shape_height, + self.full_shape_width, + ), + 0, + dtype="float32", + ) + + # arrange starting ending indexes + starting_pixel = [self.shift_x, self.shift_y] + ending_pixel = [ + min(starting_pixel[0] + self.bool_mask.shape[1], self.full_shape_width), + min(starting_pixel[1] + self.bool_mask.shape[0], self.full_shape_height), + ] + + # convert sliced mask to full mask + mask_fullsized[starting_pixel[1] : ending_pixel[1], starting_pixel[0] : ending_pixel[0]] = self.bool_mask[ + : ending_pixel[1] - starting_pixel[1], : ending_pixel[0] - starting_pixel[0] + ] + + return Mask( + mask_fullsized, + shift_amount=[0, 0], + full_shape=self.full_shape, + ) + + def to_coco_segmentation(self): + """ + Returns boolean mask as coco segmentation: + [ + [x1, y1, x2, y2, x3, y3, ...], + [x1, y1, x2, y2, x3, y3, ...], + ... + ] + """ + coco_segmentation = get_coco_segmentation_from_bool_mask(self.bool_mask) + return coco_segmentation + + +class ObjectAnnotation: + """ + All about an annotation such as Mask, Category, BoundingBox. + """ + + @classmethod + def from_bool_mask( + cls, + bool_mask, + category_id: Optional[int] = None, + category_name: Optional[str] = None, + shift_amount: Optional[List[int]] = [0, 0], + full_shape: Optional[List[int]] = None, + ): + """ + Creates ObjectAnnotation from bool_mask (2D np.ndarray) + + Args: + bool_mask: np.ndarray with bool elements + 2D mask of object, should have a shape of height*width + category_id: int + ID of the object category + category_name: str + Name of the object category + full_shape: List + Size of the full image, should be in the form of [height, width] + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + """ + return cls( + category_id=category_id, + bool_mask=bool_mask, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + @classmethod + def from_coco_segmentation( + cls, + segmentation, + full_shape: List[int], + category_id: Optional[int] = None, + category_name: Optional[str] = None, + shift_amount: Optional[List[int]] = [0, 0], + ): + """ + Creates ObjectAnnotation from coco segmentation: + [ + [x1, y1, x2, y2, x3, y3, ...], + [x1, y1, x2, y2, x3, y3, ...], + ... + ] + + Args: + segmentation: List[List] + [ + [x1, y1, x2, y2, x3, y3, ...], + [x1, y1, x2, y2, x3, y3, ...], + ... + ] + category_id: int + ID of the object category + category_name: str + Name of the object category + full_shape: List + Size of the full image, should be in the form of [height, width] + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + """ + bool_mask = get_bool_mask_from_coco_segmentation(segmentation, width=full_shape[1], height=full_shape[0]) + return cls( + category_id=category_id, + bool_mask=bool_mask, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + @classmethod + def from_coco_bbox( + cls, + bbox: List[int], + category_id: Optional[int] = None, + category_name: Optional[str] = None, + shift_amount: Optional[List[int]] = [0, 0], + full_shape: Optional[List[int]] = None, + ): + """ + Creates ObjectAnnotation from coco bbox [minx, miny, width, height] + + Args: + bbox: List + [minx, miny, width, height] + category_id: int + ID of the object category + category_name: str + Name of the object category + full_shape: List + Size of the full image, should be in the form of [height, width] + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + """ + xmin = bbox[0] + ymin = bbox[1] + xmax = bbox[0] + bbox[2] + ymax = bbox[1] + bbox[3] + bbox = [xmin, ymin, xmax, ymax] + return cls( + category_id=category_id, + bbox=bbox, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + @classmethod + def from_coco_annotation_dict( + cls, + annotation_dict: Dict, + full_shape: List[int], + category_name: str = None, + shift_amount: Optional[List[int]] = [0, 0], + ): + """ + Creates ObjectAnnotation object from category name and COCO formatted + annotation dict (with fields "bbox", "segmentation", "category_id"). + + Args: + annotation_dict: dict + COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id") + category_name: str + Category name of the annotation + full_shape: List + Size of the full image, should be in the form of [height, width] + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + """ + if annotation_dict["segmentation"]: + return cls.from_coco_segmentation( + segmentation=annotation_dict["segmentation"], + category_id=annotation_dict["category_id"], + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + else: + return cls.from_coco_bbox( + bbox=annotation_dict["bbox"], + category_id=annotation_dict["category_id"], + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + @classmethod + def from_shapely_annotation( + cls, + annotation, + full_shape: List[int], + category_id: Optional[int] = None, + category_name: Optional[str] = None, + shift_amount: Optional[List[int]] = [0, 0], + ): + """ + Creates ObjectAnnotation from shapely_utils.ShapelyAnnotation + + Args: + annotation: shapely_utils.ShapelyAnnotation + category_id: int + ID of the object category + category_name: str + Name of the object category + full_shape: List + Size of the full image, should be in the form of [height, width] + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + """ + bool_mask = get_bool_mask_from_coco_segmentation( + annotation.to_coco_segmentation(), width=full_shape[1], height=full_shape[0] + ) + return cls( + category_id=category_id, + bool_mask=bool_mask, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + @classmethod + def from_imantics_annotation( + cls, + annotation, + shift_amount: Optional[List[int]] = [0, 0], + full_shape: Optional[List[int]] = None, + ): + """ + Creates ObjectAnnotation from imantics.annotation.Annotation + + Args: + annotation: imantics.annotation.Annotation + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + full_shape: List + Size of the full image, should be in the form of [height, width] + """ + return cls( + category_id=annotation.category.id, + bool_mask=annotation.mask.array, + category_name=annotation.category.name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + def __init__( + self, + bbox: Optional[List[int]] = None, + bool_mask: Optional[np.ndarray] = None, + category_id: Optional[int] = None, + category_name: Optional[str] = None, + shift_amount: Optional[List[int]] = [0, 0], + full_shape: Optional[List[int]] = None, + ): + """ + Args: + bbox: List + [minx, miny, maxx, maxy] + bool_mask: np.ndarray with bool elements + 2D mask of object, should have a shape of height*width + category_id: int + ID of the object category + category_name: str + Name of the object category + shift_amount: List + To shift the box and mask predictions from sliced image + to full sized image, should be in the form of [shift_x, shift_y] + full_shape: List + Size of the full image after shifting, should be in + the form of [height, width] + """ + if not isinstance(category_id, int): + raise ValueError("category_id must be an integer") + if (bbox is None) and (bool_mask is None): + raise ValueError("you must provide a bbox or bool_mask") + + if bool_mask is not None: + self.mask = Mask( + bool_mask=bool_mask, + shift_amount=shift_amount, + full_shape=full_shape, + ) + bbox_from_bool_mask = get_bbox_from_bool_mask(bool_mask) + # https://github.com/obss/sahi/issues/235 + if bbox_from_bool_mask is not None: + bbox = bbox_from_bool_mask + else: + raise ValueError("Invalid boolean mask.") + else: + self.mask = None + + # make sure bbox coords lie inside [0, image_size] + xmin = max(bbox[0], 0) + ymin = max(bbox[1], 0) + if full_shape: + xmax = min(bbox[2], full_shape[1]) + ymax = min(bbox[3], full_shape[0]) + else: + xmax = bbox[2] + ymax = bbox[3] + bbox = [xmin, ymin, xmax, ymax] + # set bbox + self.bbox = BoundingBox(bbox, shift_amount) + + category_name = category_name if category_name else str(category_id) + self.category = Category( + id=category_id, + name=category_name, + ) + + self.merged = None + + def to_coco_annotation(self): + """ + Returns sahi.utils.coco.CocoAnnotation representation of ObjectAnnotation. + """ + if self.mask: + coco_annotation = CocoAnnotation.from_coco_segmentation( + segmentation=self.mask.to_coco_segmentation(), + category_id=self.category.id, + category_name=self.category.name, + ) + else: + coco_annotation = CocoAnnotation.from_coco_bbox( + bbox=self.bbox.to_coco_bbox(), + category_id=self.category.id, + category_name=self.category.name, + ) + return coco_annotation + + def to_coco_prediction(self): + """ + Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation. + """ + if self.mask: + coco_prediction = CocoPrediction.from_coco_segmentation( + segmentation=self.mask.to_coco_segmentation(), + category_id=self.category.id, + category_name=self.category.name, + score=1, + ) + else: + coco_prediction = CocoPrediction.from_coco_bbox( + bbox=self.bbox.to_coco_bbox(), + category_id=self.category.id, + category_name=self.category.name, + score=1, + ) + return coco_prediction + + def to_shapely_annotation(self): + """ + Returns sahi.utils.shapely.ShapelyAnnotation representation of ObjectAnnotation. + """ + if self.mask: + shapely_annotation = ShapelyAnnotation.from_coco_segmentation( + segmentation=self.mask.to_coco_segmentation(), + ) + else: + shapely_annotation = ShapelyAnnotation.from_coco_bbox( + bbox=self.bbox.to_coco_bbox(), + ) + return shapely_annotation + + def to_imantics_annotation(self): + """ + Returns imantics.annotation.Annotation representation of ObjectAnnotation. + """ + try: + import imantics + except ImportError: + raise ImportError( + 'Please run "pip install -U imantics" ' "to install imantics first for imantics conversion." + ) + + imantics_category = imantics.Category(id=self.category.id, name=self.category.name) + if self.mask is not None: + imantics_mask = imantics.Mask.create(self.mask.bool_mask) + imantics_annotation = imantics.annotation.Annotation.from_mask( + mask=imantics_mask, category=imantics_category + ) + else: + imantics_bbox = imantics.BBox.create(self.bbox.to_voc_bbox()) + imantics_annotation = imantics.annotation.Annotation.from_bbox( + bbox=imantics_bbox, category=imantics_category + ) + return imantics_annotation + + def deepcopy(self): + """ + Returns: deepcopy of current ObjectAnnotation instance + """ + return copy.deepcopy(self) + + @classmethod + def get_empty_mask(cls): + return Mask(bool_mask=None) + + def get_shifted_object_annotation(self): + if self.mask: + return ObjectAnnotation( + bbox=self.bbox.get_shifted_box().to_voc_bbox(), + category_id=self.category.id, + bool_mask=self.mask.get_shifted_mask().bool_mask, + category_name=self.category.name, + shift_amount=[0, 0], + full_shape=self.mask.get_shifted_mask().full_shape, + ) + else: + return ObjectAnnotation( + bbox=self.bbox.get_shifted_box().to_voc_bbox(), + category_id=self.category.id, + bool_mask=None, + category_name=self.category.name, + shift_amount=[0, 0], + full_shape=None, + ) + + def __repr__(self): + return f"""ObjectAnnotation< + bbox: {self.bbox}, + mask: {self.mask}, + category: {self.category}>""" diff --git a/sahi/auto_model.py b/sahi/auto_model.py new file mode 100644 index 0000000..edd6b27 --- /dev/null +++ b/sahi/auto_model.py @@ -0,0 +1,134 @@ +from typing import Dict, Optional + +from sahi.utils.file import import_model_class +from sahi.utils.import_utils import check_requirements + +MODEL_TYPE_TO_MODEL_CLASS_NAME = { + "mmdet": "MmdetDetectionModel", + "yolov5": "Yolov5DetectionModel", + "detectron2": "Detectron2DetectionModel", + "huggingface": "HuggingfaceDetectionModel", + "torchvision": "TorchVisionDetectionModel", +} + + +class AutoDetectionModel: + @staticmethod + def from_pretrained( + model_type: str, + model_path: str, + config_path: Optional[str] = None, + device: Optional[str] = None, + mask_threshold: float = 0.5, + confidence_threshold: float = 0.3, + category_mapping: Optional[Dict] = None, + category_remapping: Optional[Dict] = None, + load_at_init: bool = True, + image_size: int = None, + **kwargs, + ): + """ + Loads a DetectionModel from given path. + + Args: + model_type: str + Name of the detection framework (example: "yolov5", "mmdet", "detectron2") + model_path: str + Path of the Layer model (ex. '/sahi/yolo/models/yolov5') + config_path: str + Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py') + device: str + Device, "cpu" or "cuda:0" + mask_threshold: float + Value to threshold mask pixels, should be between 0 and 1 + confidence_threshold: float + All predictions with score < confidence_threshold will be discarded + category_mapping: dict: str to str + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + category_remapping: dict: str to int + Remap category ids based on category names, after performing inference e.g. {"car": 3} + load_at_init: bool + If True, automatically loads the model at initalization + image_size: int + Inference input size. + Returns: + Returns an instance of a DetectionModel + Raises: + ImportError: If given {model_type} framework is not installed + """ + + model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type] + DetectionModel = import_model_class(model_class_name) + + return DetectionModel( + model_path=model_path, + config_path=config_path, + device=device, + mask_threshold=mask_threshold, + confidence_threshold=confidence_threshold, + category_mapping=category_mapping, + category_remapping=category_remapping, + load_at_init=load_at_init, + image_size=image_size, + **kwargs, + ) + + @staticmethod + @check_requirements(["layer"]) + def from_layer( + model_path: str, + no_cache: bool = False, + device: Optional[str] = None, + mask_threshold: float = 0.5, + confidence_threshold: float = 0.3, + category_mapping: Optional[Dict] = None, + category_remapping: Optional[Dict] = None, + image_size: int = None, + ): + """ + Loads a DetectionModel from Layer. You can pass additional parameters in the name to retrieve a specific version + of the model with format: ``model_path:major_version.minor_version`` + By default, this function caches models locally when possible. + Args: + model_path: str + Path of the Layer model (ex. '/sahi/yolo/models/yolov5') + no_cache: bool + If True, force model fetch from the remote location. + device: str + Device, "cpu" or "cuda:0" + mask_threshold: float + Value to threshold mask pixels, should be between 0 and 1 + confidence_threshold: float + All predictions with score < confidence_threshold will be discarded + category_mapping: dict: str to str + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + category_remapping: dict: str to int + Remap category ids based on category names, after performing inference e.g. {"car": 3} + image_size: int + Inference input size. + Returns: + Returns an instance of a DetectionModel + Raises: + ImportError: If Layer is not installed in your environment + ValueError: If model path does not match expected pattern: organization_name/project_name/models/model_name + """ + import layer + + layer_model = layer.get_model(name=model_path, no_cache=no_cache).get_train() + if layer_model.__class__.__module__ in ["yolov5.models.common", "models.common"]: + model_type = "yolov5" + else: + raise Exception(f"Unsupported model: {type(layer_model)}. Only YOLOv5 models are supported.") + + model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type] + DetectionModel = import_model_class(model_class_name) + + return DetectionModel( + model=layer_model, + device=device, + mask_threshold=mask_threshold, + confidence_threshold=confidence_threshold, + category_mapping=category_mapping, + category_remapping=category_remapping, + image_size=image_size, + ) diff --git a/sahi/cli.py b/sahi/cli.py new file mode 100644 index 0000000..ea711f2 --- /dev/null +++ b/sahi/cli.py @@ -0,0 +1,35 @@ +import fire + +from sahi import __version__ as sahi_version +from sahi.predict import predict, predict_fiftyone +from sahi.scripts.coco2fiftyone import main as coco2fiftyone +from sahi.scripts.coco2yolov5 import main as coco2yolov5 +from sahi.scripts.coco_error_analysis import analyse +from sahi.scripts.coco_evaluation import evaluate +from sahi.scripts.slice_coco import slice +from sahi.utils.import_utils import print_enviroment_info + +coco_app = { + "evaluate": evaluate, + "analyse": analyse, + "fiftyone": coco2fiftyone, + "slice": slice, + "yolov5": coco2yolov5, +} + +sahi_app = { + "predict": predict, + "predict-fiftyone": predict_fiftyone, + "coco": coco_app, + "version": sahi_version, + "env": print_enviroment_info, +} + + +def app() -> None: + """Cli app.""" + fire.Fire(sahi_app) + + +if __name__ == "__main__": + app() diff --git a/sahi/model.py b/sahi/model.py new file mode 100644 index 0000000..936644b --- /dev/null +++ b/sahi/model.py @@ -0,0 +1,1135 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from sahi.prediction import ObjectPrediction +from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list +from sahi.utils.cv import get_bbox_from_bool_mask +from sahi.utils.import_utils import check_requirements, is_available +from sahi.utils.torch import is_torch_cuda_available +import os +import sys +from pathlib import Path + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[0] + +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +if str(ROOT / 'YOLOv6') not in sys.path: + sys.path.append(str(ROOT / 'YOLOv6')) # add YOLOv6 ROOT to PATH + +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +from YOLOv6.yolov6.core.inferer import Inferer +import torch + +from YOLOv6.yolov6.layers.common import DetectBackend +from YOLOv6.yolov6.utils.nms import non_max_suppression +logger = logging.getLogger(__name__) +from yolo6 import precess_image, COCO_CLASSES, check_img_size + +class DetectionModel: + def __init__( + self, + model_path: Optional[str] = None, + model: Optional[Any] = None, + config_path: Optional[str] = None, + device: Optional[str] = None, + mask_threshold: float = 0.5, + confidence_threshold: float = 0.3, + category_mapping: Optional[Dict] = None, + category_remapping: Optional[Dict] = None, + load_at_init: bool = True, + image_size: int = None, + ): + """ + Init object detection/instance segmentation model. + Args: + model_path: str + Path for the instance segmentation model weight + config_path: str + Path for the mmdetection instance segmentation model config file + device: str + Torch device, "cpu" or "cuda" + mask_threshold: float + Value to threshold mask pixels, should be between 0 and 1 + confidence_threshold: float + All predictions with score < confidence_threshold will be discarded + category_mapping: dict: str to str + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + category_remapping: dict: str to int + Remap category ids based on category names, after performing inference e.g. {"car": 3} + load_at_init: bool + If True, automatically loads the model at initalization + image_size: int + Inference input size. + """ + self.model_path = model_path + self.config_path = config_path + self.model = None + self.device = device + self.mask_threshold = mask_threshold + self.confidence_threshold = confidence_threshold + self.category_mapping = category_mapping + self.category_remapping = category_remapping + self.image_size = image_size + self._original_predictions = None + self._object_prediction_list_per_image = None + self.image_size = check_img_size(img_size=self.image_size) + + + # automatically set device if its None + if not (self.device): + self.device = "cuda:0" if is_torch_cuda_available() else "cpu" + + # automatically load model if load_at_init is True + if load_at_init: + if model: + self.set_model(model) + else: + self.load_model() + + def load_model(self): + """ + This function should be implemented in a way that detection model + should be initialized and set to self.model. + (self.model_path, self.config_path, and self.device should be utilized) + """ + raise NotImplementedError() + + def set_model(self, model: Any, **kwargs): + """ + This function should be implemented to instantiate a DetectionModel out of an already loaded model + Args: + model: Any + Loaded model + """ + raise NotImplementedError() + + def unload_model(self): + """ + Unloads the model from CPU/GPU. + """ + self.model = None + if is_available("torch"): + from sahi.utils.torch import empty_cuda_cache + + empty_cuda_cache() + + def perform_inference(self, image: np.ndarray): + """ + This function should be implemented in a way that prediction should be + performed using self.model and the prediction result should be set to self._original_predictions. + Args: + image: np.ndarray + A numpy array that contains the image to be predicted. + """ + raise NotImplementedError() + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + """ + This function should be implemented in a way that self._original_predictions should + be converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list. self.mask_threshold can also be utilized. + Args: + shift_amount_list: list of list + To shift the box and mask predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] + """ + raise NotImplementedError() + + def _apply_category_remapping(self): + """ + Applies category remapping based on mapping given in self.category_remapping + """ + # confirm self.category_remapping is not None + if self.category_remapping is None: + raise ValueError("self.category_remapping cannot be None") + # remap categories + for object_prediction_list in self._object_prediction_list_per_image: + for object_prediction in object_prediction_list: + old_category_id_str = str(object_prediction.category.id) + new_category_id_int = self.category_remapping[old_category_id_str] + object_prediction.category.id = new_category_id_int + + def convert_original_predictions( + self, + shift_amount: Optional[List[int]] = [0, 0], + full_shape: Optional[List[int]] = None, + ): + """ + Converts original predictions of the detection model to a list of + prediction.ObjectPrediction object. Should be called after perform_inference(). + Args: + shift_amount: list + To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] + full_shape: list + Size of the full image after shifting, should be in the form of [height, width] + """ + self._create_object_prediction_list_from_original_predictions( + shift_amount_list=shift_amount, + full_shape_list=full_shape, + ) + if self.category_remapping: + self._apply_category_remapping() + + @property + def object_prediction_list(self): + return self._object_prediction_list_per_image[0] + + @property + def object_prediction_list_per_image(self): + return self._object_prediction_list_per_image + + @property + def original_predictions(self): + return self._original_predictions + + +@check_requirements(["torch", "mmdet", "mmcv"]) +class MmdetDetectionModel(DetectionModel): + def load_model(self): + """ + Detection model is initialized and set to self.model. + """ + try: + import mmdet + except ImportError: + raise ImportError( + 'Please run "pip install -U mmcv mmdet" ' "to install MMDetection first for MMDetection inference." + ) + + from mmdet.apis import init_detector + + # create model + model = init_detector( + config=self.config_path, + checkpoint=self.model_path, + device=self.device, + ) + + # update model image size + if self.image_size is not None: + model.cfg.data.test.pipeline[1]["img_scale"] = (self.image_size, self.image_size) + + # set self.model + self.model = model + + # set category_mapping + if not self.category_mapping: + category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} + self.category_mapping = category_mapping + + def perform_inference(self, image: np.ndarray): + """ + Prediction is performed using self.model and the prediction result is set to self._original_predictions. + Args: + image: np.ndarray + A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. + """ + try: + import mmdet + except ImportError: + raise ImportError( + 'Please run "pip install -U mmcv mmdet" ' "to install MMDetection first for MMDetection inference." + ) + + # Confirm model is loaded + if self.model is None: + raise ValueError("Model is not loaded, load it by calling .load_model()") + # Supports only batch of 1 + from mmdet.apis import inference_detector + + # perform inference + if isinstance(image, np.ndarray): + # https://github.com/obss/sahi/issues/265 + image = image[:, :, ::-1] + # compatibility with sahi v0.8.15 + if not isinstance(image, list): + image = [image] + prediction_result = inference_detector(self.model, image) + + self._original_predictions = prediction_result + + @property + def num_categories(self): + """ + Returns number of categories + """ + if isinstance(self.model.CLASSES, str): + num_categories = 1 + else: + num_categories = len(self.model.CLASSES) + return num_categories + + @property + def has_mask(self): + """ + Returns if model output contains segmentation mask + """ + has_mask = self.model.with_mask + return has_mask + + @property + def category_names(self): + if type(self.model.CLASSES) == str: + # https://github.com/open-mmlab/mmdetection/pull/4973 + return (self.model.CLASSES,) + else: + return self.model.CLASSES + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + """ + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list_per_image. + Args: + shift_amount_list: list of list + To shift the box and mask predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] + """ + original_predictions = self._original_predictions + category_mapping = self.category_mapping + + # compatilibty for sahi v0.8.15 + shift_amount_list = fix_shift_amount_list(shift_amount_list) + full_shape_list = fix_full_shape_list(full_shape_list) + + # parse boxes and masks from predictions + num_categories = self.num_categories + object_prediction_list_per_image = [] + for image_ind, original_prediction in enumerate(original_predictions): + shift_amount = shift_amount_list[image_ind] + full_shape = None if full_shape_list is None else full_shape_list[image_ind] + + if self.has_mask: + boxes = original_prediction[0] + masks = original_prediction[1] + else: + boxes = original_prediction + + object_prediction_list = [] + + # process predictions + for category_id in range(num_categories): + category_boxes = boxes[category_id] + if self.has_mask: + category_masks = masks[category_id] + num_category_predictions = len(category_boxes) + + for category_predictions_ind in range(num_category_predictions): + bbox = category_boxes[category_predictions_ind][:4] + score = category_boxes[category_predictions_ind][4] + category_name = category_mapping[str(category_id)] + + # ignore low scored predictions + if score < self.confidence_threshold: + continue + + # parse prediction mask + if self.has_mask: + bool_mask = category_masks[category_predictions_ind] + # check if mask is valid + # https://github.com/obss/sahi/issues/389 + if get_bbox_from_bool_mask(bool_mask) is None: + continue + else: + bool_mask = None + + # fix negative box coords + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = max(0, bbox[2]) + bbox[3] = max(0, bbox[3]) + + # fix out of image box coords + if full_shape is not None: + bbox[0] = min(full_shape[1], bbox[0]) + bbox[1] = min(full_shape[0], bbox[1]) + bbox[2] = min(full_shape[1], bbox[2]) + bbox[3] = min(full_shape[0], bbox[3]) + + # ignore invalid predictions + if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): + logger.warning(f"ignoring invalid prediction with bbox: {bbox}") + continue + + object_prediction = ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + bool_mask=bool_mask, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) + self._object_prediction_list_per_image = object_prediction_list_per_image + + +@check_requirements(["torch", "yolov5"]) +class Yolov5DetectionModel(DetectionModel): + def load_model(self): + """ + Detection model is initialized and set to self.model. + """ + import yolov5 + + try: + model = yolov5.load(self.model_path, device=self.device) + self.set_model(model) + except Exception as e: + raise TypeError("model_path is not a valid yolov5 model path: ", e) + + def set_model(self, model: Any): + """ + Sets the underlying YOLOv5 model. + Args: + model: Any + A YOLOv5 model + """ + + if model.__class__.__module__ not in ["yolov5.models.common", "models.common"]: + raise Exception(f"Not a yolov5 model: {type(model)}") + + model.conf = self.confidence_threshold + self.model = model + + # set category_mapping + if not self.category_mapping: + category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} + self.category_mapping = category_mapping + + def perform_inference(self, image: np.ndarray): + """ + Prediction is performed using self.model and the prediction result is set to self._original_predictions. + Args: + image: np.ndarray + A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. + """ + + # Confirm model is loaded + if self.model is None: + raise ValueError("Model is not loaded, load it by calling .load_model()") + if self.image_size is not None: + prediction_result = self.model(image, size=self.image_size) + else: + prediction_result = self.model(image) + + self._original_predictions = prediction_result + + @property + def num_categories(self): + """ + Returns number of categories + """ + return len(self.model.names) + + @property + def has_mask(self): + """ + Returns if model output contains segmentation mask + """ + has_mask = self.model.with_mask + return has_mask + + @property + def category_names(self): + return self.model.names + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + """ + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list_per_image. + Args: + shift_amount_list: list of list + To shift the box and mask predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] + """ + original_predictions = self._original_predictions + + # compatilibty for sahi v0.8.15 + shift_amount_list = fix_shift_amount_list(shift_amount_list) + full_shape_list = fix_full_shape_list(full_shape_list) + + # handle all predictions + object_prediction_list_per_image = [] + for image_ind, image_predictions_in_xyxy_format in enumerate(original_predictions.xyxy): + shift_amount = shift_amount_list[image_ind] + full_shape = None if full_shape_list is None else full_shape_list[image_ind] + object_prediction_list = [] + + # process predictions + for prediction in image_predictions_in_xyxy_format.cpu().detach().numpy(): + x1 = int(prediction[0]) + y1 = int(prediction[1]) + x2 = int(prediction[2]) + y2 = int(prediction[3]) + bbox = [x1, y1, x2, y2] + score = prediction[4] + category_id = int(prediction[5]) + category_name = self.category_mapping[str(category_id)] + + # fix negative box coords + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = max(0, bbox[2]) + bbox[3] = max(0, bbox[3]) + + # fix out of image box coords + if full_shape is not None: + bbox[0] = min(full_shape[1], bbox[0]) + bbox[1] = min(full_shape[0], bbox[1]) + bbox[2] = min(full_shape[1], bbox[2]) + bbox[3] = min(full_shape[0], bbox[3]) + + # ignore invalid predictions + if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): + logger.warning(f"ignoring invalid prediction with bbox: {bbox}") + continue + + object_prediction = ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + bool_mask=None, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) + + self._object_prediction_list_per_image = object_prediction_list_per_image + + +@check_requirements(["torch", "detectron2"]) +class Detectron2DetectionModel(DetectionModel): + def load_model(self): + from detectron2.config import get_cfg + from detectron2.data import MetadataCatalog + from detectron2.engine import DefaultPredictor + from detectron2.model_zoo import model_zoo + + cfg = get_cfg() + + try: # try to load from model zoo + config_file = model_zoo.get_config_file(self.config_path) + cfg.merge_from_file(config_file) + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.config_path) + except Exception as e: # try to load from local + print(e) + if self.config_path is not None: + cfg.merge_from_file(self.config_path) + cfg.MODEL.WEIGHTS = self.model_path + + # set model device + cfg.MODEL.DEVICE = self.device + # set input image size + if self.image_size is not None: + cfg.INPUT.MIN_SIZE_TEST = self.image_size + cfg.INPUT.MAX_SIZE_TEST = self.image_size + # init predictor + model = DefaultPredictor(cfg) + + self.model = model + + # detectron2 category mapping + if self.category_mapping is None: + try: # try to parse category names from metadata + metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) + category_names = metadata.thing_classes + self.category_names = category_names + self.category_mapping = { + str(ind): category_name for ind, category_name in enumerate(self.category_names) + } + except Exception as e: + logger.warning(e) + # https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html#update-the-config-for-new-datasets + if cfg.MODEL.META_ARCHITECTURE == "RetinaNet": + num_categories = cfg.MODEL.RETINANET.NUM_CLASSES + else: # fasterrcnn/maskrcnn etc + num_categories = cfg.MODEL.ROI_HEADS.NUM_CLASSES + self.category_names = [str(category_id) for category_id in range(num_categories)] + self.category_mapping = { + str(ind): category_name for ind, category_name in enumerate(self.category_names) + } + else: + self.category_names = list(self.category_mapping.values()) + + def perform_inference(self, image: np.ndarray): + """ + Prediction is performed using self.model and the prediction result is set to self._original_predictions. + Args: + image: np.ndarray + A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. + """ + + # Confirm model is loaded + if self.model is None: + raise RuntimeError("Model is not loaded, load it by calling .load_model()") + + if isinstance(image, np.ndarray) and self.model.input_format == "BGR": + # convert RGB image to BGR format + image = image[:, :, ::-1] + + prediction_result = self.model(image) + + self._original_predictions = prediction_result + + @property + def num_categories(self): + """ + Returns number of categories + """ + num_categories = len(self.category_mapping) + return num_categories + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + """ + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list_per_image. + Args: + shift_amount_list: list of list + To shift the box and mask predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] + """ + original_predictions = self._original_predictions + + # compatilibty for sahi v0.8.15 + if isinstance(shift_amount_list[0], int): + shift_amount_list = [shift_amount_list] + if full_shape_list is not None and isinstance(full_shape_list[0], int): + full_shape_list = [full_shape_list] + + # parse boxes, masks, scores, category_ids from predictions + boxes = original_predictions["instances"].pred_boxes.tensor.tolist() + scores = original_predictions["instances"].scores.tolist() + category_ids = original_predictions["instances"].pred_classes.tolist() + + # check if predictions contain mask + try: + masks = original_predictions["instances"].pred_masks.tolist() + except AttributeError: + masks = None + + # create object_prediction_list + object_prediction_list_per_image = [] + object_prediction_list = [] + + # detectron2 DefaultPredictor supports single image + shift_amount = shift_amount_list[0] + full_shape = None if full_shape_list is None else full_shape_list[0] + + for ind in range(len(boxes)): + score = scores[ind] + if score < self.confidence_threshold: + continue + + category_id = category_ids[ind] + + if masks is None: + bbox = boxes[ind] + mask = None + else: + mask = np.array(masks[ind]) + + # check if mask is valid + # https://github.com/obss/sahi/issues/389 + if get_bbox_from_bool_mask(mask) is None: + continue + else: + bbox = None + + object_prediction = ObjectPrediction( + bbox=bbox, + bool_mask=mask, + category_id=category_id, + category_name=self.category_mapping[str(category_id)], + shift_amount=shift_amount, + score=score, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + + # detectron2 DefaultPredictor supports single image + object_prediction_list_per_image = [object_prediction_list] + + self._object_prediction_list_per_image = object_prediction_list_per_image + + +@check_requirements(["torch", "transformers"]) +class HuggingfaceDetectionModel(DetectionModel): + import torch + + def __init__( + self, + model_path: Optional[str] = None, + model: Optional[Any] = None, + feature_extractor: Optional[Any] = None, + config_path: Optional[str] = None, + device: Optional[str] = None, + mask_threshold: float = 0.5, + confidence_threshold: float = 0.3, + category_mapping: Optional[Dict] = None, + category_remapping: Optional[Dict] = None, + load_at_init: bool = True, + image_size: int = None, + ): + self._feature_extractor = feature_extractor + self._image_shapes = [] + super().__init__( + model_path, + model, + config_path, + device, + mask_threshold, + confidence_threshold, + category_mapping, + category_remapping, + load_at_init, + image_size, + ) + + @property + def feature_extractor(self): + return self._feature_extractor + + @property + def image_shapes(self): + return self._image_shapes + + @property + def num_categories(self) -> int: + """ + Returns number of categories + """ + return self.model.config.num_labels + + def load_model(self): + from transformers import AutoFeatureExtractor, AutoModelForObjectDetection + + model = AutoModelForObjectDetection.from_pretrained(self.model_path) + if self.image_size is not None: + feature_extractor = AutoFeatureExtractor.from_pretrained( + self.model_path, size=self.image_size, do_resize=True + ) + else: + feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path) + self.set_model(model, feature_extractor) + + def set_model(self, model: Any, feature_extractor: Any = None): + feature_extractor = feature_extractor or self.feature_extractor + if feature_extractor is None: + raise ValueError(f"'feature_extractor' is required to be set, got {feature_extractor}.") + elif ( + "ObjectDetection" not in model.__class__.__name__ + or "FeatureExtractor" not in feature_extractor.__class__.__name__ + ): + raise ValueError( + f"Given 'model' is not an ObjectDetectionModel or 'feature_extractor' is not a valid FeatureExtractor." + ) + self.model = model + self.model.to(self.device) + self._feature_extractor = feature_extractor + self.category_mapping = self.model.config.id2label + + def perform_inference(self, image: Union[List, np.ndarray]): + """ + Prediction is performed using self.model and the prediction result is set to self._original_predictions. + Args: + image: np.ndarray + A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. + """ + import torch + + # Confirm model is loaded + if self.model is None: + raise RuntimeError("Model is not loaded, load it by calling .load_model()") + + with torch.no_grad(): + inputs = self.feature_extractor(images=image, return_tensors="pt") + inputs["pixel_values"] = inputs.pixel_values.to(self.device) + if hasattr(inputs, "pixel_mask"): + inputs["pixel_mask"] = inputs.pixel_mask.to(self.device) + outputs = self.model(**inputs) + + if isinstance(image, list): + self._image_shapes = [img.shape for img in image] + else: + self._image_shapes = [image.shape] + self._original_predictions = outputs + + def get_valid_predictions( + self, logits: torch.Tensor, pred_boxes: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import torch + + probs = logits.softmax(-1) + scores = probs.max(-1).values + cat_ids = probs.argmax(-1) + valid_detections = torch.where(cat_ids < self.num_categories, 1, 0) + valid_confidences = torch.where(scores >= self.confidence_threshold, 1, 0) + valid_mask = valid_detections.logical_and(valid_confidences) + scores = scores[valid_mask] + cat_ids = cat_ids[valid_mask] + boxes = pred_boxes[valid_mask] + return scores, cat_ids, boxes + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + """ + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list_per_image. + Args: + shift_amount_list: list of list + To shift the box and mask predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] + """ + original_predictions = self._original_predictions + + # compatilibty for sahi v0.8.15 + shift_amount_list = fix_shift_amount_list(shift_amount_list) + full_shape_list = fix_full_shape_list(full_shape_list) + + n_image = original_predictions.logits.shape[0] + object_prediction_list_per_image = [] + for image_ind in range(n_image): + image_height, image_width, _ = self.image_shapes[image_ind] + scores, cat_ids, boxes = self.get_valid_predictions( + logits=original_predictions.logits[image_ind], pred_boxes=original_predictions.pred_boxes[image_ind] + ) + + # create object_prediction_list + object_prediction_list = [] + + shift_amount = shift_amount_list[image_ind] + full_shape = None if full_shape_list is None else full_shape_list[image_ind] + + for ind in range(len(boxes)): + category_id = cat_ids[ind].item() + yolo_bbox = boxes[ind].tolist() + bbox = list( + pbf.convert_bbox( + yolo_bbox, + from_type="yolo", + to_type="voc", + image_size=(image_width, image_height), + return_values=True, + ) + ) + + # fix negative box coords + bbox[0] = max(0, int(bbox[0])) + bbox[1] = max(0, int(bbox[1])) + bbox[2] = min(bbox[2], image_width) + bbox[3] = min(bbox[3], image_height) + + object_prediction = ObjectPrediction( + bbox=bbox, + bool_mask=None, + category_id=category_id, + category_name=self.category_mapping[category_id], + shift_amount=shift_amount, + score=scores[ind].item(), + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) + + self._object_prediction_list_per_image = object_prediction_list_per_image + + +@check_requirements(["torch", "torchvision"]) +class TorchVisionDetectionModel(DetectionModel): + def __init__( + self, + model_path: Optional[str] = None, + model: Optional[Any] = None, + config_path: Optional[str] = None, + device: Optional[str] = None, + mask_threshold: float = 0.5, + confidence_threshold: float = 0.3, + category_mapping: Optional[Dict] = None, + category_remapping: Optional[Dict] = None, + load_at_init: bool = True, + image_size: int = None, + ): + + super().__init__( + model_path=model_path, + model=model, + config_path=config_path, + device=device, + mask_threshold=mask_threshold, + confidence_threshold=confidence_threshold, + category_mapping=category_mapping, + category_remapping=category_remapping, + load_at_init=load_at_init, + image_size=image_size, + ) + + def load_model(self): + import torch + + from sahi.utils.torchvision import MODEL_NAME_TO_CONSTRUCTOR + + # read config params + model_name = None + num_classes = None + if self.config_path is not None: + import yaml + + with open(self.config_path, "r") as stream: + try: + config = yaml.safe_load(stream) + except yaml.YAMLError as exc: + raise RuntimeError(exc) + + model_name = config.get("model_name", None) + num_classes = config.get("num_classes", None) + + # complete params if not provided in config + if not model_name: + model_name = "fasterrcnn_resnet50_fpn" + logger.warning(f"model_name not provided in config, using default model_type: {model_name}'") + if num_classes is None: + logger.warning("num_classes not provided in config, using default num_classes: 91") + num_classes = 91 + if self.model_path is None: + logger.warning("model_path not provided in config, using pretrained weights and default num_classes: 91.") + pretrained = True + num_classes = 91 + else: + pretrained = False + + # load model + model = MODEL_NAME_TO_CONSTRUCTOR[model_name](num_classes=num_classes, pretrained=pretrained) + try: + model.load_state_dict(torch.load(self.model_path)) + except Exception as e: + TypeError("model_path is not a valid torchvision model path: ", e) + + self.set_model(model) + + def set_model(self, model: Any): + """ + Sets the underlying TorchVision model. + Args: + model: Any + A TorchVision model + """ + + model.eval() + self.model = model.to(self.device) + + # set category_mapping + from sahi.utils.torchvision import COCO_CLASSES + + if self.category_mapping is None: + category_names = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))} + self.category_mapping = category_names + + def perform_inference(self, image: np.ndarray, image_size: int = None): + """ + Prediction is performed using self.model and the prediction result is set to self._original_predictions. + Args: + image: np.ndarray + A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. + image_size: int + Inference input size. + """ + from sahi.utils.torch import to_float_tensor + + # arrange model input size + if self.image_size is not None: + # get min and max of image height and width + min_shape, max_shape = min(image.shape[:2]), max(image.shape[:2]) + # torchvision resize transform scales the shorter dimension to the target size + # we want to scale the longer dimension to the target size + image_size = self.image_size * min_shape / max_shape + self.model.transform.min_size = (image_size,) # default is (800,) + self.model.transform.max_size = image_size # default is 1333 + + image = to_float_tensor(image) + image = image.to(self.device) + prediction_result = self.model([image]) + + self._original_predictions = prediction_result + + @property + def num_categories(self): + """ + Returns number of categories + """ + return len(self.category_mapping) + + @property + def has_mask(self): + """ + Returns if model output contains segmentation mask + """ + return self.model.with_mask + + @property + def category_names(self): + return list(self.category_mapping.values()) + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + """ + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list_per_image. + Args: + shift_amount_list: list of list + To shift the box and mask predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] + """ + original_predictions = self._original_predictions + + # compatilibty for sahi v0.8.20 + if isinstance(shift_amount_list[0], int): + shift_amount_list = [shift_amount_list] + if full_shape_list is not None and isinstance(full_shape_list[0], int): + full_shape_list = [full_shape_list] + + for image_predictions in original_predictions: + object_prediction_list_per_image = [] + + # get indices of boxes with score > confidence_threshold + scores = image_predictions["scores"].cpu().detach().numpy() + selected_indices = np.where(scores > self.confidence_threshold)[0] + + # parse boxes, masks, scores, category_ids from predictions + category_ids = list(image_predictions["labels"][selected_indices].cpu().detach().numpy()) + boxes = list(image_predictions["boxes"][selected_indices].cpu().detach().numpy()) + scores = scores[selected_indices] + + # check if predictions contain mask + masks = image_predictions.get("masks", None) + if masks is not None: + masks = list(image_predictions["masks"][selected_indices].cpu().detach().numpy()) + else: + masks = None + + # create object_prediction_list + object_prediction_list = [] + + shift_amount = shift_amount_list[0] + full_shape = None if full_shape_list is None else full_shape_list[0] + + for ind in range(len(boxes)): + + if masks is not None: + mask = np.array(masks[ind]) + else: + mask = None + + object_prediction = ObjectPrediction( + bbox=boxes[ind], + bool_mask=mask, + category_id=int(category_ids[ind]), + category_name=self.category_mapping[str(int(category_ids[ind]))], + shift_amount=shift_amount, + score=scores[ind], + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) + + self._object_prediction_list_per_image = object_prediction_list_per_image + + +class Yolov6DetectionModel(DetectionModel): + def load_model(self): + self.model = DetectBackend(self.model_path, self.device) + self.model.model.float() + self.stride = self.model.stride + + if self.device != 'cpu': + self.model(torch.zeros(1, 3, *self.image_size).to(self.device).type_as(next(self.model.model.parameters()))) # warmup + + def perform_inference(self, image: np.ndarray): + img, self.img_shape, self.src_shape = precess_image(image, img_size=self.image_size, stride=self.stride) + self._original_predictions = self.model(img) + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + if isinstance(shift_amount_list[0], int): + shift_amount_list = [shift_amount_list] + if full_shape_list is not None and isinstance(full_shape_list[0], int): + full_shape_list = [full_shape_list] + + shift_amount = shift_amount_list[0] + full_shape = None if full_shape_list is None else full_shape_list[0] + + object_prediction_list_per_image = [] + object_prediction_list = [] + det = non_max_suppression(self._original_predictions, conf_thres=self.confidence_threshold, max_det=1000)[0] + if len(det): + det[:, :4] = Inferer.rescale(self.img_shape, det[:, :4], self.src_shape).round() + for *xyxy, conf, cls in reversed(det): + category_id = int(cls) + category_name = COCO_CLASSES[category_id] + score = float(conf.numpy()) + bbox = [int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])] + object_prediction = ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + bool_mask=None, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) + self._object_prediction_list_per_image = object_prediction_list_per_image + \ No newline at end of file diff --git a/sahi/postprocess/__init__.py b/sahi/postprocess/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sahi/postprocess/combine.py b/sahi/postprocess/combine.py new file mode 100644 index 0000000..a59c067 --- /dev/null +++ b/sahi/postprocess/combine.py @@ -0,0 +1,604 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2021. + +import logging +from typing import List + +import torch + +from sahi.postprocess.utils import ObjectPredictionList, has_match, merge_object_prediction_pair +from sahi.prediction import ObjectPrediction +from sahi.utils.import_utils import check_requirements + +logger = logging.getLogger(__name__) + + +@check_requirements(["torch"]) +def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_threshold: float = 0.5): + """ + Apply non-maximum suppression to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + predictions: (tensor) The location preds for the image + along with the class predscores, Shape: [num_boxes,5]. + match_metric: (str) IOU or IOS + match_threshold: (float) The overlap thresh for + match metric. + Returns: + A list of filtered indexes, Shape: [ ,] + """ + scores = predictions[:, 4].squeeze() + category_ids = predictions[:, 5].squeeze() + keep_mask = torch.zeros_like(category_ids, dtype=torch.bool) + for category_id in torch.unique(category_ids): + curr_indices = torch.where(category_ids == category_id)[0] + curr_keep_indices = nms(predictions[curr_indices], match_metric, match_threshold) + keep_mask[curr_indices[curr_keep_indices]] = True + keep_indices = torch.where(keep_mask)[0] + # sort selected indices by their scores + keep_indices = keep_indices[scores[keep_indices].sort(descending=True)[1]].tolist() + return keep_indices + + +@check_requirements(["torch"]) +def nms( + predictions: torch.tensor, + match_metric: str = "IOU", + match_threshold: float = 0.5, +): + """ + Apply non-maximum suppression to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + predictions: (tensor) The location preds for the image + along with the class predscores, Shape: [num_boxes,5]. + match_metric: (str) IOU or IOS + match_threshold: (float) The overlap thresh for + match metric. + Returns: + A list of filtered indexes, Shape: [ ,] + """ + # we extract coordinates for every + # prediction box present in P + x1 = predictions[:, 0] + y1 = predictions[:, 1] + x2 = predictions[:, 2] + y2 = predictions[:, 3] + + # we extract the confidence scores as well + scores = predictions[:, 4] + + # calculate area of every block in P + areas = (x2 - x1) * (y2 - y1) + + # sort the prediction boxes in P + # according to their confidence scores + order = scores.argsort() + + # initialise an empty list for + # filtered prediction boxes + keep = [] + + while len(order) > 0: + # extract the index of the + # prediction with highest score + # we call this prediction S + idx = order[-1] + + # push S in filtered predictions list + keep.append(idx.tolist()) + + # remove S from P + order = order[:-1] + + # sanity check + if len(order) == 0: + break + + # select coordinates of BBoxes according to + # the indices in order + xx1 = torch.index_select(x1, dim=0, index=order) + xx2 = torch.index_select(x2, dim=0, index=order) + yy1 = torch.index_select(y1, dim=0, index=order) + yy2 = torch.index_select(y2, dim=0, index=order) + + # find the coordinates of the intersection boxes + xx1 = torch.max(xx1, x1[idx]) + yy1 = torch.max(yy1, y1[idx]) + xx2 = torch.min(xx2, x2[idx]) + yy2 = torch.min(yy2, y2[idx]) + + # find height and width of the intersection boxes + w = xx2 - xx1 + h = yy2 - yy1 + + # take max with 0.0 to avoid negative w and h + # due to non-overlapping boxes + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + + # find the intersection area + inter = w * h + + # find the areas of BBoxes according the indices in order + rem_areas = torch.index_select(areas, dim=0, index=order) + + if match_metric == "IOU": + # find the union of every prediction T in P + # with the prediction S + # Note that areas[idx] represents area of S + union = (rem_areas - inter) + areas[idx] + # find the IoU of every prediction in P with S + match_metric_value = inter / union + + elif match_metric == "IOS": + # find the smaller area of every prediction T in P + # with the prediction S + # Note that areas[idx] represents area of S + smaller = torch.min(rem_areas, areas[idx]) + # find the IoU of every prediction in P with S + match_metric_value = inter / smaller + else: + raise ValueError() + + # keep the boxes with IoU less than thresh_iou + mask = match_metric_value < match_threshold + order = order[mask] + return keep + + +@check_requirements(["torch"]) +def batched_greedy_nmm( + object_predictions_as_tensor: torch.tensor, + match_metric: str = "IOU", + match_threshold: float = 0.5, +): + """ + Apply greedy version of non-maximum merging per category to avoid detecting + too many overlapping bounding boxes for a given object. + Args: + object_predictions_as_tensor: (tensor) The location preds for the image + along with the class predscores, Shape: [num_boxes,5]. + match_metric: (str) IOU or IOS + match_threshold: (float) The overlap thresh for + match metric. + Returns: + keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices + to keep to a list of prediction indices to be merged. + """ + category_ids = object_predictions_as_tensor[:, 5].squeeze() + keep_to_merge_list = {} + for category_id in torch.unique(category_ids): + curr_indices = torch.where(category_ids == category_id)[0] + curr_keep_to_merge_list = greedy_nmm(object_predictions_as_tensor[curr_indices], match_metric, match_threshold) + curr_indices_list = curr_indices.tolist() + for curr_keep, curr_merge_list in curr_keep_to_merge_list.items(): + keep = curr_indices_list[curr_keep] + merge_list = [curr_indices_list[curr_merge_ind] for curr_merge_ind in curr_merge_list] + keep_to_merge_list[keep] = merge_list + return keep_to_merge_list + + +@check_requirements(["torch"]) +def greedy_nmm( + object_predictions_as_tensor: torch.tensor, + match_metric: str = "IOU", + match_threshold: float = 0.5, +): + """ + Apply greedy version of non-maximum merging to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + object_predictions_as_tensor: (tensor) The location preds for the image + along with the class predscores, Shape: [num_boxes,5]. + object_predictions_as_list: ObjectPredictionList Object prediction objects + to be merged. + match_metric: (str) IOU or IOS + match_threshold: (float) The overlap thresh for + match metric. + Returns: + keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices + to keep to a list of prediction indices to be merged. + """ + keep_to_merge_list = {} + + # we extract coordinates for every + # prediction box present in P + x1 = object_predictions_as_tensor[:, 0] + y1 = object_predictions_as_tensor[:, 1] + x2 = object_predictions_as_tensor[:, 2] + y2 = object_predictions_as_tensor[:, 3] + + # we extract the confidence scores as well + scores = object_predictions_as_tensor[:, 4] + + # calculate area of every block in P + areas = (x2 - x1) * (y2 - y1) + + # sort the prediction boxes in P + # according to their confidence scores + order = scores.argsort() + + # initialise an empty list for + # filtered prediction boxes + keep = [] + + while len(order) > 0: + # extract the index of the + # prediction with highest score + # we call this prediction S + idx = order[-1] + + # push S in filtered predictions list + keep.append(idx.tolist()) + + # remove S from P + order = order[:-1] + + # sanity check + if len(order) == 0: + break + + # select coordinates of BBoxes according to + # the indices in order + xx1 = torch.index_select(x1, dim=0, index=order) + xx2 = torch.index_select(x2, dim=0, index=order) + yy1 = torch.index_select(y1, dim=0, index=order) + yy2 = torch.index_select(y2, dim=0, index=order) + + # find the coordinates of the intersection boxes + xx1 = torch.max(xx1, x1[idx]) + yy1 = torch.max(yy1, y1[idx]) + xx2 = torch.min(xx2, x2[idx]) + yy2 = torch.min(yy2, y2[idx]) + + # find height and width of the intersection boxes + w = xx2 - xx1 + h = yy2 - yy1 + + # take max with 0.0 to avoid negative w and h + # due to non-overlapping boxes + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + + # find the intersection area + inter = w * h + + # find the areas of BBoxes according the indices in order + rem_areas = torch.index_select(areas, dim=0, index=order) + + if match_metric == "IOU": + # find the union of every prediction T in P + # with the prediction S + # Note that areas[idx] represents area of S + union = (rem_areas - inter) + areas[idx] + # find the IoU of every prediction in P with S + match_metric_value = inter / union + + elif match_metric == "IOS": + # find the smaller area of every prediction T in P + # with the prediction S + # Note that areas[idx] represents area of S + smaller = torch.min(rem_areas, areas[idx]) + # find the IoS of every prediction in P with S + match_metric_value = inter / smaller + else: + raise ValueError() + + # keep the boxes with IoU/IoS less than thresh_iou + mask = match_metric_value < match_threshold + matched_box_indices = order[(mask == False).nonzero().flatten()].flip(dims=(0,)) + unmatched_indices = order[(mask == True).nonzero().flatten()] + + # update box pool + order = unmatched_indices[scores[unmatched_indices].argsort()] + + # create keep_ind to merge_ind_list mapping + keep_to_merge_list[idx.tolist()] = [] + + for matched_box_ind in matched_box_indices.tolist(): + keep_to_merge_list[idx.tolist()].append(matched_box_ind) + + return keep_to_merge_list + + +@check_requirements(["torch"]) +def batched_nmm( + object_predictions_as_tensor: torch.tensor, + match_metric: str = "IOU", + match_threshold: float = 0.5, +): + """ + Apply non-maximum merging per category to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + object_predictions_as_tensor: (tensor) The location preds for the image + along with the class predscores, Shape: [num_boxes,5]. + match_metric: (str) IOU or IOS + match_threshold: (float) The overlap thresh for + match metric. + Returns: + keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices + to keep to a list of prediction indices to be merged. + """ + category_ids = object_predictions_as_tensor[:, 5].squeeze() + keep_to_merge_list = {} + for category_id in torch.unique(category_ids): + curr_indices = torch.where(category_ids == category_id)[0] + curr_keep_to_merge_list = nmm(object_predictions_as_tensor[curr_indices], match_metric, match_threshold) + curr_indices_list = curr_indices.tolist() + for curr_keep, curr_merge_list in curr_keep_to_merge_list.items(): + keep = curr_indices_list[curr_keep] + merge_list = [curr_indices_list[curr_merge_ind] for curr_merge_ind in curr_merge_list] + keep_to_merge_list[keep] = merge_list + return keep_to_merge_list + + +@check_requirements(["torch"]) +def nmm( + object_predictions_as_tensor: torch.tensor, + match_metric: str = "IOU", + match_threshold: float = 0.5, +): + """ + Apply non-maximum merging to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + object_predictions_as_tensor: (tensor) The location preds for the image + along with the class predscores, Shape: [num_boxes,5]. + object_predictions_as_list: ObjectPredictionList Object prediction objects + to be merged. + match_metric: (str) IOU or IOS + match_threshold: (float) The overlap thresh for + match metric. + Returns: + keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices + to keep to a list of prediction indices to be merged. + """ + keep_to_merge_list = {} + merge_to_keep = {} + + # we extract coordinates for every + # prediction box present in P + x1 = object_predictions_as_tensor[:, 0] + y1 = object_predictions_as_tensor[:, 1] + x2 = object_predictions_as_tensor[:, 2] + y2 = object_predictions_as_tensor[:, 3] + + # we extract the confidence scores as well + scores = object_predictions_as_tensor[:, 4] + + # calculate area of every block in P + areas = (x2 - x1) * (y2 - y1) + + # sort the prediction boxes in P + # according to their confidence scores + order = scores.argsort(descending=True) + + for ind in range(len(object_predictions_as_tensor)): + # extract the index of the + # prediction with highest score + # we call this prediction S + pred_ind = order[ind] + pred_ind = pred_ind.tolist() + + # remove selected pred + other_pred_inds = order[order != pred_ind] + + # select coordinates of BBoxes according to + # the indices in order + xx1 = torch.index_select(x1, dim=0, index=other_pred_inds) + xx2 = torch.index_select(x2, dim=0, index=other_pred_inds) + yy1 = torch.index_select(y1, dim=0, index=other_pred_inds) + yy2 = torch.index_select(y2, dim=0, index=other_pred_inds) + + # find the coordinates of the intersection boxes + xx1 = torch.max(xx1, x1[pred_ind]) + yy1 = torch.max(yy1, y1[pred_ind]) + xx2 = torch.min(xx2, x2[pred_ind]) + yy2 = torch.min(yy2, y2[pred_ind]) + + # find height and width of the intersection boxes + w = xx2 - xx1 + h = yy2 - yy1 + + # take max with 0.0 to avoid negative w and h + # due to non-overlapping boxes + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + + # find the intersection area + inter = w * h + + # find the areas of BBoxes according the indices in order + rem_areas = torch.index_select(areas, dim=0, index=other_pred_inds) + + if match_metric == "IOU": + # find the union of every prediction T in P + # with the prediction S + # Note that areas[idx] represents area of S + union = (rem_areas - inter) + areas[pred_ind] + # find the IoU of every prediction in P with S + match_metric_value = inter / union + + elif match_metric == "IOS": + # find the smaller area of every prediction T in P + # with the prediction S + # Note that areas[idx] represents area of S + smaller = torch.min(rem_areas, areas[pred_ind]) + # find the IoS of every prediction in P with S + match_metric_value = inter / smaller + else: + raise ValueError() + + # keep the boxes with IoU/IoS less than thresh_iou + mask = match_metric_value < match_threshold + matched_box_indices = other_pred_inds[(mask == False).nonzero().flatten()].flip(dims=(0,)) + + # create keep_ind to merge_ind_list mapping + if pred_ind not in merge_to_keep: + keep_to_merge_list[pred_ind] = [] + + for matched_box_ind in matched_box_indices.tolist(): + if matched_box_ind not in merge_to_keep: + keep_to_merge_list[pred_ind].append(matched_box_ind) + merge_to_keep[matched_box_ind] = pred_ind + + else: + keep = merge_to_keep[pred_ind] + for matched_box_ind in matched_box_indices.tolist(): + if matched_box_ind not in keep_to_merge_list and matched_box_ind not in merge_to_keep: + keep_to_merge_list[keep].append(matched_box_ind) + merge_to_keep[matched_box_ind] = keep + + return keep_to_merge_list + + +class PostprocessPredictions: + """Utilities for calculating IOU/IOS based match for given ObjectPredictions""" + + def __init__( + self, + match_threshold: float = 0.5, + match_metric: str = "IOU", + class_agnostic: bool = True, + ): + self.match_threshold = match_threshold + self.class_agnostic = class_agnostic + self.match_metric = match_metric + + def __call__(self): + raise NotImplementedError() + + +class NMSPostprocess(PostprocessPredictions): + def __call__( + self, + object_predictions: List[ObjectPrediction], + ): + object_prediction_list = ObjectPredictionList(object_predictions) + object_predictions_as_torch = object_prediction_list.totensor() + if self.class_agnostic: + keep = nms( + object_predictions_as_torch, match_threshold=self.match_threshold, match_metric=self.match_metric + ) + else: + keep = batched_nms( + object_predictions_as_torch, match_threshold=self.match_threshold, match_metric=self.match_metric + ) + + selected_object_predictions = object_prediction_list[keep].tolist() + if not isinstance(selected_object_predictions, list): + selected_object_predictions = [selected_object_predictions] + + return selected_object_predictions + + +class NMMPostprocess(PostprocessPredictions): + def __call__( + self, + object_predictions: List[ObjectPrediction], + ): + object_prediction_list = ObjectPredictionList(object_predictions) + object_predictions_as_torch = object_prediction_list.totensor() + if self.class_agnostic: + keep_to_merge_list = nmm( + object_predictions_as_torch, + match_threshold=self.match_threshold, + match_metric=self.match_metric, + ) + else: + keep_to_merge_list = batched_nmm( + object_predictions_as_torch, + match_threshold=self.match_threshold, + match_metric=self.match_metric, + ) + + selected_object_predictions = [] + for keep_ind, merge_ind_list in keep_to_merge_list.items(): + for merge_ind in merge_ind_list: + if has_match( + object_prediction_list[keep_ind].tolist(), + object_prediction_list[merge_ind].tolist(), + self.match_metric, + self.match_threshold, + ): + object_prediction_list[keep_ind] = merge_object_prediction_pair( + object_prediction_list[keep_ind].tolist(), object_prediction_list[merge_ind].tolist() + ) + selected_object_predictions.append(object_prediction_list[keep_ind].tolist()) + + return selected_object_predictions + + +class GreedyNMMPostprocess(PostprocessPredictions): + def __call__( + self, + object_predictions: List[ObjectPrediction], + ): + object_prediction_list = ObjectPredictionList(object_predictions) + object_predictions_as_torch = object_prediction_list.totensor() + if self.class_agnostic: + keep_to_merge_list = greedy_nmm( + object_predictions_as_torch, + match_threshold=self.match_threshold, + match_metric=self.match_metric, + ) + else: + keep_to_merge_list = batched_greedy_nmm( + object_predictions_as_torch, + match_threshold=self.match_threshold, + match_metric=self.match_metric, + ) + + selected_object_predictions = [] + for keep_ind, merge_ind_list in keep_to_merge_list.items(): + for merge_ind in merge_ind_list: + if has_match( + object_prediction_list[keep_ind].tolist(), + object_prediction_list[merge_ind].tolist(), + self.match_metric, + self.match_threshold, + ): + object_prediction_list[keep_ind] = merge_object_prediction_pair( + object_prediction_list[keep_ind].tolist(), object_prediction_list[merge_ind].tolist() + ) + selected_object_predictions.append(object_prediction_list[keep_ind].tolist()) + + return selected_object_predictions + + +class LSNMSPostprocess(PostprocessPredictions): + # https://github.com/remydubois/lsnms/blob/10b8165893db5bfea4a7cb23e268a502b35883cf/lsnms/nms.py#L62 + def __call__( + self, + object_predictions: List[ObjectPrediction], + ): + try: + from lsnms import nms + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'Please run "pip install lsnms>0.3.1" to install lsnms first for lsnms utilities.' + ) + + if self.match_metric == "IOS": + NotImplementedError(f"match_metric={self.match_metric} is not supported for LSNMSPostprocess") + + logger.warning("LSNMSPostprocess is experimental and not recommended to use.") + + object_prediction_list = ObjectPredictionList(object_predictions) + object_predictions_as_numpy = object_prediction_list.tonumpy() + + boxes = object_predictions_as_numpy[:, :4] + scores = object_predictions_as_numpy[:, 4] + class_ids = object_predictions_as_numpy[:, 5].astype("uint8") + + keep = nms( + boxes, scores, iou_threshold=self.match_threshold, class_ids=None if self.class_agnostic else class_ids + ) + + selected_object_predictions = object_prediction_list[keep].tolist() + if not isinstance(selected_object_predictions, list): + selected_object_predictions = [selected_object_predictions] + + return selected_object_predictions diff --git a/sahi/postprocess/legacy/__init__.py b/sahi/postprocess/legacy/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/sahi/postprocess/legacy/__init__.py @@ -0,0 +1 @@ + diff --git a/sahi/postprocess/legacy/combine.py b/sahi/postprocess/legacy/combine.py new file mode 100644 index 0000000..628106f --- /dev/null +++ b/sahi/postprocess/legacy/combine.py @@ -0,0 +1,181 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2021. + +import copy +from typing import List + +import numpy as np + +from sahi.annotation import BoundingBox, Category, Mask +from sahi.postprocess.utils import calculate_area, calculate_box_union, calculate_intersection_area +from sahi.prediction import ObjectPrediction + + +class PostprocessPredictions: + """Utilities for calculating IOU/IOS based match for given ObjectPredictions""" + + def __init__( + self, + match_threshold: float = 0.5, + match_metric: str = "IOU", + class_agnostic: bool = True, + ): + self.match_threshold = match_threshold + self.class_agnostic = class_agnostic + if match_metric == "IOU": + self.calculate_match = self.calculate_bbox_iou + elif match_metric == "IOS": + self.calculate_match = self.calculate_bbox_ios + else: + raise ValueError(f"'match_metric' should be one of ['IOU', 'IOS'] but given as {match_metric}") + + def _has_match(self, pred1: ObjectPrediction, pred2: ObjectPrediction) -> bool: + threshold_condition = self.calculate_match(pred1, pred2) > self.match_threshold + category_condition = self.has_same_category_id(pred1, pred2) or self.class_agnostic + return threshold_condition and category_condition + + @staticmethod + def get_score_func(object_prediction: ObjectPrediction): + """Used for sorting predictions""" + return object_prediction.score.value + + @staticmethod + def has_same_category_id(pred1: ObjectPrediction, pred2: ObjectPrediction) -> bool: + return pred1.category.id == pred2.category.id + + @staticmethod + def calculate_bbox_iou(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float: + """Returns the ratio of intersection area to the union""" + box1 = np.array(pred1.bbox.to_voc_bbox()) + box2 = np.array(pred2.bbox.to_voc_bbox()) + area1 = calculate_area(box1) + area2 = calculate_area(box2) + intersect = calculate_intersection_area(box1, box2) + return intersect / (area1 + area2 - intersect) + + @staticmethod + def calculate_bbox_ios(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float: + """Returns the ratio of intersection area to the smaller box's area""" + box1 = np.array(pred1.bbox.to_voc_bbox()) + box2 = np.array(pred2.bbox.to_voc_bbox()) + area1 = calculate_area(box1) + area2 = calculate_area(box2) + intersect = calculate_intersection_area(box1, box2) + smaller_area = np.minimum(area1, area2) + return intersect / smaller_area + + def __call__(self): + raise NotImplementedError() + + +class NMSPostprocess(PostprocessPredictions): + def __call__( + self, + object_predictions: List[ObjectPrediction], + ): + source_object_predictions: List[ObjectPrediction] = copy.deepcopy(object_predictions) + selected_object_predictions: List[ObjectPrediction] = [] + while len(source_object_predictions) > 0: + # select object prediction with highest score + source_object_predictions.sort(reverse=True, key=self.get_score_func) + selected_object_prediction = source_object_predictions[0] + # remove selected prediction from source list + del source_object_predictions[0] + # if any element from remaining source prediction list matches, remove it + new_source_object_predictions: List[ObjectPrediction] = [] + for candidate_object_prediction in source_object_predictions: + if self._has_match(selected_object_prediction, candidate_object_prediction): + pass + else: + new_source_object_predictions.append(candidate_object_prediction) + source_object_predictions = new_source_object_predictions + # append selected prediction to selected list + selected_object_predictions.append(selected_object_prediction) + return selected_object_predictions + + +class UnionMergePostprocess(PostprocessPredictions): + def __call__( + self, + object_predictions: List[ObjectPrediction], + ): + source_object_predictions: List[ObjectPrediction] = copy.deepcopy(object_predictions) + selected_object_predictions: List[ObjectPrediction] = [] + while len(source_object_predictions) > 0: + # select object prediction with highest score + source_object_predictions.sort(reverse=True, key=self.get_score_func) + selected_object_prediction = source_object_predictions[0] + # remove selected prediction from source list + del source_object_predictions[0] + # if any element from remaining source prediction list matches, remove it and merge with selected prediction + new_source_object_predictions: List[ObjectPrediction] = [] + for ind, candidate_object_prediction in enumerate(source_object_predictions): + if self._has_match(selected_object_prediction, candidate_object_prediction): + selected_object_prediction = self._merge_object_prediction_pair( + selected_object_prediction, candidate_object_prediction + ) + else: + new_source_object_predictions.append(candidate_object_prediction) + source_object_predictions = new_source_object_predictions + # append selected prediction to selected list + selected_object_predictions.append(selected_object_prediction) + return selected_object_predictions + + def _merge_object_prediction_pair( + self, + pred1: ObjectPrediction, + pred2: ObjectPrediction, + ) -> ObjectPrediction: + shift_amount = pred1.bbox.shift_amount + merged_bbox: BoundingBox = self._get_merged_bbox(pred1, pred2) + merged_score: float = self._get_merged_score(pred1, pred2) + merged_category: Category = self._get_merged_category(pred1, pred2) + if pred1.mask and pred2.mask: + merged_mask: Mask = self._get_merged_mask(pred1, pred2) + bool_mask = merged_mask.bool_mask + full_shape = merged_mask.full_shape + else: + bool_mask = None + full_shape = None + return ObjectPrediction( + bbox=merged_bbox.to_voc_bbox(), + score=merged_score, + category_id=merged_category.id, + category_name=merged_category.name, + bool_mask=bool_mask, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + @staticmethod + def _get_merged_category(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Category: + if pred1.score.value > pred2.score.value: + return pred1.category + else: + return pred2.category + + @staticmethod + def _get_merged_bbox(pred1: ObjectPrediction, pred2: ObjectPrediction) -> BoundingBox: + box1: List[int] = pred1.bbox.to_voc_bbox() + box2: List[int] = pred2.bbox.to_voc_bbox() + bbox = BoundingBox(box=calculate_box_union(box1, box2)) + return bbox + + @staticmethod + def _get_merged_score( + pred1: ObjectPrediction, + pred2: ObjectPrediction, + ) -> float: + scores: List[float] = [pred.score.value for pred in (pred1, pred2)] + return max(scores) + + @staticmethod + def _get_merged_mask(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Mask: + mask1 = pred1.mask + mask2 = pred2.mask + union_mask = np.logical_or(mask1.bool_mask, mask2.bool_mask) + return Mask( + bool_mask=union_mask, + full_shape=mask1.full_shape, + shift_amount=mask1.shift_amount, + ) diff --git a/sahi/postprocess/utils.py b/sahi/postprocess/utils.py new file mode 100644 index 0000000..2b8906c --- /dev/null +++ b/sahi/postprocess/utils.py @@ -0,0 +1,216 @@ +from collections.abc import Sequence +from typing import List, Union + +import numpy as np +import torch + +from sahi.annotation import BoundingBox, Category, Mask +from sahi.prediction import ObjectPrediction + + +class ObjectPredictionList(Sequence): + def __init__(self, list): + self.list = list + super().__init__() + + def __getitem__(self, i): + if torch.is_tensor(i) or isinstance(i, np.ndarray): + i = i.tolist() + if isinstance(i, int): + return ObjectPredictionList([self.list[i]]) + elif isinstance(i, (tuple, list)): + accessed_mapping = map(self.list.__getitem__, i) + return ObjectPredictionList(list(accessed_mapping)) + else: + raise NotImplementedError(f"{type(i)}") + + def __setitem__(self, i, elem): + if torch.is_tensor(i) or isinstance(i, np.ndarray): + i = i.tolist() + if isinstance(i, int): + self.list[i] = elem + elif isinstance(i, (tuple, list)): + if len(i) != len(elem): + raise ValueError() + if isinstance(elem, ObjectPredictionList): + for ind, el in enumerate(elem.list): + self.list[i[ind]] = el + else: + for ind, el in enumerate(elem): + self.list[i[ind]] = el + else: + raise NotImplementedError(f"{type(i)}") + + def __len__(self): + return len(self.list) + + def __str__(self): + return str(self.list) + + def extend(self, object_prediction_list): + self.list.extend(object_prediction_list.list) + + def totensor(self): + return object_prediction_list_to_torch(self) + + def tonumpy(self): + return object_prediction_list_to_numpy(self) + + def tolist(self): + if len(self.list) == 1: + return self.list[0] + else: + return self.list + + +def object_prediction_list_to_torch(object_prediction_list: ObjectPredictionList) -> torch.tensor: + """ + Returns: + torch.tensor of size N x [x1, y1, x2, y2, score, category_id] + """ + num_predictions = len(object_prediction_list) + torch_predictions = torch.zeros([num_predictions, 6], dtype=torch.float32) + for ind, object_prediction in enumerate(object_prediction_list): + torch_predictions[ind, :4] = torch.tensor(object_prediction.tolist().bbox.to_voc_bbox(), dtype=torch.int32) + torch_predictions[ind, 4] = object_prediction.tolist().score.value + torch_predictions[ind, 5] = object_prediction.tolist().category.id + return torch_predictions + + +def object_prediction_list_to_numpy(object_prediction_list: ObjectPredictionList) -> np.ndarray: + """ + Returns: + np.ndarray of size N x [x1, y1, x2, y2, score, category_id] + """ + num_predictions = len(object_prediction_list) + numpy_predictions = np.zeros([num_predictions, 6], dtype=np.float32) + for ind, object_prediction in enumerate(object_prediction_list): + numpy_predictions[ind, :4] = np.array(object_prediction.tolist().bbox.to_voc_bbox(), dtype=np.int32) + numpy_predictions[ind, 4] = object_prediction.tolist().score.value + numpy_predictions[ind, 5] = object_prediction.tolist().category.id + return numpy_predictions + + +def calculate_box_union(box1: Union[List[int], np.ndarray], box2: Union[List[int], np.ndarray]) -> List[int]: + """ + Args: + box1 (List[int]): [x1, y1, x2, y2] + box2 (List[int]): [x1, y1, x2, y2] + """ + box1 = np.array(box1) + box2 = np.array(box2) + left_top = np.minimum(box1[:2], box2[:2]) + right_bottom = np.maximum(box1[2:], box2[2:]) + return list(np.concatenate((left_top, right_bottom))) + + +def calculate_area(box: Union[List[int], np.ndarray]) -> float: + """ + Args: + box (List[int]): [x1, y1, x2, y2] + """ + return (box[2] - box[0]) * (box[3] - box[1]) + + +def calculate_intersection_area(box1: np.ndarray, box2: np.ndarray) -> float: + """ + Args: + box1 (np.ndarray): np.array([x1, y1, x2, y2]) + box2 (np.ndarray): np.array([x1, y1, x2, y2]) + """ + left_top = np.maximum(box1[:2], box2[:2]) + right_bottom = np.minimum(box1[2:], box2[2:]) + width_height = (right_bottom - left_top).clip(min=0) + return width_height[0] * width_height[1] + + +def calculate_bbox_iou(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float: + """Returns the ratio of intersection area to the union""" + box1 = np.array(pred1.bbox.to_voc_bbox()) + box2 = np.array(pred2.bbox.to_voc_bbox()) + area1 = calculate_area(box1) + area2 = calculate_area(box2) + intersect = calculate_intersection_area(box1, box2) + return intersect / (area1 + area2 - intersect) + + +def calculate_bbox_ios(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float: + """Returns the ratio of intersection area to the smaller box's area""" + box1 = np.array(pred1.bbox.to_voc_bbox()) + box2 = np.array(pred2.bbox.to_voc_bbox()) + area1 = calculate_area(box1) + area2 = calculate_area(box2) + intersect = calculate_intersection_area(box1, box2) + smaller_area = np.minimum(area1, area2) + return intersect / smaller_area + + +def has_match( + pred1: ObjectPrediction, pred2: ObjectPrediction, match_type: str = "IOU", match_threshold: float = 0.5 +) -> bool: + if match_type == "IOU": + threshold_condition = calculate_bbox_iou(pred1, pred2) > match_threshold + elif match_type == "IOS": + threshold_condition = calculate_bbox_ios(pred1, pred2) > match_threshold + else: + raise ValueError() + return threshold_condition + + +def get_merged_mask(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Mask: + mask1 = pred1.mask + mask2 = pred2.mask + union_mask = np.logical_or(mask1.bool_mask, mask2.bool_mask) + return Mask( + bool_mask=union_mask, + full_shape=mask1.full_shape, + shift_amount=mask1.shift_amount, + ) + + +def get_merged_score( + pred1: ObjectPrediction, + pred2: ObjectPrediction, +) -> float: + scores: List[float] = [pred.score.value for pred in (pred1, pred2)] + return max(scores) + + +def get_merged_bbox(pred1: ObjectPrediction, pred2: ObjectPrediction) -> BoundingBox: + box1: List[int] = pred1.bbox.to_voc_bbox() + box2: List[int] = pred2.bbox.to_voc_bbox() + bbox = BoundingBox(box=calculate_box_union(box1, box2)) + return bbox + + +def get_merged_category(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Category: + if pred1.score.value > pred2.score.value: + return pred1.category + else: + return pred2.category + + +def merge_object_prediction_pair( + pred1: ObjectPrediction, + pred2: ObjectPrediction, +) -> ObjectPrediction: + shift_amount = pred1.bbox.shift_amount + merged_bbox: BoundingBox = get_merged_bbox(pred1, pred2) + merged_score: float = get_merged_score(pred1, pred2) + merged_category: Category = get_merged_category(pred1, pred2) + if pred1.mask and pred2.mask: + merged_mask: Mask = get_merged_mask(pred1, pred2) + bool_mask = merged_mask.bool_mask + full_shape = merged_mask.full_shape + else: + bool_mask = None + full_shape = None + return ObjectPrediction( + bbox=merged_bbox.to_voc_bbox(), + score=merged_score, + category_id=merged_category.id, + category_name=merged_category.name, + bool_mask=bool_mask, + shift_amount=shift_amount, + full_shape=full_shape, + ) diff --git a/sahi/predict.py b/sahi/predict.py new file mode 100644 index 0000000..765e17c --- /dev/null +++ b/sahi/predict.py @@ -0,0 +1,842 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import logging +import os +import time +from typing import List, Optional + +import numpy as np +from PIL import Image +from tqdm import tqdm + +from sahi.auto_model import AutoDetectionModel +from sahi.model import DetectionModel +from sahi.postprocess.combine import ( + GreedyNMMPostprocess, + LSNMSPostprocess, + NMMPostprocess, + NMSPostprocess, + PostprocessPredictions, +) +from sahi.prediction import ObjectPrediction, PredictionResult +from sahi.slicing import slice_image +from sahi.utils.coco import Coco, CocoImage +from sahi.utils.cv import ( + IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS, + crop_object_predictions, + cv2, + get_video_reader, + read_image_as_pil, + visualize_object_predictions, +) +from sahi.utils.file import Path, increment_path, list_files, save_json, save_pickle +from sahi.utils.import_utils import check_requirements + +POSTPROCESS_NAME_TO_CLASS = { + "GREEDYNMM": GreedyNMMPostprocess, + "NMM": NMMPostprocess, + "NMS": NMSPostprocess, + "LSNMS": LSNMSPostprocess, +} + +LOW_MODEL_CONFIDENCE = 0.1 + + +logger = logging.getLogger(__name__) + + +def get_prediction( + image, + detection_model, + shift_amount: list = [0, 0], + full_shape=None, + postprocess: Optional[PostprocessPredictions] = None, + verbose: int = 0, +) -> PredictionResult: + """ + Function for performing prediction for given image using given detection_model. + + Arguments: + image: str or np.ndarray + Location of image or numpy image matrix to slice + detection_model: model.DetectionMode + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + full_shape: List + Size of the full image, should be in the form of [height, width] + postprocess: sahi.postprocess.combine.PostprocessPredictions + verbose: int + 0: no print (default) + 1: print prediction duration + + Returns: + A dict with fields: + object_prediction_list: a list of ObjectPrediction + durations_in_seconds: a dict containing elapsed times for profiling + """ + durations_in_seconds = dict() + + # read image as pil + image_as_pil = read_image_as_pil(image) + # get prediction + time_start = time.time() + detection_model.perform_inference(np.ascontiguousarray(image_as_pil)) + time_end = time.time() - time_start + durations_in_seconds["prediction"] = time_end + + # process prediction + time_start = time.time() + # works only with 1 batch + detection_model.convert_original_predictions( + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list + + # postprocess matching predictions + if postprocess is not None: + object_prediction_list = postprocess(object_prediction_list) + + time_end = time.time() - time_start + durations_in_seconds["postprocess"] = time_end + + if verbose == 1: + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + + return PredictionResult( + image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds + ) + + +def get_sliced_prediction( + image, + detection_model=None, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + perform_standard_pred: bool = True, + postprocess_type: str = "GREEDYNMM", + postprocess_match_metric: str = "IOS", + postprocess_match_threshold: float = 0.5, + postprocess_class_agnostic: bool = False, + verbose: int = 1, + merge_buffer_length: int = None, +) -> PredictionResult: + """ + Function for slice image + get predicion for each slice + combine predictions in full image. + + Args: + image: str or np.ndarray + Location of image or numpy image matrix to slice + detection_model: model.DetectionModel + slice_height: int + Height of each slice. Defaults to ``512``. + slice_width: int + Width of each slice. Defaults to ``512``. + overlap_height_ratio: float + Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + overlap_width_ratio: float + Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + perform_standard_pred: bool + Perform a standard prediction on top of sliced predictions to increase large object + detection accuracy. Default: True. + postprocess_type: str + Type of the postprocess to be used after sliced inference while merging/eliminating predictions. + Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_threshold: float + Sliced predictions having higher iou than postprocess_match_threshold will be + postprocessed after sliced prediction. + postprocess_class_agnostic: bool + If True, postprocess will ignore category ids. + verbose: int + 0: no print + 1: print number of slices (default) + 2: print number of slices and slice/prediction durations + merge_buffer_length: int + The length of buffer for slices to be used during sliced prediction, which is suitable for low memory. + It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered. + scenario. See [the discussion](https://github.com/obss/sahi/pull/445). + + Returns: + A Dict with fields: + object_prediction_list: a list of sahi.prediction.ObjectPrediction + durations_in_seconds: a dict containing elapsed times for profiling + """ + + # for profiling + durations_in_seconds = dict() + + # currently only 1 batch supported + num_batch = 1 + + # create slices from full image + time_start = time.time() + slice_image_result = slice_image( + image=image, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + ) + num_slices = len(slice_image_result) + time_end = time.time() - time_start + durations_in_seconds["slice"] = time_end + + # init match postprocess instance + if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys(): + raise ValueError( + f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}" + ) + elif postprocess_type == "UNIONMERGE": + # deprecated in v0.9.3 + raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.") + postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type] + postprocess = postprocess_constructor( + match_threshold=postprocess_match_threshold, + match_metric=postprocess_match_metric, + class_agnostic=postprocess_class_agnostic, + ) + + # create prediction input + num_group = int(num_slices / num_batch) + if verbose == 1 or verbose == 2: + tqdm.write(f"Performing prediction on {num_slices} number of slices.") + object_prediction_list = [] + # perform sliced prediction + for group_ind in range(num_group): + # prepare batch (currently supports only 1 batch) + image_list = [] + shift_amount_list = [] + for image_ind in range(num_batch): + image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) + shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind]) + # perform batch prediction + prediction_result = get_prediction( + image=image_list[0], + detection_model=detection_model, + shift_amount=shift_amount_list[0], + full_shape=[ + slice_image_result.original_image_height, + slice_image_result.original_image_width, + ], + ) + # convert sliced predictions to full predictions + for object_prediction in prediction_result.object_prediction_list: + if object_prediction: # if not empty + object_prediction_list.append(object_prediction.get_shifted_object_prediction()) + + # merge matching predictions during sliced prediction + if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length: + object_prediction_list = postprocess(object_prediction_list) + + # perform standard prediction + if num_slices > 1 and perform_standard_pred: + prediction_result = get_prediction( + image=image, + detection_model=detection_model, + shift_amount=[0, 0], + full_shape=None, + postprocess=None, + ) + object_prediction_list.extend(prediction_result.object_prediction_list) + + if verbose == 2: + print( + "Slicing performed in", + durations_in_seconds["slice"], + "seconds.", + ) + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + + # merge matching predictions + if len(object_prediction_list) > 1: + object_prediction_list = postprocess(object_prediction_list) + + time_end = time.time() - time_start + durations_in_seconds["prediction"] = time_end + + return PredictionResult( + image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds + ) + + +def predict( + detection_model: DetectionModel = None, + model_type: str = "mmdet", + model_path: str = None, + model_config_path: str = None, + model_confidence_threshold: float = 0.25, + model_device: str = None, + model_category_mapping: dict = None, + model_category_remapping: dict = None, + source: str = None, + no_standard_prediction: bool = False, + no_sliced_prediction: bool = False, + image_size: int = None, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + postprocess_type: str = "GREEDYNMM", + postprocess_match_metric: str = "IOS", + postprocess_match_threshold: float = 0.5, + postprocess_class_agnostic: bool = False, + novisual: bool = False, + view_video: bool = False, + frame_skip_interval: int = 0, + export_pickle: bool = False, + export_crop: bool = False, + dataset_json_path: bool = None, + project: str = "runs/predict", + name: str = "exp", + visual_bbox_thickness: int = None, + visual_text_size: float = None, + visual_text_thickness: int = None, + visual_export_format: str = "png", + verbose: int = 1, + return_dict: bool = False, + force_postprocess_type: bool = False, +): + """ + Performs prediction for all present images in given folder. + + Args: + detection_model: sahi.model.DetectionModel + Optionally provide custom DetectionModel to be used for inference. When provided, + model_type, model_path, config_path, model_device, model_category_mapping, image_size + params will be ignored + model_type: str + mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'. + model_path: str + Path for the model weight + model_config_path: str + Path for the detection model config file + model_confidence_threshold: float + All predictions with score < model_confidence_threshold will be discarded. + model_device: str + Torch device, "cpu" or "cuda" + model_category_mapping: dict + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + model_category_remapping: dict: str to int + Remap category ids after performing inference + source: str + Folder directory that contains images or path of the image to be predicted. Also video to be predicted. + no_standard_prediction: bool + Dont perform standard prediction. Default: False. + no_sliced_prediction: bool + Dont perform sliced prediction. Default: False. + image_size: int + Input image size for each inference (image is scaled by preserving asp. rat.). + slice_height: int + Height of each slice. Defaults to ``512``. + slice_width: int + Width of each slice. Defaults to ``512``. + overlap_height_ratio: float + Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + overlap_width_ratio: float + Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + postprocess_type: str + Type of the postprocess to be used after sliced inference while merging/eliminating predictions. + Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_threshold: float + Sliced predictions having higher iou than postprocess_match_threshold will be + postprocessed after sliced prediction. + postprocess_class_agnostic: bool + If True, postprocess will ignore category ids. + novisual: bool + Dont export predicted video/image visuals. + view_video: bool + View result of prediction during video inference. + frame_skip_interval: int + If view_video or export_visual is slow, you can process one frames of 3(for exp: --frame_skip_interval=3). + export_pickle: bool + Export predictions as .pickle + export_crop: bool + Export predictions as cropped images. + dataset_json_path: str + If coco file path is provided, detection results will be exported in coco json format. + project: str + Save results to project/name. + name: str + Save results to project/name. + visual_bbox_thickness: int + visual_text_size: float + visual_text_thickness: int + visual_export_format: str + Can be specified as 'jpg' or 'png' + verbose: int + 0: no print + 1: print slice/prediction durations, number of slices + 2: print model loading/file exporting durations + return_dict: bool + If True, returns a dict with 'export_dir' field. + force_postprocess_type: bool + If True, auto postprocess check will e disabled + """ + # assert prediction type + if no_standard_prediction and no_sliced_prediction: + raise ValueError("'no_standard_prediction' and 'no_sliced_prediction' cannot be True at the same time.") + + # auto postprocess type + if not force_postprocess_type and model_confidence_threshold < LOW_MODEL_CONFIDENCE and postprocess_type != "NMS": + logger.warning( + f"Switching postprocess type/metric to NMS/IOU since confidence threshold is low ({model_confidence_threshold})." + ) + postprocess_type = "NMS" + postprocess_match_metric = "IOU" + + # for profiling + durations_in_seconds = dict() + + # init export directories + save_dir = Path(increment_path(Path(project) / name, exist_ok=False)) # increment run + crop_dir = save_dir / "crops" + visual_dir = save_dir / "visuals" + visual_with_gt_dir = save_dir / "visuals_with_gt" + pickle_dir = save_dir / "pickles" + if not novisual or export_pickle or export_crop or dataset_json_path is not None: + save_dir.mkdir(parents=True, exist_ok=True) # make dir + + # init image iterator + # TODO: rewrite this as iterator class as in https://github.com/ultralytics/yolov5/blob/d059d1da03aee9a3c0059895aa4c7c14b7f25a9e/utils/datasets.py#L178 + source_is_video = False + num_frames = None + if dataset_json_path: + coco: Coco = Coco.from_coco_dict_or_path(dataset_json_path) + image_iterator = [str(Path(source) / Path(coco_image.file_name)) for coco_image in coco.images] + coco_json = [] + elif os.path.isdir(source): + image_iterator = list_files( + directory=source, + contains=IMAGE_EXTENSIONS, + verbose=verbose, + ) + elif Path(source).suffix in VIDEO_EXTENSIONS: + source_is_video = True + read_video_frame, output_video_writer, video_file_name, num_frames = get_video_reader( + source, save_dir, frame_skip_interval, not novisual, view_video + ) + image_iterator = read_video_frame + else: + image_iterator = [source] + + # init model instance + time_start = time.time() + if detection_model is None: + detection_model = AutoDetectionModel.from_pretrained( + model_type=model_type, + model_path=model_path, + config_path=model_config_path, + confidence_threshold=model_confidence_threshold, + device=model_device, + category_mapping=model_category_mapping, + category_remapping=model_category_remapping, + load_at_init=False, + image_size=image_size, + ) + detection_model.load_model() + time_end = time.time() - time_start + durations_in_seconds["model_load"] = time_end + + # iterate over source images + durations_in_seconds["prediction"] = 0 + durations_in_seconds["slice"] = 0 + + input_type_str = "video frames" if source_is_video else "images" + for ind, image_path in enumerate( + tqdm(image_iterator, f"Performing inference on {input_type_str}", total=num_frames) + ): + # get filename + if source_is_video: + video_name = Path(source).stem + relative_filepath = video_name + "_frame_" + str(ind) + elif os.path.isdir(source): # preserve source folder structure in export + relative_filepath = str(Path(image_path)).split(str(Path(source)))[-1] + relative_filepath = relative_filepath[1:] if relative_filepath[0] == os.sep else relative_filepath + else: # no process if source is single file + relative_filepath = Path(image_path).name + + filename_without_extension = Path(relative_filepath).stem + + # load image + image_as_pil = read_image_as_pil(image_path) + + # perform prediction + if not no_sliced_prediction: + # get sliced prediction + prediction_result = get_sliced_prediction( + image=image_as_pil, + detection_model=detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=not no_standard_prediction, + postprocess_type=postprocess_type, + postprocess_match_metric=postprocess_match_metric, + postprocess_match_threshold=postprocess_match_threshold, + postprocess_class_agnostic=postprocess_class_agnostic, + verbose=1 if verbose else 0, + ) + object_prediction_list = prediction_result.object_prediction_list + durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"] + else: + # get standard prediction + prediction_result = get_prediction( + image=image_as_pil, + detection_model=detection_model, + shift_amount=[0, 0], + full_shape=None, + postprocess=None, + verbose=0, + ) + object_prediction_list = prediction_result.object_prediction_list + + durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"] + # Show prediction time + if verbose: + tqdm.write( + "Prediction time is: {:.2f} ms".format(prediction_result.durations_in_seconds["prediction"] * 1000) + ) + + if dataset_json_path: + if source_is_video is True: + raise NotImplementedError("Video input type not supported with coco formatted dataset json") + + # append predictions in coco format + for object_prediction in object_prediction_list: + coco_prediction = object_prediction.to_coco_prediction() + coco_prediction.image_id = coco.images[ind].id + coco_prediction_json = coco_prediction.json + if coco_prediction_json["bbox"]: + coco_json.append(coco_prediction_json) + if not novisual: + # convert ground truth annotations to object_prediction_list + coco_image: CocoImage = coco.images[ind] + object_prediction_gt_list: List[ObjectPrediction] = [] + for coco_annotation in coco_image.annotations: + coco_annotation_dict = coco_annotation.json + category_name = coco_annotation.category_name + full_shape = [coco_image.height, coco_image.width] + object_prediction_gt = ObjectPrediction.from_coco_annotation_dict( + annotation_dict=coco_annotation_dict, category_name=category_name, full_shape=full_shape + ) + object_prediction_gt_list.append(object_prediction_gt) + # export visualizations with ground truths + output_dir = str(visual_with_gt_dir / Path(relative_filepath).parent) + color = (0, 255, 0) # original annotations in green + result = visualize_object_predictions( + np.ascontiguousarray(image_as_pil), + object_prediction_list=object_prediction_gt_list, + rect_th=visual_bbox_thickness, + text_size=visual_text_size, + text_th=visual_text_thickness, + color=color, + output_dir=None, + file_name=None, + export_format=None, + ) + color = (255, 0, 0) # model predictions in red + _ = visualize_object_predictions( + result["image"], + object_prediction_list=object_prediction_list, + rect_th=visual_bbox_thickness, + text_size=visual_text_size, + text_th=visual_text_thickness, + color=color, + output_dir=output_dir, + file_name=filename_without_extension, + export_format=visual_export_format, + ) + + time_start = time.time() + # export prediction boxes + if export_crop: + output_dir = str(crop_dir / Path(relative_filepath).parent) + crop_object_predictions( + image=np.ascontiguousarray(image_as_pil), + object_prediction_list=object_prediction_list, + output_dir=output_dir, + file_name=filename_without_extension, + export_format=visual_export_format, + ) + # export prediction list as pickle + if export_pickle: + save_path = str(pickle_dir / Path(relative_filepath).parent / (filename_without_extension + ".pickle")) + save_pickle(data=object_prediction_list, save_path=save_path) + + # export visualization + if not novisual or view_video: + output_dir = str(visual_dir / Path(relative_filepath).parent) + result = visualize_object_predictions( + np.ascontiguousarray(image_as_pil), + object_prediction_list=object_prediction_list, + rect_th=visual_bbox_thickness, + text_size=visual_text_size, + text_th=visual_text_thickness, + output_dir=output_dir if not source_is_video else None, + file_name=filename_without_extension, + export_format=visual_export_format, + ) + if not novisual and source_is_video: # export video + output_video_writer.write(result["image"]) + + # render video inference + if view_video: + cv2.imshow("Prediction of {}".format(str(video_file_name)), result["image"]) + cv2.waitKey(1) + + time_end = time.time() - time_start + durations_in_seconds["export_files"] = time_end + + # export coco results + if dataset_json_path: + save_path = str(save_dir / "result.json") + save_json(coco_json, save_path) + + if not novisual or export_pickle or export_crop or dataset_json_path is not None: + print(f"Prediction results are successfully exported to {save_dir}") + + # print prediction duration + if verbose == 2: + print( + "Model loaded in", + durations_in_seconds["model_load"], + "seconds.", + ) + print( + "Slicing performed in", + durations_in_seconds["slice"], + "seconds.", + ) + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + if not novisual: + print( + "Exporting performed in", + durations_in_seconds["export_files"], + "seconds.", + ) + + if return_dict: + return {"export_dir": save_dir} + + +@check_requirements(["fiftyone"]) +def predict_fiftyone( + model_type: str = "mmdet", + model_path: str = None, + model_config_path: str = None, + model_confidence_threshold: float = 0.25, + model_device: str = None, + model_category_mapping: dict = None, + model_category_remapping: dict = None, + dataset_json_path: str = None, + image_dir: str = None, + no_standard_prediction: bool = False, + no_sliced_prediction: bool = False, + image_size: int = None, + slice_height: int = 256, + slice_width: int = 256, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + postprocess_type: str = "GREEDYNMM", + postprocess_match_metric: str = "IOS", + postprocess_match_threshold: float = 0.5, + postprocess_class_agnostic: bool = False, + verbose: int = 1, +): + """ + Performs prediction for all present images in given folder. + + Args: + model_type: str + mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'. + model_path: str + Path for the model weight + model_config_path: str + Path for the detection model config file + model_confidence_threshold: float + All predictions with score < model_confidence_threshold will be discarded. + model_device: str + Torch device, "cpu" or "cuda" + model_category_mapping: dict + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + model_category_remapping: dict: str to int + Remap category ids after performing inference + dataset_json_path: str + If coco file path is provided, detection results will be exported in coco json format. + image_dir: str + Folder directory that contains images or path of the image to be predicted. + no_standard_prediction: bool + Dont perform standard prediction. Default: False. + no_sliced_prediction: bool + Dont perform sliced prediction. Default: False. + image_size: int + Input image size for each inference (image is scaled by preserving asp. rat.). + slice_height: int + Height of each slice. Defaults to ``256``. + slice_width: int + Width of each slice. Defaults to ``256``. + overlap_height_ratio: float + Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window + of size 256 yields an overlap of 51 pixels). + Default to ``0.2``. + overlap_width_ratio: float + Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window + of size 256 yields an overlap of 51 pixels). + Default to ``0.2``. + postprocess_type: str + Type of the postprocess to be used after sliced inference while merging/eliminating predictions. + Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_threshold: float + Sliced predictions having higher iou than postprocess_match_threshold will be + postprocessed after sliced prediction. + postprocess_class_agnostic: bool + If True, postprocess will ignore category ids. + verbose: int + 0: no print + 1: print slice/prediction durations, number of slices, model loading/file exporting durations + """ + from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file, fo + + # assert prediction type + if no_standard_prediction and no_sliced_prediction: + raise ValueError("'no_standard_pred' and 'no_sliced_prediction' cannot be True at the same time.") + # for profiling + durations_in_seconds = dict() + + dataset = create_fiftyone_dataset_from_coco_file(image_dir, dataset_json_path) + + # init model instance + time_start = time.time() + detection_model = AutoDetectionModel.from_pretrained( + model_type=model_type, + model_path=model_path, + config_path=model_config_path, + confidence_threshold=model_confidence_threshold, + device=model_device, + category_mapping=model_category_mapping, + category_remapping=model_category_remapping, + load_at_init=False, + image_size=image_size, + ) + detection_model.load_model() + time_end = time.time() - time_start + durations_in_seconds["model_load"] = time_end + + # iterate over source images + durations_in_seconds["prediction"] = 0 + durations_in_seconds["slice"] = 0 + # Add predictions to samples + with fo.ProgressBar() as pb: + for sample in pb(dataset): + # perform prediction + if not no_sliced_prediction: + # get sliced prediction + prediction_result = get_sliced_prediction( + image=sample.filepath, + detection_model=detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=not no_standard_prediction, + postprocess_type=postprocess_type, + postprocess_match_threshold=postprocess_match_threshold, + postprocess_match_metric=postprocess_match_metric, + postprocess_class_agnostic=postprocess_class_agnostic, + verbose=verbose, + ) + durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"] + else: + # get standard prediction + prediction_result = get_prediction( + image=sample.filepath, + detection_model=detection_model, + shift_amount=[0, 0], + full_shape=None, + postprocess=None, + verbose=0, + ) + durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"] + + # Save predictions to dataset + sample[model_type] = fo.Detections(detections=prediction_result.to_fiftyone_detections()) + sample.save() + + # print prediction duration + if verbose == 1: + print( + "Model loaded in", + durations_in_seconds["model_load"], + "seconds.", + ) + print( + "Slicing performed in", + durations_in_seconds["slice"], + "seconds.", + ) + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + + # visualize results + session = fo.launch_app() + session.dataset = dataset + # Evaluate the predictions + results = dataset.evaluate_detections( + model_type, + gt_field="ground_truth", + eval_key="eval", + iou=postprocess_match_threshold, + compute_mAP=True, + ) + # Get the 10 most common classes in the dataset + counts = dataset.count_values("ground_truth.detections.label") + classes_top10 = sorted(counts, key=counts.get, reverse=True)[:10] + # Print a classification report for the top-10 classes + results.print_report(classes=classes_top10) + # Load the view on which we ran the `eval` evaluation + eval_view = dataset.load_evaluation_view("eval") + # Show samples with most false positives + session.view = eval_view.sort_by("eval_fp", reverse=True) + while 1: + time.sleep(3) diff --git a/sahi/prediction.py b/sahi/prediction.py new file mode 100644 index 0000000..5c55e09 --- /dev/null +++ b/sahi/prediction.py @@ -0,0 +1,222 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import copy +from typing import Dict, List, Optional, Union + +import numpy as np +from PIL import Image + +from sahi.annotation import ObjectAnnotation +from sahi.utils.coco import CocoAnnotation, CocoPrediction +from sahi.utils.cv import read_image_as_pil, visualize_object_predictions +from sahi.utils.file import Path + + +class PredictionScore: + def __init__(self, value: float): + """ + Arguments: + score: prediction score between 0 and 1 + """ + # if score is a numpy object, convert it to python variable + if type(value).__module__ == "numpy": + value = copy.deepcopy(value).tolist() + # set score + self.value = value + + def is_greater_than_threshold(self, threshold): + """ + Check if score is greater than threshold + """ + return self.value > threshold + + def __repr__(self): + return f"PredictionScore: " + + +class ObjectPrediction(ObjectAnnotation): + """ + Class for handling detection model predictions. + """ + + def __init__( + self, + bbox: Optional[List[int]] = None, + category_id: Optional[int] = None, + category_name: Optional[str] = None, + bool_mask: Optional[np.ndarray] = None, + score: Optional[float] = 0, + shift_amount: Optional[List[int]] = [0, 0], + full_shape: Optional[List[int]] = None, + ): + """ + Creates ObjectPrediction from bbox, score, category_id, category_name, bool_mask. + + Arguments: + bbox: list + [minx, miny, maxx, maxy] + score: float + Prediction score between 0 and 1 + category_id: int + ID of the object category + category_name: str + Name of the object category + bool_mask: np.ndarray + 2D boolean mask array. Should be None if model doesn't output segmentation mask. + shift_amount: list + To shift the box and mask predictions from sliced image + to full sized image, should be in the form of [shift_x, shift_y] + full_shape: list + Size of the full image after shifting, should be in + the form of [height, width] + """ + self.score = PredictionScore(score) + super().__init__( + bbox=bbox, + category_id=category_id, + bool_mask=bool_mask, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + def get_shifted_object_prediction(self): + """ + Returns shifted version ObjectPrediction. + Shifts bbox and mask coords. + Used for mapping sliced predictions over full image. + """ + if self.mask: + return ObjectPrediction( + bbox=self.bbox.get_shifted_box().to_voc_bbox(), + category_id=self.category.id, + score=self.score.value, + bool_mask=self.mask.get_shifted_mask().bool_mask, + category_name=self.category.name, + shift_amount=[0, 0], + full_shape=self.mask.get_shifted_mask().full_shape, + ) + else: + return ObjectPrediction( + bbox=self.bbox.get_shifted_box().to_voc_bbox(), + category_id=self.category.id, + score=self.score.value, + bool_mask=None, + category_name=self.category.name, + shift_amount=[0, 0], + full_shape=None, + ) + + def to_coco_prediction(self, image_id=None): + """ + Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation. + """ + if self.mask: + coco_prediction = CocoPrediction.from_coco_segmentation( + segmentation=self.mask.to_coco_segmentation(), + category_id=self.category.id, + category_name=self.category.name, + score=self.score.value, + image_id=image_id, + ) + else: + coco_prediction = CocoPrediction.from_coco_bbox( + bbox=self.bbox.to_coco_bbox(), + category_id=self.category.id, + category_name=self.category.name, + score=self.score.value, + image_id=image_id, + ) + return coco_prediction + + def to_fiftyone_detection(self, image_height: int, image_width: int): + """ + Returns fiftyone.Detection representation of ObjectPrediction. + """ + try: + import fiftyone as fo + except ImportError: + raise ImportError('Please run "pip install -U fiftyone" to install fiftyone first for fiftyone conversion.') + + x1, y1, x2, y2 = self.bbox.to_voc_bbox() + rel_box = [x1 / image_width, y1 / image_height, (x2 - x1) / image_width, (y2 - y1) / image_height] + fiftyone_detection = fo.Detection(label=self.category.name, bounding_box=rel_box, confidence=self.score.value) + return fiftyone_detection + + def __repr__(self): + return f"""ObjectPrediction< + bbox: {self.bbox}, + mask: {self.mask}, + score: {self.score}, + category: {self.category}>""" + + +class PredictionResult: + def __init__( + self, + object_prediction_list: List[ObjectPrediction], + image: Union[Image.Image, str, np.ndarray], + durations_in_seconds: Optional[Dict] = None, + ): + self.image: Image.Image = read_image_as_pil(image) + self.image_width, self.image_height = self.image.size + self.object_prediction_list: List[ObjectPrediction] = object_prediction_list + self.durations_in_seconds = durations_in_seconds + + def export_visuals( + self, export_dir: str, text_size: float = None, rect_th: int = None, file_name: str = "prediction_visual" + ): + """ + + Args: + export_dir: directory for resulting visualization to be exported + text_size: size of the category name over box + rect_th: rectangle thickness + file_name: saving name + Returns: + + """ + Path(export_dir).mkdir(parents=True, exist_ok=True) + visualize_object_predictions( + image=np.ascontiguousarray(self.image), + object_prediction_list=self.object_prediction_list, + rect_th=rect_th, + text_size=text_size, + text_th=None, + color=None, + output_dir=export_dir, + file_name=file_name, + export_format="png", + ) + + def to_coco_annotations(self): + coco_annotation_list = [] + for object_prediction in self.object_prediction_list: + coco_annotation_list.append(object_prediction.to_coco_prediction().json) + return coco_annotation_list + + def to_coco_predictions(self, image_id: Optional[int] = None): + coco_prediction_list = [] + for object_prediction in self.object_prediction_list: + coco_prediction_list.append(object_prediction.to_coco_prediction(image_id=image_id).json) + return coco_prediction_list + + def to_imantics_annotations(self): + imantics_annotation_list = [] + for object_prediction in self.object_prediction_list: + imantics_annotation_list.append(object_prediction.to_imantics_annotation()) + return imantics_annotation_list + + def to_fiftyone_detections(self): + try: + import fiftyone as fo + except ImportError: + raise ImportError('Please run "pip install -U fiftyone" to install fiftyone first for fiftyone conversion.') + + fiftyone_detection_list: List[fo.Detection] = [] + for object_prediction in self.object_prediction_list: + fiftyone_detection_list.append( + object_prediction.to_fiftyone_detection(image_height=self.image_height, image_width=self.image_width) + ) + return fiftyone_detection_list diff --git a/sahi/scripts/__init__.py b/sahi/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sahi/scripts/coco2fiftyone.py b/sahi/scripts/coco2fiftyone.py new file mode 100644 index 0000000..56e8851 --- /dev/null +++ b/sahi/scripts/coco2fiftyone.py @@ -0,0 +1,80 @@ +import time +from pathlib import Path +from typing import List + +import fire + +from sahi.utils.file import load_json + + +def main( + image_dir: str, + dataset_json_path: str, + *result_json_paths, + iou_thresh: float = 0.5, +): + """ + Args: + image_dir (str): directory for coco images + dataset_json_path (str): file path for the coco dataset json file + result_json_paths (str): one or more paths for the coco result json file + iou_thresh (float): iou threshold for coco evaluation + """ + + from sahi.utils.fiftyone import add_coco_labels, create_fiftyone_dataset_from_coco_file, fo + + coco_result_list = [] + result_name_list = [] + if result_json_paths: + for result_json_path in result_json_paths: + coco_result = load_json(result_json_path) + coco_result_list.append(coco_result) + + # use file names as fiftyone name, create unique names if duplicate + result_name_temp = Path(result_json_path).stem + result_name = result_name_temp + name_increment = 2 + while result_name in result_name_list: + result_name = result_name_temp + "_" + str(name_increment) + name_increment += 1 + result_name_list.append(result_name) + + dataset = create_fiftyone_dataset_from_coco_file(image_dir, dataset_json_path) + + # submit detections if coco result is given + if result_json_paths: + for result_name, coco_result in zip(result_name_list, coco_result_list): + add_coco_labels(dataset, result_name, coco_result, coco_id_field="gt_coco_id") + + # visualize results + session = fo.launch_app() + session.dataset = dataset + + # order by false positives if any coco result is given + if result_json_paths: + # Evaluate the predictions + first_coco_result_name = result_name_list[0] + _ = dataset.evaluate_detections( + first_coco_result_name, + gt_field="gt_detections", + eval_key=f"{first_coco_result_name}_eval", + iou=iou_thresh, + compute_mAP=False, + ) + # Get the 10 most common classes in the dataset + # counts = dataset.count_values("gt_detections.detections.label") + # classes_top10 = sorted(counts, key=counts.get, reverse=True)[:10] + # Print a classification report for the top-10 classes + # results.print_report(classes=classes_top10) + # Load the view on which we ran the `eval` evaluation + eval_view = dataset.load_evaluation_view(f"{first_coco_result_name}_eval") + # Show samples with most false positives + session.view = eval_view.sort_by(f"{first_coco_result_name}_eval_fp", reverse=True) + + print("SAHI has successfully launched a Fiftyone app " f"at http://localhost:{fo.config.default_app_port}") + while 1: + time.sleep(3) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/sahi/scripts/coco2yolov5.py b/sahi/scripts/coco2yolov5.py new file mode 100644 index 0000000..5c410c7 --- /dev/null +++ b/sahi/scripts/coco2yolov5.py @@ -0,0 +1,43 @@ +import fire + +from sahi.utils.coco import Coco +from sahi.utils.file import Path, increment_path + + +def main( + image_dir: str, + dataset_json_path: str, + train_split: str = 0.9, + project: str = "runs/coco2yolov5", + name: str = "exp", + seed: str = 1, +): + """ + Args: + images_dir (str): directory for coco images + dataset_json_path (str): file path for the coco json file to be converted + train_split (str): set the training split ratio + project (str): save results to project/name + name (str): save results to project/name" + seed (int): fix the seed for reproducibility + """ + + # increment run + save_dir = Path(increment_path(Path(project) / name, exist_ok=False)) + # load coco dict + coco = Coco.from_coco_dict_or_path( + coco_dict_or_path=dataset_json_path, + image_dir=image_dir, + ) + # export as yolov5 + coco.export_as_yolov5( + output_dir=str(save_dir), + train_split_rate=train_split, + numpy_seed=seed, + ) + + print(f"COCO to YOLOv5 conversion results are successfully exported to {save_dir}") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/sahi/scripts/coco_error_analysis.py b/sahi/scripts/coco_error_analysis.py new file mode 100644 index 0000000..d71668c --- /dev/null +++ b/sahi/scripts/coco_error_analysis.py @@ -0,0 +1,468 @@ +import copy +import os +from multiprocessing import Pool +from pathlib import Path +from typing import List + +import fire +import numpy as np + +COLOR_PALETTE = np.vstack( + [ + np.array([0.8, 0.8, 0.8]), + np.array([0.6, 0.6, 0.6]), + np.array([0.31, 0.51, 0.74]), + np.array([0.75, 0.31, 0.30]), + np.array([0.36, 0.90, 0.38]), + np.array([0.50, 0.39, 0.64]), + np.array([1, 0.6, 0]), + ] +) + + +def _makeplot(rs, ps, outDir, class_name, iou_type): + import matplotlib.pyplot as plt + + export_path_list = [] + + areaNames = ["allarea", "small", "medium", "large"] + types = ["C75", "C50", "Loc", "Sim", "Oth", "BG", "FN"] + for i in range(len(areaNames)): + area_ps = ps[..., i, 0] + figure_title = iou_type + "-" + class_name + "-" + areaNames[i] + aps = [] + ps_curve = [] + for ps_ in area_ps: + # calculate precision recal curves + if ps_.ndim > 1: + ps_mean = np.zeros((ps_.shape[0],)) + for ind, ps_threshold in enumerate(ps_): + ps_mean[ind] = ps_threshold[ps_threshold > -1].mean() + ps_curve.append(ps_mean) + else: + ps_curve.append(ps_) + # calculate ap + if len(ps_[ps_ > -1]): + ap = ps_[ps_ > -1].mean() + else: + ap = np.array(0) + aps.append(ap) + ps_curve.insert(0, np.zeros(ps_curve[0].shape)) + fig = plt.figure() + ax = plt.subplot(111) + for k in range(len(types)): + ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5) + ax.fill_between( + rs, + ps_curve[k], + ps_curve[k + 1], + color=COLOR_PALETTE[k], + label=str(f"[{aps[k]:.3f}]" + types[k]), + ) + plt.xlabel("recall") + plt.ylabel("precision") + plt.xlim(0, 1.0) + plt.ylim(0, 1.0) + plt.title(figure_title) + plt.legend() + # plt.show() + export_path = str(Path(outDir) / f"{figure_title}.png") + fig.savefig(export_path) + plt.close(fig) + + export_path_list.append(export_path) + return export_path_list + + +def _autolabel(ax, rects, is_percent=True): + """Attach a text label above each bar in *rects*, displaying its height.""" + for rect in rects: + height = rect.get_height() + if is_percent and height > 0 and height <= 1: # for percent values + text_label = "{:2.0f}".format(height * 100) + else: + text_label = "{:2.0f}".format(height) + ax.annotate( + text_label, + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha="center", + va="bottom", + fontsize="x-small", + ) + + +def _makebarplot(rs, ps, outDir, class_name, iou_type): + import matplotlib.pyplot as plt + + areaNames = ["allarea", "small", "medium", "large"] + types = ["C75", "C50", "Loc", "Sim", "Oth", "BG", "FN"] + fig, ax = plt.subplots() + x = np.arange(len(areaNames)) # the areaNames locations + width = 0.60 # the width of the bars + rects_list = [] + figure_title = iou_type + "-" + class_name + "-" + "ap bar plot" + for k in range(len(types) - 1): + type_ps = ps[k, ..., 0] + # calculate ap + aps = [] + for ps_ in type_ps.T: + if len(ps_[ps_ > -1]): + ap = ps_[ps_ > -1].mean() + else: + ap = np.array(0) + aps.append(ap) + # create bars + rects_list.append( + ax.bar( + x - width / 2 + (k + 1) * width / len(types), + aps, + width / len(types), + label=types[k], + color=COLOR_PALETTE[k], + ) + ) + + # Add some text for labels, title and custom x-axis tick labels, etc. + ax.set_ylabel("Mean Average Precision (mAP)") + ax.set_title(figure_title) + ax.set_xticks(x) + ax.set_xticklabels(areaNames) + ax.legend() + + # Add score texts over bars + for rects in rects_list: + _autolabel(ax, rects) + + # Save plot + export_path = str(Path(outDir) / f"{figure_title}.png") + fig.savefig(export_path) + plt.close(fig) + + return export_path + + +def _get_gt_area_group_numbers(cocoEval): + areaRng = cocoEval.params.areaRng + areaRngStr = [str(aRng) for aRng in areaRng] + areaRngLbl = cocoEval.params.areaRngLbl + areaRngStr2areaRngLbl = dict(zip(areaRngStr, areaRngLbl)) + areaRngLbl2Number = dict.fromkeys(areaRngLbl, 0) + for evalImg in cocoEval.evalImgs: + if evalImg: + for gtIgnore in evalImg["gtIgnore"]: + if not gtIgnore: + aRngLbl = areaRngStr2areaRngLbl[str(evalImg["aRng"])] + areaRngLbl2Number[aRngLbl] += 1 + return areaRngLbl2Number + + +def _make_gt_area_group_numbers_plot(cocoEval, outDir, verbose=True): + import matplotlib.pyplot as plt + + areaRngLbl2Number = _get_gt_area_group_numbers(cocoEval) + areaRngLbl = areaRngLbl2Number.keys() + if verbose: + print("number of annotations per area group:", areaRngLbl2Number) + + # Init figure + fig, ax = plt.subplots() + x = np.arange(len(areaRngLbl)) # the areaNames locations + width = 0.60 # the width of the bars + figure_title = "number of annotations per area group" + + rects = ax.bar(x, areaRngLbl2Number.values(), width) + + # Add some text for labels, title and custom x-axis tick labels, etc. + ax.set_ylabel("Number of annotations") + ax.set_title(figure_title) + ax.set_xticks(x) + ax.set_xticklabels(areaRngLbl) + + # Add score texts over bars + _autolabel(ax, rects, is_percent=False) + + # Save plot + export_path = str(Path(outDir) / f"{figure_title}.png") + fig.tight_layout() + fig.savefig(export_path) + plt.close(fig) + + return export_path + + +def _make_gt_area_histogram_plot(cocoEval, outDir): + import matplotlib.pyplot as plt + + n_bins = 100 + areas = [ann["area"] for ann in cocoEval.cocoGt.anns.values()] + + # init figure + figure_title = "gt annotation areas histogram plot" + fig, ax = plt.subplots() + + # Set the number of bins + ax.hist(np.sqrt(areas), bins=n_bins) + + # Add some text for labels, title and custom x-axis tick labels, etc. + ax.set_xlabel("Squareroot Area") + ax.set_ylabel("Number of annotations") + ax.set_title(figure_title) + + # Save plot + export_path = str(Path(outDir) / f"{figure_title}.png") + fig.tight_layout() + fig.savefig(export_path) + plt.close(fig) + + return export_path + + +def _analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type, areas=None, max_detections=None, COCOeval=None): + nm = cocoGt.loadCats(catId)[0] + print(f'--------------analyzing {k + 1}-{nm["name"]}---------------') + ps_ = {} + dt = copy.deepcopy(cocoDt) + nm = cocoGt.loadCats(catId)[0] + imgIds = cocoGt.getImgIds() + dt_anns = dt.dataset["annotations"] + select_dt_anns = [] + for ann in dt_anns: + if ann["category_id"] == catId: + select_dt_anns.append(ann) + dt.dataset["annotations"] = select_dt_anns + dt.createIndex() + # compute precision but ignore superclass confusion + gt = copy.deepcopy(cocoGt) + child_catIds = gt.getCatIds(supNms=[nm["supercategory"]]) + for idx, ann in enumerate(gt.dataset["annotations"]): + if ann["category_id"] in child_catIds and ann["category_id"] != catId: + gt.dataset["annotations"][idx]["ignore"] = 1 + gt.dataset["annotations"][idx]["iscrowd"] = 1 + gt.dataset["annotations"][idx]["category_id"] = catId + cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type) + cocoEval.params.imgIds = imgIds + cocoEval.params.maxDets = [max_detections] + cocoEval.params.iouThrs = [0.1] + cocoEval.params.useCats = 1 + if areas: + cocoEval.params.areaRng = [ + [0 ** 2, areas[2]], + [0 ** 2, areas[0]], + [areas[0], areas[1]], + [areas[1], areas[2]], + ] + cocoEval.evaluate() + cocoEval.accumulate() + ps_supercategory = cocoEval.eval["precision"][0, :, catId, :, :] + ps_["ps_supercategory"] = ps_supercategory + # compute precision but ignore any class confusion + gt = copy.deepcopy(cocoGt) + for idx, ann in enumerate(gt.dataset["annotations"]): + if ann["category_id"] != catId: + gt.dataset["annotations"][idx]["ignore"] = 1 + gt.dataset["annotations"][idx]["iscrowd"] = 1 + gt.dataset["annotations"][idx]["category_id"] = catId + cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type) + cocoEval.params.imgIds = imgIds + cocoEval.params.maxDets = [max_detections] + cocoEval.params.iouThrs = [0.1] + cocoEval.params.useCats = 1 + if areas: + cocoEval.params.areaRng = [ + [0 ** 2, areas[2]], + [0 ** 2, areas[0]], + [areas[0], areas[1]], + [areas[1], areas[2]], + ] + cocoEval.evaluate() + cocoEval.accumulate() + ps_allcategory = cocoEval.eval["precision"][0, :, catId, :, :] + ps_["ps_allcategory"] = ps_allcategory + return k, ps_ + + +def _analyse_results( + res_file, + ann_file, + res_types, + out_dir=None, + extraplots=None, + areas=None, + max_detections=500, + COCO=None, + COCOeval=None, +): + for res_type in res_types: + if res_type not in ["bbox", "segm"]: + raise ValueError(f"res_type {res_type} is not supported") + if areas is not None: + if len(areas) != 3: + raise ValueError("3 integers should be specified as areas,representing 3 area regions") + + if out_dir is None: + out_dir = Path(res_file).parent + out_dir = str(out_dir / "coco_error_analysis") + + directory = os.path.dirname(out_dir + "/") + if not os.path.exists(directory): + print(f"-------------create {out_dir}-----------------") + os.makedirs(directory) + + result_type_to_export_paths = {} + + cocoGt = COCO(ann_file) + cocoDt = cocoGt.loadRes(res_file) + imgIds = cocoGt.getImgIds() + for res_type in res_types: + res_out_dir = out_dir + "/" + res_type + "/" + res_directory = os.path.dirname(res_out_dir) + if not os.path.exists(res_directory): + print(f"-------------create {res_out_dir}-----------------") + os.makedirs(res_directory) + iou_type = res_type + cocoEval = COCOeval(copy.deepcopy(cocoGt), copy.deepcopy(cocoDt), iou_type) + cocoEval.params.imgIds = imgIds + cocoEval.params.iouThrs = [0.75, 0.5, 0.1] + cocoEval.params.maxDets = [max_detections] + if areas is not None: + cocoEval.params.areaRng = [ + [0 ** 2, areas[2]], + [0 ** 2, areas[0]], + [areas[0], areas[1]], + [areas[1], areas[2]], + ] + cocoEval.evaluate() + cocoEval.accumulate() + + present_cat_ids = [] + catIds = cocoGt.getCatIds() + for k, catId in enumerate(catIds): + image_ids = cocoGt.getImgIds(catIds=[catId]) + if len(image_ids) != 0: + present_cat_ids.append(catId) + matrix_shape = list(cocoEval.eval["precision"].shape) + matrix_shape[2] = len(present_cat_ids) + ps = np.zeros(matrix_shape) + + for k, catId in enumerate(present_cat_ids): + ps[:, :, k, :, :] = cocoEval.eval["precision"][:, :, catId, :, :] + ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))]) + + recThrs = cocoEval.params.recThrs + with Pool(processes=48) as pool: + args = [ + (k, cocoDt, cocoGt, catId, iou_type, areas, max_detections, COCOeval) + for k, catId in enumerate(present_cat_ids) + ] + analyze_results = pool.starmap(_analyze_individual_category, args) + + classname_to_export_path_list = {} + for k, catId in enumerate(present_cat_ids): + + nm = cocoGt.loadCats(catId)[0] + print(f'--------------saving {k + 1}-{nm["name"]}---------------') + analyze_result = analyze_results[k] + if k != analyze_result[0]: + raise ValueError(f"k {k} != analyze_result[0] {analyze_result[0]}") + ps_supercategory = analyze_result[1]["ps_supercategory"] + ps_allcategory = analyze_result[1]["ps_allcategory"] + # compute precision but ignore superclass confusion + ps[3, :, k, :, :] = ps_supercategory + # compute precision but ignore any class confusion + ps[4, :, k, :, :] = ps_allcategory + # fill in background and false negative errors and plot + ps[5, :, k, :, :][ps[4, :, k, :, :] == -1] = -1 + ps[5, :, k, :, :][ps[4, :, k, :, :] > 0] = 1 + ps[6, :, k, :, :] = 1.0 + + normalized_class_name = nm["name"].replace("/", "_").replace(os.sep, "_") + + curve_export_path_list = _makeplot(recThrs, ps[:, :, k], res_out_dir, normalized_class_name, iou_type) + + if extraplots: + bar_plot_path = _makebarplot(recThrs, ps[:, :, k], res_out_dir, normalized_class_name, iou_type) + else: + bar_plot_path = None + classname_to_export_path_list[nm["name"]] = { + "curves": curve_export_path_list, + "bar_plot": bar_plot_path, + } + + curve_export_path_list = _makeplot(recThrs, ps, res_out_dir, "allclass", iou_type) + if extraplots: + bar_plot_path = _makebarplot(recThrs, ps, res_out_dir, "allclass", iou_type) + gt_area_group_numbers_plot_path = _make_gt_area_group_numbers_plot( + cocoEval=cocoEval, outDir=res_out_dir, verbose=True + ) + gt_area_histogram_plot_path = _make_gt_area_histogram_plot(cocoEval=cocoEval, outDir=res_out_dir) + else: + bar_plot_path, gt_area_group_numbers_plot_path, gt_area_histogram_plot_path = None, None, None + + result_type_to_export_paths[res_type] = { + "classwise": classname_to_export_path_list, + "overall": { + "bar_plot": bar_plot_path, + "curves": curve_export_path_list, + "gt_area_group_numbers": gt_area_group_numbers_plot_path, + "gt_area_histogram": gt_area_histogram_plot_path, + }, + } + print(f"COCO error analysis results are successfully exported to {out_dir}") + + return result_type_to_export_paths + + +def analyse( + dataset_json_path: str, + result_json_path: str, + out_dir: str = None, + type: str = "bbox", + no_extraplots: bool = False, + areas: List[int] = [1024, 9216, 10000000000], + max_detections: int = 500, + return_dict: bool = False, +): + """ + Args: + dataset_json_path (str): file path for the coco dataset json file + result_json_paths (str): file path for the coco result json file + out_dir (str): dir to save analyse result images + no_extraplots (bool): dont export export extra bar/stat plots + type (str): 'bbox' or 'mask' + areas (List[int]): area regions for coco evaluation calculations + max_detections (int): Maximum number of detections to consider for AP alculation. Default: 500 + return_dict (bool): If True, returns a dict export paths. + """ + try: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'Please run "pip install -U pycocotools" ' "to install pycocotools first for coco evaluation." + ) + try: + import matplotlib.pyplot as plt + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'Please run "pip install -U matplotlib" ' "to install matplotlib first for visualization." + ) + + result = _analyse_results( + result_json_path, + dataset_json_path, + res_types=[type], + out_dir=out_dir, + extraplots=not no_extraplots, + areas=areas, + max_detections=max_detections, + COCO=COCO, + COCOeval=COCOeval, + ) + if return_dict: + return result + + +if __name__ == "__main__": + fire.Fire(analyse) diff --git a/sahi/scripts/coco_evaluation.py b/sahi/scripts/coco_evaluation.py new file mode 100644 index 0000000..ff85008 --- /dev/null +++ b/sahi/scripts/coco_evaluation.py @@ -0,0 +1,395 @@ +import itertools +import json +import warnings +from collections import OrderedDict +from pathlib import Path +from typing import List, Union + +import fire +import numpy as np +from terminaltables import AsciiTable + + +def _cocoeval_summarize( + cocoeval, ap=1, iouThr=None, catIdx=None, areaRng="all", maxDets=100, catName="", nameStrLen=None +): + p = cocoeval.params + if catName: + iStr = " {:<18} {} {:<{nameStrLen}} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + nameStr = catName + else: + iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + titleStr = "Average Precision" if ap == 1 else "Average Recall" + typeStr = "(AP)" if ap == 1 else "(AR)" + iouStr = "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) if iouThr is None else "{:0.2f}".format(iouThr) + + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = cocoeval.eval["precision"] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + if catIdx is not None: + s = s[:, :, catIdx, aind, mind] + else: + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = cocoeval.eval["recall"] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + if catIdx is not None: + s = s[:, catIdx, aind, mind] + else: + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + if catName: + print(iStr.format(titleStr, typeStr, nameStr, iouStr, areaRng, maxDets, mean_s, nameStrLen=nameStrLen)) + else: + print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) + return mean_s + + +def evaluate_core( + dataset_path, + result_path, + metric: str = "bbox", + classwise: bool = False, + max_detections: int = 500, + iou_thrs=None, + metric_items=None, + out_dir: str = None, + areas: List[int] = [1024, 9216, 10000000000], + COCO=None, + COCOeval=None, +): + """Evaluation in COCO protocol. + Args: + dataset_path (str): COCO dataset json path. + result_path (str): COCO result json path. + metric (str | list[str]): Metrics to be evaluated. Options are + 'bbox', 'segm', 'proposal'. + classwise (bool): Whether to evaluating the AP for each class. + max_detections (int): Maximum number of detections to consider for AP + calculation. + Default: 500 + iou_thrs (List[float], optional): IoU threshold used for + evaluating recalls/mAPs. If set to a list, the average of all + IoUs will also be computed. If not specified, [0.50, 0.55, + 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used. + Default: None. + metric_items (list[str] | str, optional): Metric items that will + be returned. If not specified, ``['AR@10', 'AR@100', + 'AR@500', 'AR_s@500', 'AR_m@500', 'AR_l@500' ]`` will be + used when ``metric=='proposal'``, ``['mAP', 'mAP50', 'mAP75', + 'mAP_s', 'mAP_m', 'mAP_l', 'mAP50_s', 'mAP50_m', 'mAP50_l']`` + will be used when ``metric=='bbox' or metric=='segm'``. + out_dir (str): Directory to save evaluation result json. + areas (List[int]): area regions for coco evaluation calculations + Returns: + dict: + eval_results (dict[str, float]): COCO style evaluation metric. + export_path (str): Path for the exported eval result json. + + """ + + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ["bbox", "segm"] + for metric in metrics: + if metric not in allowed_metrics: + raise KeyError(f"metric {metric} is not supported") + if iou_thrs is None: + iou_thrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True) + if metric_items is not None: + if not isinstance(metric_items, list): + metric_items = [metric_items] + if areas is not None: + if len(areas) != 3: + raise ValueError("3 integers should be specified as areas, representing 3 area regions") + eval_results = OrderedDict() + cocoGt = COCO(dataset_path) + cat_ids = list(cocoGt.cats.keys()) + for metric in metrics: + msg = f"Evaluating {metric}..." + msg = "\n" + msg + print(msg) + + iou_type = metric + with open(result_path) as json_file: + results = json.load(json_file) + try: + if iou_type == "segm": + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in results: + x.pop("bbox") + warnings.simplefilter("once") + warnings.warn( + 'The key "bbox" is deleted for more accurate mask AP ' + "of small/medium/large instances since v2.12.0. This " + "does not change the overall mAP calculation.", + UserWarning, + ) + cocoDt = cocoGt.loadRes(results) + except IndexError: + print("The testing results of the whole dataset is empty.") + break + + cocoEval = COCOeval(cocoGt, cocoDt, iou_type) + if areas is not None: + cocoEval.params.areaRng = [ + [0 ** 2, areas[2]], + [0 ** 2, areas[0]], + [areas[0], areas[1]], + [areas[1], areas[2]], + ] + cocoEval.params.catIds = cat_ids + cocoEval.params.maxDets = [max_detections] + cocoEval.params.iouThrs = ( + [iou_thrs] if not isinstance(iou_thrs, list) and not isinstance(iou_thrs, np.ndarray) else iou_thrs + ) + # mapping of cocoEval.stats + coco_metric_names = { + "mAP": 0, + "mAP75": 1, + "mAP50": 2, + "mAP_s": 3, + "mAP_m": 4, + "mAP_l": 5, + "mAP50_s": 6, + "mAP50_m": 7, + "mAP50_l": 8, + "AR_s": 9, + "AR_m": 10, + "AR_l": 11, + } + if metric_items is not None: + for metric_item in metric_items: + if metric_item not in coco_metric_names: + raise KeyError(f"metric item {metric_item} is not supported") + + cocoEval.evaluate() + cocoEval.accumulate() + # calculate mAP50_s/m/l + mAP = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="all", maxDets=max_detections) + mAP50 = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="all", maxDets=max_detections) + mAP75 = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.75, areaRng="all", maxDets=max_detections) + mAP50_s = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="small", maxDets=max_detections) + mAP50_m = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="medium", maxDets=max_detections) + mAP50_l = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="large", maxDets=max_detections) + mAP_s = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="small", maxDets=max_detections) + mAP_m = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="medium", maxDets=max_detections) + mAP_l = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="large", maxDets=max_detections) + AR_s = _cocoeval_summarize(cocoEval, ap=0, iouThr=None, areaRng="small", maxDets=max_detections) + AR_m = _cocoeval_summarize(cocoEval, ap=0, iouThr=None, areaRng="medium", maxDets=max_detections) + AR_l = _cocoeval_summarize(cocoEval, ap=0, iouThr=None, areaRng="large", maxDets=max_detections) + cocoEval.stats = np.append( + [mAP, mAP75, mAP50, mAP_s, mAP_m, mAP_l, mAP50_s, mAP50_m, mAP50_l, AR_s, AR_m, AR_l], 0 + ) + + if classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = cocoEval.eval["precision"] + # precision: (iou, recall, cls, area range, max dets) + if len(cat_ids) != precisions.shape[2]: + raise ValueError( + f"The number of categories {len(cat_ids)} is not equal to the number of precisions {precisions.shape[2]}" + ) + max_cat_name_len = 0 + for idx, catId in enumerate(cat_ids): + nm = cocoGt.loadCats(catId)[0] + cat_name_len = len(nm["name"]) + max_cat_name_len = cat_name_len if cat_name_len > max_cat_name_len else max_cat_name_len + + results_per_category = [] + for idx, catId in enumerate(cat_ids): + # skip if no image with this category + image_ids = cocoGt.getImgIds(catIds=[catId]) + if len(image_ids) == 0: + continue + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = cocoGt.loadCats(catId)[0] + ap = _cocoeval_summarize( + cocoEval, + ap=1, + catIdx=idx, + areaRng="all", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + ap_s = _cocoeval_summarize( + cocoEval, + ap=1, + catIdx=idx, + areaRng="small", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + ap_m = _cocoeval_summarize( + cocoEval, + ap=1, + catIdx=idx, + areaRng="medium", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + ap_l = _cocoeval_summarize( + cocoEval, + ap=1, + catIdx=idx, + areaRng="large", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + ap50 = _cocoeval_summarize( + cocoEval, + ap=1, + iouThr=0.5, + catIdx=idx, + areaRng="all", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + ap50_s = _cocoeval_summarize( + cocoEval, + ap=1, + iouThr=0.5, + catIdx=idx, + areaRng="small", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + ap50_m = _cocoeval_summarize( + cocoEval, + ap=1, + iouThr=0.5, + catIdx=idx, + areaRng="medium", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + ap50_l = _cocoeval_summarize( + cocoEval, + ap=1, + iouThr=0.5, + catIdx=idx, + areaRng="large", + maxDets=max_detections, + catName=nm["name"], + nameStrLen=max_cat_name_len, + ) + results_per_category.append((f'{metric}_{nm["name"]}_mAP', f"{float(ap):0.3f}")) + results_per_category.append((f'{metric}_{nm["name"]}_mAP_s', f"{float(ap_s):0.3f}")) + results_per_category.append((f'{metric}_{nm["name"]}_mAP_m', f"{float(ap_m):0.3f}")) + results_per_category.append((f'{metric}_{nm["name"]}_mAP_l', f"{float(ap_l):0.3f}")) + results_per_category.append((f'{metric}_{nm["name"]}_mAP50', f"{float(ap50):0.3f}")) + results_per_category.append((f'{metric}_{nm["name"]}_mAP50_s', f"{float(ap50_s):0.3f}")) + results_per_category.append((f'{metric}_{nm["name"]}_mAP50_m', f"{float(ap50_m):0.3f}")) + results_per_category.append((f'{metric}_{nm["name"]}_mAP50_l', f"{float(ap50_l):0.3f}")) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + headers = ["category", "AP"] * (num_columns // 2) + results_2d = itertools.zip_longest(*[results_flatten[i::num_columns] for i in range(num_columns)]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + print("\n" + table.table) + + if metric_items is None: + metric_items = ["mAP", "mAP50", "mAP75", "mAP_s", "mAP_m", "mAP_l", "mAP50_s", "mAP50_m", "mAP50_l"] + + for metric_item in metric_items: + key = f"{metric}_{metric_item}" + val = float(f"{cocoEval.stats[coco_metric_names[metric_item]]:.3f}") + eval_results[key] = val + ap = cocoEval.stats + eval_results[f"{metric}_mAP_copypaste"] = ( + f"{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} " + f"{ap[4]:.3f} {ap[5]:.3f} {ap[6]:.3f} {ap[7]:.3f} " + f"{ap[8]:.3f}" + ) + if classwise: + eval_results["results_per_category"] = {key: value for key, value in results_per_category} + # set save path + if not out_dir: + out_dir = Path(result_path).parent + Path(out_dir).mkdir(parents=True, exist_ok=True) + export_path = str(Path(out_dir) / "eval.json") + # export as json + with open(export_path, "w", encoding="utf-8") as outfile: + json.dump(eval_results, outfile, indent=4, separators=(",", ":")) + print(f"COCO evaluation results are successfully exported to {export_path}") + return {"eval_results": eval_results, "export_path": export_path} + + +def evaluate( + dataset_json_path: str, + result_json_path: str, + out_dir: str = None, + type: str = "bbox", + classwise: bool = False, + max_detections: int = 500, + iou_thrs: Union[List[float], float] = None, + areas: List[int] = [1024, 9216, 10000000000], + return_dict: bool = False, +): + """ + Args: + dataset_json_path (str): file path for the coco dataset json file + result_json_path (str): file path for the coco result json file + out_dir (str): dir to save eval result + type (bool): 'bbox' or 'segm' + classwise (bool): whether to evaluate the AP for each class + max_detections (int): Maximum number of detections to consider for AP alculation. Default: 500 + iou_thrs (float): IoU threshold used for evaluating recalls/mAPs + areas (List[int]): area regions for coco evaluation calculations + return_dict (bool): If True, returns a dict with 'eval_results' 'export_path' fields. + """ + try: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'Please run "pip install -U pycocotools" ' "to install pycocotools first for coco evaluation." + ) + + # perform coco eval + result = evaluate_core( + dataset_json_path, + result_json_path, + type, + classwise, + max_detections, + iou_thrs, + out_dir=out_dir, + areas=areas, + COCO=COCO, + COCOeval=COCOeval, + ) + if return_dict: + return result + + +if __name__ == "__main__": + fire.Fire(evaluate) diff --git a/sahi/scripts/predict.py b/sahi/scripts/predict.py new file mode 100644 index 0000000..f693817 --- /dev/null +++ b/sahi/scripts/predict.py @@ -0,0 +1,853 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import logging +import os +import time +import warnings +from typing import Dict, List, Optional + +import numpy as np +from PIL import Image +from tqdm import tqdm + +from sahi.auto_model import AutoDetectionModel +from sahi.model import DetectionModel +from sahi.postprocess.combine import ( + GreedyNMMPostprocess, + LSNMSPostprocess, + NMMPostprocess, + NMSPostprocess, + PostprocessPredictions, +) +from sahi.prediction import ObjectPrediction, PredictionResult +from sahi.slicing import slice_image +from sahi.utils.coco import Coco, CocoImage +from sahi.utils.cv import ( + IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS, + crop_object_predictions, + cv2, + get_video_reader, + read_image_as_pil, + visualize_object_predictions, +) +from sahi.utils.file import Path, increment_path, list_files, save_json, save_pickle +from sahi.utils.import_utils import check_requirements + +POSTPROCESS_NAME_TO_CLASS = { + "GREEDYNMM": GreedyNMMPostprocess, + "NMM": NMMPostprocess, + "NMS": NMSPostprocess, + "LSNMS": LSNMSPostprocess, +} + +LOW_MODEL_CONFIDENCE = 0.1 + + +logger = logging.getLogger(__name__) + + +def get_prediction( + image, + detection_model, + image_size: int = None, + shift_amount: list = [0, 0], + full_shape=None, + postprocess: Optional[PostprocessPredictions] = None, + verbose: int = 0, +) -> PredictionResult: + """ + Function for performing prediction for given image using given detection_model. + + Arguments: + image: str or np.ndarray + Location of image or numpy image matrix to slice + detection_model: model.DetectionMode + image_size: int + Inference input size. + shift_amount: List + To shift the box and mask predictions from sliced image to full + sized image, should be in the form of [shift_x, shift_y] + full_shape: List + Size of the full image, should be in the form of [height, width] + postprocess: sahi.postprocess.combine.PostprocessPredictions + verbose: int + 0: no print (default) + 1: print prediction duration + + Returns: + A dict with fields: + object_prediction_list: a list of ObjectPrediction + durations_in_seconds: a dict containing elapsed times for profiling + """ + if image_size is not None: + warnings.warn("Set 'image_size' at DetectionModel init.", DeprecationWarning) + + durations_in_seconds = dict() + + # read image as pil + image_as_pil = read_image_as_pil(image) + # get prediction + time_start = time.time() + detection_model.perform_inference(np.ascontiguousarray(image_as_pil), image_size=image_size) + time_end = time.time() - time_start + durations_in_seconds["prediction"] = time_end + + # process prediction + time_start = time.time() + # works only with 1 batch + detection_model.convert_original_predictions( + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list + + # postprocess matching predictions + if postprocess is not None: + object_prediction_list = postprocess(object_prediction_list) + + time_end = time.time() - time_start + durations_in_seconds["postprocess"] = time_end + + if verbose == 1: + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + + return PredictionResult( + image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds + ) + + +def get_sliced_prediction( + image, + detection_model=None, + image_size: int = None, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + perform_standard_pred: bool = True, + postprocess_type: str = "GREEDYNMM", + postprocess_match_metric: str = "IOS", + postprocess_match_threshold: float = 0.5, + postprocess_class_agnostic: bool = False, + verbose: int = 1, + merge_buffer_length: int = None, +) -> PredictionResult: + """ + Function for slice image + get predicion for each slice + combine predictions in full image. + + Args: + image: str or np.ndarray + Location of image or numpy image matrix to slice + detection_model: model.DetectionModel + image_size: int + Input image size for each inference (image is scaled by preserving asp. rat.). + slice_height: int + Height of each slice. Defaults to ``512``. + slice_width: int + Width of each slice. Defaults to ``512``. + overlap_height_ratio: float + Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + overlap_width_ratio: float + Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + perform_standard_pred: bool + Perform a standard prediction on top of sliced predictions to increase large object + detection accuracy. Default: True. + postprocess_type: str + Type of the postprocess to be used after sliced inference while merging/eliminating predictions. + Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_threshold: float + Sliced predictions having higher iou than postprocess_match_threshold will be + postprocessed after sliced prediction. + postprocess_class_agnostic: bool + If True, postprocess will ignore category ids. + verbose: int + 0: no print + 1: print number of slices (default) + 2: print number of slices and slice/prediction durations + merge_buffer_length: int + The length of buffer for slices to be used during sliced prediction, which is suitable for low memory. + It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered. + scenario. See [the discussion](https://github.com/obss/sahi/pull/445). + + Returns: + A Dict with fields: + object_prediction_list: a list of sahi.prediction.ObjectPrediction + durations_in_seconds: a dict containing elapsed times for profiling + """ + if image_size is not None: + warnings.warn("Set 'image_size' at DetectionModel init.", DeprecationWarning) + + # for profiling + durations_in_seconds = dict() + + # currently only 1 batch supported + num_batch = 1 + + # create slices from full image + time_start = time.time() + slice_image_result = slice_image( + image=image, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + ) + num_slices = len(slice_image_result) + time_end = time.time() - time_start + durations_in_seconds["slice"] = time_end + + # init match postprocess instance + if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys(): + raise ValueError( + f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}" + ) + elif postprocess_type == "UNIONMERGE": + # deprecated in v0.9.3 + raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.") + postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type] + postprocess = postprocess_constructor( + match_threshold=postprocess_match_threshold, + match_metric=postprocess_match_metric, + class_agnostic=postprocess_class_agnostic, + ) + + # create prediction input + num_group = int(num_slices / num_batch) + if verbose == 1 or verbose == 2: + tqdm.write(f"Performing prediction on {num_slices} number of slices.") + object_prediction_list = [] + # perform sliced prediction + for group_ind in range(num_group): + # prepare batch (currently supports only 1 batch) + image_list = [] + shift_amount_list = [] + for image_ind in range(num_batch): + image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) + shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind]) + # perform batch prediction + prediction_result = get_prediction( + image=image_list[0], + detection_model=detection_model, + image_size=image_size, + shift_amount=shift_amount_list[0], + full_shape=[ + slice_image_result.original_image_height, + slice_image_result.original_image_width, + ], + ) + # convert sliced predictions to full predictions + for object_prediction in prediction_result.object_prediction_list: + if object_prediction: # if not empty + object_prediction_list.append(object_prediction.get_shifted_object_prediction()) + + # merge matching predictions during sliced prediction + if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length: + object_prediction_list = postprocess(object_prediction_list) + + # perform standard prediction + if num_slices > 1 and perform_standard_pred: + prediction_result = get_prediction( + image=image, + detection_model=detection_model, + image_size=image_size, + shift_amount=[0, 0], + full_shape=None, + postprocess=None, + ) + object_prediction_list.extend(prediction_result.object_prediction_list) + + if verbose == 2: + print( + "Slicing performed in", + durations_in_seconds["slice"], + "seconds.", + ) + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + + # merge matching predictions + if len(object_prediction_list) > 1: + object_prediction_list = postprocess(object_prediction_list) + + time_end = time.time() - time_start + durations_in_seconds["prediction"] = time_end + + return PredictionResult( + image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds + ) + + +def predict( + detection_model: DetectionModel = None, + model_type: str = "mmdet", + model_path: str = None, + model_config_path: str = None, + model_confidence_threshold: float = 0.25, + model_device: str = None, + model_category_mapping: dict = None, + model_category_remapping: dict = None, + source: str = None, + no_standard_prediction: bool = False, + no_sliced_prediction: bool = False, + image_size: int = None, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + postprocess_type: str = "GREEDYNMM", + postprocess_match_metric: str = "IOS", + postprocess_match_threshold: float = 0.5, + postprocess_class_agnostic: bool = False, + novisual: bool = False, + view_video: bool = False, + frame_skip_interval: int = 0, + export_pickle: bool = False, + export_crop: bool = False, + dataset_json_path: bool = None, + project: str = "runs/predict", + name: str = "exp", + visual_bbox_thickness: int = None, + visual_text_size: float = None, + visual_text_thickness: int = None, + visual_export_format: str = "png", + verbose: int = 1, + return_dict: bool = False, + force_postprocess_type: bool = False, +): + """ + Performs prediction for all present images in given folder. + + Args: + detection_model: sahi.model.DetectionModel + Optionally provide custom DetectionModel to be used for inference. When provided, + model_type, model_path, config_path, model_device, model_category_mapping, image_size + params will be ignored + model_type: str + mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'. + model_path: str + Path for the model weight + model_config_path: str + Path for the detection model config file + model_confidence_threshold: float + All predictions with score < model_confidence_threshold will be discarded. + model_device: str + Torch device, "cpu" or "cuda" + model_category_mapping: dict + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + model_category_remapping: dict: str to int + Remap category ids after performing inference + source: str + Folder directory that contains images or path of the image to be predicted. Also video to be predicted. + no_standard_prediction: bool + Dont perform standard prediction. Default: False. + no_sliced_prediction: bool + Dont perform sliced prediction. Default: False. + image_size: int + Input image size for each inference (image is scaled by preserving asp. rat.). + slice_height: int + Height of each slice. Defaults to ``512``. + slice_width: int + Width of each slice. Defaults to ``512``. + overlap_height_ratio: float + Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + overlap_width_ratio: float + Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window + of size 512 yields an overlap of 102 pixels). + Default to ``0.2``. + postprocess_type: str + Type of the postprocess to be used after sliced inference while merging/eliminating predictions. + Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_threshold: float + Sliced predictions having higher iou than postprocess_match_threshold will be + postprocessed after sliced prediction. + postprocess_class_agnostic: bool + If True, postprocess will ignore category ids. + novisual: bool + Dont export predicted video/image visuals. + view_video: bool + View result of prediction during video inference. + frame_skip_interval: int + If view_video or export_visual is slow, you can process one frames of 3(for exp: --frame_skip_interval=3). + export_pickle: bool + Export predictions as .pickle + export_crop: bool + Export predictions as cropped images. + dataset_json_path: str + If coco file path is provided, detection results will be exported in coco json format. + project: str + Save results to project/name. + name: str + Save results to project/name. + visual_bbox_thickness: int + visual_text_size: float + visual_text_thickness: int + visual_export_format: str + Can be specified as 'jpg' or 'png' + verbose: int + 0: no print + 1: print slice/prediction durations, number of slices + 2: print model loading/file exporting durations + return_dict: bool + If True, returns a dict with 'export_dir' field. + force_postprocess_type: bool + If True, auto postprocess check will e disabled + """ + # assert prediction type + if no_standard_prediction and no_sliced_prediction: + raise ValueError("'no_standard_prediction' and 'no_sliced_prediction' cannot be True at the same time.") + + # auto postprocess type + if not force_postprocess_type and model_confidence_threshold < LOW_MODEL_CONFIDENCE and postprocess_type != "NMS": + logger.warning( + f"Switching postprocess type/metric to NMS/IOU since confidence threshold is low ({model_confidence_threshold})." + ) + postprocess_type = "NMS" + postprocess_match_metric = "IOU" + + # for profiling + durations_in_seconds = dict() + + # init export directories + save_dir = Path(increment_path(Path(project) / name, exist_ok=False)) # increment run + crop_dir = save_dir / "crops" + visual_dir = save_dir / "visuals" + visual_with_gt_dir = save_dir / "visuals_with_gt" + pickle_dir = save_dir / "pickles" + if not novisual or export_pickle or export_crop or dataset_json_path is not None: + save_dir.mkdir(parents=True, exist_ok=True) # make dir + + # init image iterator + # TODO: rewrite this as iterator class as in https://github.com/ultralytics/yolov5/blob/d059d1da03aee9a3c0059895aa4c7c14b7f25a9e/utils/datasets.py#L178 + source_is_video = False + num_frames = None + if dataset_json_path: + coco: Coco = Coco.from_coco_dict_or_path(dataset_json_path) + image_iterator = [str(Path(source) / Path(coco_image.file_name)) for coco_image in coco.images] + coco_json = [] + elif os.path.isdir(source): + image_iterator = list_files( + directory=source, + contains=IMAGE_EXTENSIONS, + verbose=verbose, + ) + elif Path(source).suffix in VIDEO_EXTENSIONS: + source_is_video = True + read_video_frame, output_video_writer, video_file_name, num_frames = get_video_reader( + source, save_dir, frame_skip_interval, not novisual, view_video + ) + image_iterator = read_video_frame + else: + image_iterator = [source] + + # init model instance + time_start = time.time() + if detection_model is None: + detection_model = AutoDetectionModel.from_pretrained( + model_type=model_type, + model_path=model_path, + config_path=model_config_path, + confidence_threshold=model_confidence_threshold, + device=model_device, + category_mapping=model_category_mapping, + category_remapping=model_category_remapping, + load_at_init=False, + image_size=image_size, + ) + detection_model.load_model() + time_end = time.time() - time_start + durations_in_seconds["model_load"] = time_end + + # iterate over source images + durations_in_seconds["prediction"] = 0 + durations_in_seconds["slice"] = 0 + + input_type_str = "video frames" if source_is_video else "images" + for ind, image_path in enumerate( + tqdm(image_iterator, f"Performing inference on {input_type_str}", total=num_frames) + ): + # get filename + if source_is_video: + video_name = Path(source).stem + relative_filepath = video_name + "_frame_" + str(ind) + elif os.path.isdir(source): # preserve source folder structure in export + relative_filepath = str(Path(image_path)).split(str(Path(source)))[-1] + relative_filepath = relative_filepath[1:] if relative_filepath[0] == os.sep else relative_filepath + else: # no process if source is single file + relative_filepath = Path(image_path).name + + filename_without_extension = Path(relative_filepath).stem + + # load image + image_as_pil = read_image_as_pil(image_path) + + # perform prediction + if not no_sliced_prediction: + # get sliced prediction + prediction_result = get_sliced_prediction( + image=image_as_pil, + detection_model=detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=not no_standard_prediction, + postprocess_type=postprocess_type, + postprocess_match_metric=postprocess_match_metric, + postprocess_match_threshold=postprocess_match_threshold, + postprocess_class_agnostic=postprocess_class_agnostic, + verbose=1 if verbose else 0, + ) + object_prediction_list = prediction_result.object_prediction_list + durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"] + else: + # get standard prediction + prediction_result = get_prediction( + image=image_as_pil, + detection_model=detection_model, + shift_amount=[0, 0], + full_shape=None, + postprocess=None, + verbose=0, + ) + object_prediction_list = prediction_result.object_prediction_list + + durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"] + # Show prediction time + tqdm.write("Prediction time is: {:.2f} ms".format(prediction_result.durations_in_seconds["prediction"] * 1000)) + + if dataset_json_path: + if source_is_video is True: + raise NotImplementedError("Video input type not supported with coco formatted dataset json") + + # append predictions in coco format + for object_prediction in object_prediction_list: + coco_prediction = object_prediction.to_coco_prediction() + coco_prediction.image_id = coco.images[ind].id + coco_prediction_json = coco_prediction.json + if coco_prediction_json["bbox"]: + coco_json.append(coco_prediction_json) + if not novisual: + # convert ground truth annotations to object_prediction_list + coco_image: CocoImage = coco.images[ind] + object_prediction_gt_list: List[ObjectPrediction] = [] + for coco_annotation in coco_image.annotations: + coco_annotation_dict = coco_annotation.json + category_name = coco_annotation.category_name + full_shape = [coco_image.height, coco_image.width] + object_prediction_gt = ObjectPrediction.from_coco_annotation_dict( + annotation_dict=coco_annotation_dict, category_name=category_name, full_shape=full_shape + ) + object_prediction_gt_list.append(object_prediction_gt) + # export visualizations with ground truths + output_dir = str(visual_with_gt_dir / Path(relative_filepath).parent) + color = (0, 255, 0) # original annotations in green + result = visualize_object_predictions( + np.ascontiguousarray(image_as_pil), + object_prediction_list=object_prediction_gt_list, + rect_th=visual_bbox_thickness, + text_size=visual_text_size, + text_th=visual_text_thickness, + color=color, + output_dir=None, + file_name=None, + export_format=None, + ) + color = (255, 0, 0) # model predictions in red + _ = visualize_object_predictions( + result["image"], + object_prediction_list=object_prediction_list, + rect_th=visual_bbox_thickness, + text_size=visual_text_size, + text_th=visual_text_thickness, + color=color, + output_dir=output_dir, + file_name=filename_without_extension, + export_format=visual_export_format, + ) + + time_start = time.time() + # export prediction boxes + if export_crop: + output_dir = str(crop_dir / Path(relative_filepath).parent) + crop_object_predictions( + image=np.ascontiguousarray(image_as_pil), + object_prediction_list=object_prediction_list, + output_dir=output_dir, + file_name=filename_without_extension, + export_format=visual_export_format, + ) + # export prediction list as pickle + if export_pickle: + save_path = str(pickle_dir / Path(relative_filepath).parent / (filename_without_extension + ".pickle")) + save_pickle(data=object_prediction_list, save_path=save_path) + + # export visualization + if not novisual or view_video: + output_dir = str(visual_dir / Path(relative_filepath).parent) + result = visualize_object_predictions( + np.ascontiguousarray(image_as_pil), + object_prediction_list=object_prediction_list, + rect_th=visual_bbox_thickness, + text_size=visual_text_size, + text_th=visual_text_thickness, + output_dir=output_dir if not source_is_video else None, + file_name=filename_without_extension, + export_format=visual_export_format, + ) + if not novisual and source_is_video: # export video + output_video_writer.write(result["image"]) + + # render video inference + if view_video: + cv2.imshow("Prediction of {}".format(str(video_file_name)), result["image"]) + cv2.waitKey(1) + + time_end = time.time() - time_start + durations_in_seconds["export_files"] = time_end + + # export coco results + if dataset_json_path: + save_path = str(save_dir / "result.json") + save_json(coco_json, save_path) + + if not novisual or export_pickle or export_crop or dataset_json_path is not None: + print(f"Prediction results are successfully exported to {save_dir}") + + # print prediction duration + if verbose == 2: + print( + "Model loaded in", + durations_in_seconds["model_load"], + "seconds.", + ) + print( + "Slicing performed in", + durations_in_seconds["slice"], + "seconds.", + ) + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + if not novisual: + print( + "Exporting performed in", + durations_in_seconds["export_files"], + "seconds.", + ) + + if return_dict: + return {"export_dir": save_dir} + + +@check_requirements(["fiftyone"]) +def predict_fiftyone( + model_type: str = "mmdet", + model_path: str = None, + model_config_path: str = None, + model_confidence_threshold: float = 0.25, + model_device: str = None, + model_category_mapping: dict = None, + model_category_remapping: dict = None, + dataset_json_path: str = None, + image_dir: str = None, + no_standard_prediction: bool = False, + no_sliced_prediction: bool = False, + image_size: int = None, + slice_height: int = 256, + slice_width: int = 256, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + postprocess_type: str = "GREEDYNMM", + postprocess_match_metric: str = "IOS", + postprocess_match_threshold: float = 0.5, + postprocess_class_agnostic: bool = False, + verbose: int = 1, +): + """ + Performs prediction for all present images in given folder. + + Args: + model_type: str + mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'. + model_path: str + Path for the model weight + model_config_path: str + Path for the detection model config file + model_confidence_threshold: float + All predictions with score < model_confidence_threshold will be discarded. + model_device: str + Torch device, "cpu" or "cuda" + model_category_mapping: dict + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + model_category_remapping: dict: str to int + Remap category ids after performing inference + dataset_json_path: str + If coco file path is provided, detection results will be exported in coco json format. + image_dir: str + Folder directory that contains images or path of the image to be predicted. + no_standard_prediction: bool + Dont perform standard prediction. Default: False. + no_sliced_prediction: bool + Dont perform sliced prediction. Default: False. + image_size: int + Input image size for each inference (image is scaled by preserving asp. rat.). + slice_height: int + Height of each slice. Defaults to ``256``. + slice_width: int + Width of each slice. Defaults to ``256``. + overlap_height_ratio: float + Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window + of size 256 yields an overlap of 51 pixels). + Default to ``0.2``. + overlap_width_ratio: float + Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window + of size 256 yields an overlap of 51 pixels). + Default to ``0.2``. + postprocess_type: str + Type of the postprocess to be used after sliced inference while merging/eliminating predictions. + Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_metric: str + Metric to be used during object prediction matching after sliced prediction. + 'IOU' for intersection over union, 'IOS' for intersection over smaller area. + postprocess_match_threshold: float + Sliced predictions having higher iou than postprocess_match_threshold will be + postprocessed after sliced prediction. + postprocess_class_agnostic: bool + If True, postprocess will ignore category ids. + verbose: int + 0: no print + 1: print slice/prediction durations, number of slices, model loading/file exporting durations + """ + from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file, fo + + # assert prediction type + if no_standard_prediction and no_sliced_prediction: + raise ValueError("'no_standard_pred' and 'no_sliced_prediction' cannot be True at the same time.") + # for profiling + durations_in_seconds = dict() + + dataset = create_fiftyone_dataset_from_coco_file(image_dir, dataset_json_path) + + # init model instance + time_start = time.time() + detection_model = AutoDetectionModel.from_pretrained( + model_type=model_type, + model_path=model_path, + config_path=model_config_path, + confidence_threshold=model_confidence_threshold, + device=model_device, + category_mapping=model_category_mapping, + category_remapping=model_category_remapping, + load_at_init=False, + image_size=image_size, + ) + detection_model.load_model() + time_end = time.time() - time_start + durations_in_seconds["model_load"] = time_end + + # iterate over source images + durations_in_seconds["prediction"] = 0 + durations_in_seconds["slice"] = 0 + # Add predictions to samples + with fo.ProgressBar() as pb: + for sample in pb(dataset): + # perform prediction + if not no_sliced_prediction: + # get sliced prediction + prediction_result = get_sliced_prediction( + image=sample.filepath, + detection_model=detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=not no_standard_prediction, + postprocess_type=postprocess_type, + postprocess_match_threshold=postprocess_match_threshold, + postprocess_match_metric=postprocess_match_metric, + postprocess_class_agnostic=postprocess_class_agnostic, + verbose=verbose, + ) + durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"] + else: + # get standard prediction + prediction_result = get_prediction( + image=sample.filepath, + detection_model=detection_model, + shift_amount=[0, 0], + full_shape=None, + postprocess=None, + verbose=0, + ) + durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"] + + # Save predictions to dataset + sample[model_type] = fo.Detections(detections=prediction_result.to_fiftyone_detections()) + sample.save() + + # print prediction duration + if verbose == 1: + print( + "Model loaded in", + durations_in_seconds["model_load"], + "seconds.", + ) + print( + "Slicing performed in", + durations_in_seconds["slice"], + "seconds.", + ) + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + + # visualize results + session = fo.launch_app() + session.dataset = dataset + # Evaluate the predictions + results = dataset.evaluate_detections( + model_type, + gt_field="ground_truth", + eval_key="eval", + iou=postprocess_match_threshold, + compute_mAP=True, + ) + # Get the 10 most common classes in the dataset + counts = dataset.count_values("ground_truth.detections.label") + classes_top10 = sorted(counts, key=counts.get, reverse=True)[:10] + # Print a classification report for the top-10 classes + results.print_report(classes=classes_top10) + # Load the view on which we ran the `eval` evaluation + eval_view = dataset.load_evaluation_view("eval") + # Show samples with most false positives + session.view = eval_view.sort_by("eval_fp", reverse=True) + while 1: + time.sleep(3) diff --git a/sahi/scripts/predict_fiftyone.py b/sahi/scripts/predict_fiftyone.py new file mode 100644 index 0000000..ec3aa0d --- /dev/null +++ b/sahi/scripts/predict_fiftyone.py @@ -0,0 +1,11 @@ +import fire + +from sahi.predict import predict_fiftyone + + +def main(): + fire.Fire(predict_fiftyone) + + +if __name__ == "__main__": + main() diff --git a/sahi/scripts/slice_coco.py b/sahi/scripts/slice_coco.py new file mode 100644 index 0000000..d5606a2 --- /dev/null +++ b/sahi/scripts/slice_coco.py @@ -0,0 +1,67 @@ +import os + +import fire + +from sahi.slicing import slice_coco +from sahi.utils.file import Path, save_json + + +def slice( + image_dir: str, + dataset_json_path: str, + slice_size: int = 512, + overlap_ratio: float = 0.2, + ignore_negative_samples: bool = False, + output_dir: str = "runs/slice_coco", + min_area_ratio: float = 0.1, +): + """ + Args: + image_dir (str): directory for coco images + dataset_json_path (str): file path for the coco dataset json file + slice_size (int) + overlap_ratio (float): slice overlap ratio + ignore_negative_samples (bool): ignore images without annotation + output_dir (str): output export dir + min_area_ratio (float): If the cropped annotation area to original + annotation ratio is smaller than this value, the annotation + is filtered out. Default 0.1. + """ + + # assure slice_size is list + slice_size_list = slice_size + if isinstance(slice_size_list, (int, float)): + slice_size_list = [slice_size_list] + + # slice coco dataset images and annotations + print("Slicing step is starting...") + for slice_size in slice_size_list: + # in format: train_images_512_01 + output_images_folder_name = ( + Path(dataset_json_path).stem + f"_images_{str(slice_size)}_{str(overlap_ratio).replace('.','')}" + ) + output_images_dir = str(Path(output_dir) / output_images_folder_name) + sliced_coco_name = Path(dataset_json_path).name.replace( + ".json", f"_{str(slice_size)}_{str(overlap_ratio).replace('.','')}" + ) + coco_dict, coco_path = slice_coco( + coco_annotation_file_path=dataset_json_path, + image_dir=image_dir, + output_coco_annotation_file_name="", + output_dir=output_images_dir, + ignore_negative_samples=ignore_negative_samples, + slice_height=slice_size, + slice_width=slice_size, + min_area_ratio=min_area_ratio, + overlap_height_ratio=overlap_ratio, + overlap_width_ratio=overlap_ratio, + out_ext=".jpg", + verbose=False, + ) + output_coco_annotation_file_path = os.path.join(output_dir, sliced_coco_name + ".json") + save_json(coco_dict, output_coco_annotation_file_path) + print(f"Sliced dataset for 'slice_size: {slice_size}' is exported to {output_dir}") + + +if __name__ == "__main__": + fire.Fire(slice) diff --git a/sahi/slicing.py b/sahi/slicing.py new file mode 100644 index 0000000..7a7401e --- /dev/null +++ b/sahi/slicing.py @@ -0,0 +1,468 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import concurrent.futures +import logging +import os +import time +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +from PIL import Image +from shapely.errors import TopologicalError +from tqdm import tqdm + +from sahi.utils.coco import Coco, CocoAnnotation, CocoImage, create_coco_dict +from sahi.utils.cv import read_image_as_pil +from sahi.utils.file import load_json, save_json + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + +MAX_WORKERS = 20 + + +def get_slice_bboxes( + image_height: int, + image_width: int, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: int = 0.2, + overlap_width_ratio: int = 0.2, +) -> List[List[int]]: + """Slices `image_pil` in crops. + Corner values of each slice will be generated using the `slice_height`, + `slice_width`, `overlap_height_ratio` and `overlap_width_ratio` arguments. + + Args: + image_height (int): Height of the original image. + image_width (int): Width of the original image. + slice_height (int): Height of each slice. Default 512. + slice_width (int): Width of each slice. Default 512. + overlap_height_ratio(float): Fractional overlap in height of each + slice (e.g. an overlap of 0.2 for a slice of size 100 yields an + overlap of 20 pixels). Default 0.2. + overlap_width_ratio(float): Fractional overlap in width of each + slice (e.g. an overlap of 0.2 for a slice of size 100 yields an + overlap of 20 pixels). Default 0.2. + + Returns: + List[List[int]]: List of 4 corner coordinates for each N slices. + [ + [slice_0_left, slice_0_top, slice_0_right, slice_0_bottom], + ... + [slice_N_left, slice_N_top, slice_N_right, slice_N_bottom] + ] + """ + slice_bboxes = [] + y_max = y_min = 0 + y_overlap = int(overlap_height_ratio * slice_height) + x_overlap = int(overlap_width_ratio * slice_width) + while y_max < image_height: + x_min = x_max = 0 + y_max = y_min + slice_height + while x_max < image_width: + x_max = x_min + slice_width + if y_max > image_height or x_max > image_width: + xmax = min(image_width, x_max) + ymax = min(image_height, y_max) + xmin = max(0, xmax - slice_width) + ymin = max(0, ymax - slice_height) + slice_bboxes.append([xmin, ymin, xmax, ymax]) + else: + slice_bboxes.append([x_min, y_min, x_max, y_max]) + x_min = x_max - x_overlap + y_min = y_max - y_overlap + return slice_bboxes + + +def annotation_inside_slice(annotation: Dict, slice_bbox: List[int]) -> bool: + """Check whether annotation coordinates lie inside slice coordinates. + + Args: + annotation (dict): Single annotation entry in COCO format. + slice_bbox (List[int]): Generated from `get_slice_bboxes`. + Format for each slice bbox: [x_min, y_min, x_max, y_max]. + + Returns: + (bool): True if any annotation coordinate lies inside slice. + """ + left, top, width, height = annotation["bbox"] + + right = left + width + bottom = top + height + + if left >= slice_bbox[2]: + return False + if top >= slice_bbox[3]: + return False + if right <= slice_bbox[0]: + return False + if bottom <= slice_bbox[1]: + return False + + return True + + +def process_coco_annotations(coco_annotation_list: List[CocoAnnotation], slice_bbox: List[int], min_area_ratio) -> bool: + """Slices and filters given list of CocoAnnotation objects with given + 'slice_bbox' and 'min_area_ratio'. + + Args: + coco_annotation_list (List[CocoAnnotation]) + slice_bbox (List[int]): Generated from `get_slice_bboxes`. + Format for each slice bbox: [x_min, y_min, x_max, y_max]. + min_area_ratio (float): If the cropped annotation area to original + annotation ratio is smaller than this value, the annotation is + filtered out. Default 0.1. + + Returns: + (List[CocoAnnotation]): Sliced annotations. + """ + + sliced_coco_annotation_list: List[CocoAnnotation] = [] + for coco_annotation in coco_annotation_list: + if annotation_inside_slice(coco_annotation.json, slice_bbox): + sliced_coco_annotation = coco_annotation.get_sliced_coco_annotation(slice_bbox) + if sliced_coco_annotation.area / coco_annotation.area >= min_area_ratio: + sliced_coco_annotation_list.append(sliced_coco_annotation) + return sliced_coco_annotation_list + + +class SlicedImage: + def __init__(self, image, coco_image, starting_pixel): + """ + image: np.array + Sliced image. + coco_image: CocoImage + Coco styled image object that belong to sliced image. + starting_pixel: list of list of int + Starting pixel coordinates of the sliced image. + """ + self.image = image + self.coco_image = coco_image + self.starting_pixel = starting_pixel + + +class SliceImageResult: + def __init__(self, original_image_size=None, image_dir: str = None): + """ + sliced_image_list: list of SlicedImage + image_dir: str + Directory of the sliced image exports. + original_image_size: list of int + Size of the unsliced original image in [height, width] + """ + self._sliced_image_list: List[SlicedImage] = [] + self.original_image_height = original_image_size[0] + self.original_image_width = original_image_size[1] + self.image_dir = image_dir + + def add_sliced_image(self, sliced_image: SlicedImage): + if not isinstance(sliced_image, SlicedImage): + raise TypeError("sliced_image must be a SlicedImage instance") + + self._sliced_image_list.append(sliced_image) + + @property + def sliced_image_list(self): + return self._sliced_image_list + + @property + def images(self): + """Returns sliced images. + + Returns: + images: a list of np.array + """ + images = [] + for sliced_image in self._sliced_image_list: + images.append(sliced_image.image) + return images + + @property + def coco_images(self) -> List[CocoImage]: + """Returns CocoImage representation of SliceImageResult. + + Returns: + coco_images: a list of CocoImage + """ + coco_images: List = [] + for sliced_image in self._sliced_image_list: + coco_images.append(sliced_image.coco_image) + return coco_images + + @property + def starting_pixels(self) -> List[int]: + """Returns a list of starting pixels for each slice. + + Returns: + starting_pixels: a list of starting pixel coords [x,y] + """ + starting_pixels = [] + for sliced_image in self._sliced_image_list: + starting_pixels.append(sliced_image.starting_pixel) + return starting_pixels + + @property + def filenames(self) -> List[int]: + """Returns a list of filenames for each slice. + + Returns: + filenames: a list of filenames as str + """ + filenames = [] + for sliced_image in self._sliced_image_list: + filenames.append(sliced_image.coco_image.file_name) + return filenames + + def __len__(self): + return len(self._sliced_image_list) + + +def slice_image( + image: Union[str, Image.Image], + coco_annotation_list: Optional[CocoAnnotation] = None, + output_file_name: Optional[str] = None, + output_dir: Optional[str] = None, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + min_area_ratio: float = 0.1, + out_ext: Optional[str] = None, + verbose: bool = False, +) -> SliceImageResult: + + """Slice a large image into smaller windows. If output_file_name is given export + sliced images. + + Args: + image (str or PIL.Image): File path of image or Pillow Image to be sliced. + coco_annotation_list (CocoAnnotation): List of CocoAnnotation objects. + output_file_name (str, optional): Root name of output files (coordinates will + be appended to this) + output_dir (str, optional): Output directory + slice_height (int): Height of each slice. Default 512. + slice_width (int): Width of each slice. Default 512. + overlap_height_ratio (float): Fractional overlap in height of each + slice (e.g. an overlap of 0.2 for a slice of size 100 yields an + overlap of 20 pixels). Default 0.2. + overlap_width_ratio (float): Fractional overlap in width of each + slice (e.g. an overlap of 0.2 for a slice of size 100 yields an + overlap of 20 pixels). Default 0.2. + min_area_ratio (float): If the cropped annotation area to original annotation + ratio is smaller than this value, the annotation is filtered out. Default 0.1. + out_ext (str, optional): Extension of saved images. Default is the + original suffix. + verbose (bool, optional): Switch to print relevant values to screen. + Default 'False'. + + Returns: + sliced_image_result: SliceImageResult: + sliced_image_list: list of SlicedImage + image_dir: str + Directory of the sliced image exports. + original_image_size: list of int + Size of the unsliced original image in [height, width] + num_total_invalid_segmentation: int + Number of invalid segmentation annotations. + """ + + # define verboseprint + verboselog = logger.info if verbose else lambda *a, **k: None + + def _export_single_slice(image: np.ndarray, output_dir: str, slice_file_name: str): + image_pil = read_image_as_pil(image) + slice_file_path = str(Path(output_dir) / slice_file_name) + # export sliced image + image_pil.save(slice_file_path) + verboselog("sliced image path: " + slice_file_path) + + # create outdir if not present + if output_dir is not None: + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # read image + image_pil = read_image_as_pil(image) + verboselog("image.shape: " + str(image_pil.size)) + + image_width, image_height = image_pil.size + if not (image_width != 0 and image_height != 0): + raise RuntimeError(f"invalid image size: {image_pil.size} for 'slice_image'.") + slice_bboxes = get_slice_bboxes( + image_height=image_height, + image_width=image_width, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + ) + + t0 = time.time() + n_ims = 0 + + # init images and annotations lists + sliced_image_result = SliceImageResult(original_image_size=[image_height, image_width], image_dir=output_dir) + + image_pil_arr = np.asarray(image_pil) + # iterate over slices + for slice_bbox in slice_bboxes: + n_ims += 1 + + # extract image + tlx = slice_bbox[0] + tly = slice_bbox[1] + brx = slice_bbox[2] + bry = slice_bbox[3] + image_pil_slice = image_pil_arr[tly:bry, tlx:brx] + + # process annotations if coco_annotations is given + if coco_annotation_list is not None: + sliced_coco_annotation_list = process_coco_annotations(coco_annotation_list, slice_bbox, min_area_ratio) + + # set image file suffixes + slice_suffixes = "_".join(map(str, slice_bbox)) + if out_ext: + suffix = out_ext + else: + try: + suffix = Path(image_pil.filename).suffix + except AttributeError: + suffix = ".jpg" + + # set image file name and path + slice_file_name = f"{output_file_name}_{slice_suffixes}{suffix}" + + # create coco image + slice_width = slice_bbox[2] - slice_bbox[0] + slice_height = slice_bbox[3] - slice_bbox[1] + coco_image = CocoImage(file_name=slice_file_name, height=slice_height, width=slice_width) + + # append coco annotations (if present) to coco image + if coco_annotation_list: + for coco_annotation in sliced_coco_annotation_list: + coco_image.add_annotation(coco_annotation) + + # create sliced image and append to sliced_image_result + sliced_image = SlicedImage( + image=image_pil_slice, + coco_image=coco_image, + starting_pixel=[slice_bbox[0], slice_bbox[1]], + ) + sliced_image_result.add_sliced_image(sliced_image) + + # export slices if output directory is provided + if output_file_name and output_dir: + conc_exec = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) + conc_exec.map( + _export_single_slice, + sliced_image_result.images, + [output_dir] * len(sliced_image_result), + sliced_image_result.filenames, + ) + + verboselog( + "Num slices: " + str(n_ims) + " slice_height: " + str(slice_height) + " slice_width: " + str(slice_width), + ) + + return sliced_image_result + + +def slice_coco( + coco_annotation_file_path: str, + image_dir: str, + output_coco_annotation_file_name: str, + output_dir: Optional[str] = None, + ignore_negative_samples: bool = False, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, + min_area_ratio: float = 0.1, + out_ext: Optional[str] = None, + verbose: bool = False, +) -> List[Union[Dict, str]]: + + """ + Slice large images given in a directory, into smaller windows. If out_name is given export sliced images and coco file. + + Args: + coco_annotation_file_pat (str): Location of the coco annotation file + image_dir (str): Base directory for the images + output_coco_annotation_file_name (str): File name of the exported coco + datatset json. + output_dir (str, optional): Output directory + ignore_negative_samples (bool): If True, images without annotations + are ignored. Defaults to False. + slice_height (int): Height of each slice. Default 512. + slice_width (int): Width of each slice. Default 512. + overlap_height_ratio (float): Fractional overlap in height of each + slice (e.g. an overlap of 0.2 for a slice of size 100 yields an + overlap of 20 pixels). Default 0.2. + overlap_width_ratio (float): Fractional overlap in width of each + slice (e.g. an overlap of 0.2 for a slice of size 100 yields an + overlap of 20 pixels). Default 0.2. + min_area_ratio (float): If the cropped annotation area to original annotation + ratio is smaller than this value, the annotation is filtered out. Default 0.1. + out_ext (str, optional): Extension of saved images. Default is the + original suffix. + verbose (bool, optional): Switch to print relevant values to screen. + Default 'False'. + + Returns: + coco_dict: dict + COCO dict for sliced images and annotations + save_path: str + Path to the saved coco file + """ + + # read coco file + coco_dict: Dict = load_json(coco_annotation_file_path) + # create image_id_to_annotation_list mapping + coco = Coco.from_coco_dict_or_path(coco_dict) + # init sliced coco_utils.CocoImage list + sliced_coco_images: List = [] + + # iterate over images and slice + for coco_image in tqdm(coco.images): + # get image path + image_path: str = os.path.join(image_dir, coco_image.file_name) + # get annotation json list corresponding to selected coco image + # slice image + try: + slice_image_result = slice_image( + image=image_path, + coco_annotation_list=coco_image.annotations, + output_file_name=Path(coco_image.file_name).stem, + output_dir=output_dir, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + min_area_ratio=min_area_ratio, + out_ext=out_ext, + verbose=verbose, + ) + # append slice outputs + sliced_coco_images.extend(slice_image_result.coco_images) + except TopologicalError: + logger.warning(f"Invalid annotation found, skipping this image: {image_path}") + + # create and save coco dict + coco_dict = create_coco_dict( + sliced_coco_images, + coco_dict["categories"], + ignore_negative_samples=ignore_negative_samples, + ) + save_path = "" + if output_coco_annotation_file_name and output_dir: + save_path = Path(output_dir) / (output_coco_annotation_file_name + "_coco.json") + save_json(coco_dict, save_path) + + return coco_dict, save_path diff --git a/sahi/utils/__init__.py b/sahi/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sahi/utils/coco.py b/sahi/utils/coco.py new file mode 100644 index 0000000..4503fae --- /dev/null +++ b/sahi/utils/coco.py @@ -0,0 +1,2407 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. +# Modified by Sinan O Altinuc, 2020. + +import copy +import logging +import os +from collections import Counter, defaultdict +from dataclasses import dataclass +from multiprocessing import Pool +from pathlib import Path +from typing import Dict, List, Optional, Set, Union + +import numpy as np +from tqdm import tqdm + +from sahi.utils.file import load_json, save_json +from sahi.utils.shapely import ShapelyAnnotation, box, get_shapely_multipolygon + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +class CocoCategory: + """ + COCO formatted category. + """ + + def __init__(self, id=None, name=None, supercategory=None): + self.id = int(id) + self.name = name + self.supercategory = supercategory if supercategory else name + + @classmethod + def from_coco_category(cls, category): + """ + Creates CocoCategory object using coco category. + + Args: + category: Dict + {"supercategory": "person", "id": 1, "name": "person"}, + """ + return cls( + id=category["id"], + name=category["name"], + supercategory=category["supercategory"] if "supercategory" in category else category["name"], + ) + + @property + def json(self): + return { + "id": self.id, + "name": self.name, + "supercategory": self.supercategory, + } + + def __repr__(self): + return f"""CocoCategory< + id: {self.id}, + name: {self.name}, + supercategory: {self.supercategory}>""" + + +class CocoAnnotation: + """ + COCO formatted annotation. + """ + + @classmethod + def from_coco_segmentation(cls, segmentation, category_id, category_name, iscrowd=0): + """ + Creates CocoAnnotation object using coco segmentation. + + Args: + segmentation: List[List] + [[1, 1, 325, 125, 250, 200, 5, 200]] + category_id: int + Category id of the annotation + category_name: str + Category name of the annotation + iscrowd: int + 0 or 1 + """ + return cls( + segmentation=segmentation, + category_id=category_id, + category_name=category_name, + iscrowd=iscrowd, + ) + + @classmethod + def from_coco_bbox(cls, bbox, category_id, category_name, iscrowd=0): + """ + Creates CocoAnnotation object using coco bbox + + Args: + bbox: List + [xmin, ymin, width, height] + category_id: int + Category id of the annotation + category_name: str + Category name of the annotation + iscrowd: int + 0 or 1 + """ + return cls( + bbox=bbox, + category_id=category_id, + category_name=category_name, + iscrowd=iscrowd, + ) + + @classmethod + def from_coco_annotation_dict(cls, annotation_dict: Dict, category_name: Optional[str] = None): + """ + Creates CocoAnnotation object from category name and COCO formatted + annotation dict (with fields "bbox", "segmentation", "category_id"). + + Args: + category_name: str + Category name of the annotation + annotation_dict: dict + COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id") + """ + if annotation_dict.__contains__("segmentation") and not isinstance(annotation_dict["segmentation"], list): + has_rle_segmentation = True + logger.warning( + f"Segmentation annotation for id {annotation_dict['id']} is skipped since RLE segmentation format is not supported." + ) + else: + has_rle_segmentation = False + + if ( + annotation_dict.__contains__("segmentation") + and annotation_dict["segmentation"] + and not has_rle_segmentation + ): + return cls( + segmentation=annotation_dict["segmentation"], + category_id=annotation_dict["category_id"], + category_name=category_name, + ) + else: + return cls( + bbox=annotation_dict["bbox"], + category_id=annotation_dict["category_id"], + category_name=category_name, + ) + + @classmethod + def from_shapely_annotation( + cls, + shapely_annotation: ShapelyAnnotation, + category_id: int, + category_name: str, + iscrowd: int, + ): + """ + Creates CocoAnnotation object from ShapelyAnnotation object. + + Args: + shapely_annotation (ShapelyAnnotation) + category_id (int): Category id of the annotation + category_name (str): Category name of the annotation + iscrowd (int): 0 or 1 + """ + coco_annotation = cls( + bbox=[0, 0, 0, 0], + category_id=category_id, + category_name=category_name, + iscrowd=iscrowd, + ) + coco_annotation._segmentation = shapely_annotation.to_coco_segmentation() + coco_annotation._shapely_annotation = shapely_annotation + return coco_annotation + + def __init__( + self, + segmentation=None, + bbox=None, + category_id=None, + category_name=None, + image_id=None, + iscrowd=0, + ): + """ + Creates coco annotation object using bbox or segmentation + + Args: + segmentation: List[List] + [[1, 1, 325, 125, 250, 200, 5, 200]] + bbox: List + [xmin, ymin, width, height] + category_id: int + Category id of the annotation + category_name: str + Category name of the annotation + image_id: int + Image ID of the annotation + iscrowd: int + 0 or 1 + """ + if bbox is None and segmentation is None: + raise ValueError("you must provide a bbox or polygon") + + self._segmentation = segmentation + bbox = [round(point) for point in bbox] if bbox else bbox + self._category_id = category_id + self._category_name = category_name + self._image_id = image_id + self._iscrowd = iscrowd + + if self._segmentation: + shapely_annotation = ShapelyAnnotation.from_coco_segmentation(segmentation=self._segmentation) + else: + shapely_annotation = ShapelyAnnotation.from_coco_bbox(bbox=bbox) + self._shapely_annotation = shapely_annotation + + def get_sliced_coco_annotation(self, slice_bbox: List[int]): + shapely_polygon = box(slice_bbox[0], slice_bbox[1], slice_bbox[2], slice_bbox[3]) + intersection_shapely_annotation = self._shapely_annotation.get_intersection(shapely_polygon) + return CocoAnnotation.from_shapely_annotation( + intersection_shapely_annotation, + category_id=self.category_id, + category_name=self.category_name, + iscrowd=self.iscrowd, + ) + + @property + def area(self): + """ + Returns area of annotation polygon (or bbox if no polygon available) + """ + return self._shapely_annotation.area + + @property + def bbox(self): + """ + Returns coco formatted bbox of the annotation as [xmin, ymin, width, height] + """ + return self._shapely_annotation.to_coco_bbox() + + @property + def segmentation(self): + """ + Returns coco formatted segmentation of the annotation as [[1, 1, 325, 125, 250, 200, 5, 200]] + """ + if self._segmentation: + return self._shapely_annotation.to_coco_segmentation() + else: + return [] + + @property + def category_id(self): + """ + Returns category id of the annotation as int + """ + return self._category_id + + @category_id.setter + def category_id(self, i): + if not isinstance(i, int): + raise Exception("category_id must be an integer") + self._category_id = i + + @property + def image_id(self): + """ + Returns image id of the annotation as int + """ + return self._image_id + + @image_id.setter + def image_id(self, i): + if not isinstance(i, int): + raise Exception("image_id must be an integer") + self._image_id = i + + @property + def category_name(self): + """ + Returns category name of the annotation as str + """ + return self._category_name + + @category_name.setter + def category_name(self, n): + if not isinstance(n, str): + raise Exception("category_name must be a string") + self._category_name = n + + @property + def iscrowd(self): + """ + Returns iscrowd info of the annotation + """ + return self._iscrowd + + @property + def json(self): + return { + "image_id": self.image_id, + "bbox": self.bbox, + "category_id": self.category_id, + "segmentation": self.segmentation, + "iscrowd": self.iscrowd, + "area": self.area, + } + + def serialize(self): + print(".serialize() is deprectaed, use .json instead") + + def __repr__(self): + return f"""CocoAnnotation< + image_id: {self.image_id}, + bbox: {self.bbox}, + segmentation: {self.segmentation}, + category_id: {self.category_id}, + category_name: {self.category_name}, + iscrowd: {self.iscrowd}, + area: {self.area}>""" + + +class CocoPrediction(CocoAnnotation): + """ + Class for handling predictions in coco format. + """ + + @classmethod + def from_coco_segmentation(cls, segmentation, category_id, category_name, score, iscrowd=0, image_id=None): + """ + Creates CocoAnnotation object using coco segmentation. + + Args: + segmentation: List[List] + [[1, 1, 325, 125, 250, 200, 5, 200]] + category_id: int + Category id of the annotation + category_name: str + Category name of the annotation + score: float + Prediction score between 0 and 1 + iscrowd: int + 0 or 1 + """ + return cls( + segmentation=segmentation, + category_id=category_id, + category_name=category_name, + score=score, + iscrowd=iscrowd, + image_id=image_id, + ) + + @classmethod + def from_coco_bbox(cls, bbox, category_id, category_name, score, iscrowd=0, image_id=None): + """ + Creates CocoAnnotation object using coco bbox + + Args: + bbox: List + [xmin, ymin, width, height] + category_id: int + Category id of the annotation + category_name: str + Category name of the annotation + score: float + Prediction score between 0 and 1 + iscrowd: int + 0 or 1 + """ + return cls( + bbox=bbox, + category_id=category_id, + category_name=category_name, + score=score, + iscrowd=iscrowd, + image_id=image_id, + ) + + @classmethod + def from_coco_annotation_dict(cls, category_name, annotation_dict, score, image_id=None): + """ + Creates CocoAnnotation object from category name and COCO formatted + annotation dict (with fields "bbox", "segmentation", "category_id"). + + Args: + category_name: str + Category name of the annotation + annotation_dict: dict + COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id") + score: float + Prediction score between 0 and 1 + """ + if annotation_dict["segmentation"]: + return cls( + segmentation=annotation_dict["segmentation"], + category_id=annotation_dict["category_id"], + category_name=category_name, + score=score, + image_id=image_id, + ) + else: + return cls( + bbox=annotation_dict["bbox"], + category_id=annotation_dict["category_id"], + category_name=category_name, + image_id=image_id, + ) + + def __init__( + self, + segmentation=None, + bbox=None, + category_id=None, + category_name=None, + image_id=None, + score=None, + iscrowd=0, + ): + """ + + Args: + segmentation: List[List] + [[1, 1, 325, 125, 250, 200, 5, 200]] + bbox: List + [xmin, ymin, width, height] + category_id: int + Category id of the annotation + category_name: str + Category name of the annotation + image_id: int + Image ID of the annotation + score: float + Prediction score between 0 and 1 + iscrowd: int + 0 or 1 + """ + self.score = score + super().__init__( + segmentation=segmentation, + bbox=bbox, + category_id=category_id, + category_name=category_name, + image_id=image_id, + iscrowd=iscrowd, + ) + + @property + def json(self): + return { + "image_id": self.image_id, + "bbox": self.bbox, + "score": self.score, + "category_id": self.category_id, + "category_name": self.category_name, + "segmentation": self.segmentation, + "iscrowd": self.iscrowd, + "area": self.area, + } + + def serialize(self): + print(".serialize() is deprectaed, use .json instead") + + def __repr__(self): + return f"""CocoPrediction< + image_id: {self.image_id}, + bbox: {self.bbox}, + segmentation: {self.segmentation}, + score: {self.score}, + category_id: {self.category_id}, + category_name: {self.category_name}, + iscrowd: {self.iscrowd}, + area: {self.area}>""" + + +class CocoVidAnnotation(CocoAnnotation): + """ + COCOVid formatted annotation. + https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file + """ + + def __init__( + self, + bbox=None, + category_id=None, + category_name=None, + image_id=None, + instance_id=None, + iscrowd=0, + id=None, + ): + """ + Args: + bbox: List + [xmin, ymin, width, height] + category_id: int + Category id of the annotation + category_name: str + Category name of the annotation + image_id: int + Image ID of the annotation + instance_id: int + Used for tracking + iscrowd: int + 0 or 1 + id: int + Annotation id + """ + super(CocoVidAnnotation, self).__init__( + bbox=bbox, + category_id=category_id, + category_name=category_name, + image_id=image_id, + iscrowd=iscrowd, + ) + self.instance_id = instance_id + self.id = id + + @property + def json(self): + return { + "id": self.id, + "image_id": self.image_id, + "bbox": self.bbox, + "segmentation": self.segmentation, + "category_id": self.category_id, + "category_name": self.category_name, + "instance_id": self.instance_id, + "iscrowd": self.iscrowd, + "area": self.area, + } + + def __repr__(self): + return f"""CocoAnnotation< + id: {self.id}, + image_id: {self.image_id}, + bbox: {self.bbox}, + segmentation: {self.segmentation}, + category_id: {self.category_id}, + category_name: {self.category_name}, + instance_id: {self.instance_id}, + iscrowd: {self.iscrowd}, + area: {self.area}>""" + + +class CocoImage: + @classmethod + def from_coco_image_dict(cls, image_dict): + """ + Creates CocoImage object from COCO formatted image dict (with fields "id", "file_name", "height" and "weight"). + + Args: + image_dict: dict + COCO formatted image dict (with fields "id", "file_name", "height" and "weight") + """ + return cls( + id=image_dict["id"], + file_name=image_dict["file_name"], + height=image_dict["height"], + width=image_dict["width"], + ) + + def __init__(self, file_name: str, height: int, width: int, id: int = None): + """ + Creates CocoImage object + + Args: + id : int + Image id + file_name : str + Image path + height : int + Image height in pixels + width : int + Image width in pixels + """ + self.id = int(id) if id else id + self.file_name = file_name + self.height = int(height) + self.width = int(width) + self.annotations = [] # list of CocoAnnotation that belong to this image + self.predictions = [] # list of CocoPrediction that belong to this image + + def add_annotation(self, annotation): + """ + Adds annotation to this CocoImage instance + + annotation : CocoAnnotation + """ + + if not isinstance(annotation, CocoAnnotation): + raise TypeError("annotation must be a CocoAnnotation instance") + self.annotations.append(annotation) + + def add_prediction(self, prediction): + """ + Adds prediction to this CocoImage instance + + prediction : CocoPrediction + """ + + if not isinstance(prediction, CocoPrediction): + raise TypeError("prediction must be a CocoPrediction instance") + self.predictions.append(prediction) + + @property + def json(self): + return { + "id": self.id, + "file_name": self.file_name, + "height": self.height, + "width": self.width, + } + + def __repr__(self): + return f"""CocoImage< + id: {self.id}, + file_name: {self.file_name}, + height: {self.height}, + width: {self.width}, + annotations: List[CocoAnnotation], + predictions: List[CocoPrediction]>""" + + +class CocoVidImage(CocoImage): + """ + COCOVid formatted image. + https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file + """ + + def __init__( + self, + file_name, + height, + width, + video_id=None, + frame_id=None, + id=None, + ): + """ + Creates CocoVidImage object + + Args: + id: int + Image id + file_name: str + Image path + height: int + Image height in pixels + width: int + Image width in pixels + frame_id: int + 0-indexed frame id + video_id: int + Video id + """ + super(CocoVidImage, self).__init__(file_name=file_name, height=height, width=width, id=id) + self.frame_id = frame_id + self.video_id = video_id + + @classmethod + def from_coco_image(cls, coco_image, video_id=None, frame_id=None): + """ + Creates CocoVidImage object using CocoImage object. + Args: + coco_image: CocoImage + frame_id: int + 0-indexed frame id + video_id: int + Video id + + """ + return cls( + file_name=coco_image.file_name, + height=coco_image.height, + width=coco_image.width, + id=coco_image.id, + video_id=video_id, + frame_id=frame_id, + ) + + def add_annotation(self, annotation): + """ + Adds annotation to this CocoImage instance + annotation : CocoVidAnnotation + """ + + if not isinstance(annotation, CocoVidAnnotation): + raise TypeError("annotation must be a CocoVidAnnotation instance") + self.annotations.append(annotation) + + @property + def json(self): + return { + "file_name": self.file_name, + "height": self.height, + "width": self.width, + "id": self.id, + "video_id": self.video_id, + "frame_id": self.frame_id, + } + + def __repr__(self): + return f"""CocoVidImage< + file_name: {self.file_name}, + height: {self.height}, + width: {self.width}, + id: {self.id}, + video_id: {self.video_id}, + frame_id: {self.frame_id}, + annotations: List[CocoVidAnnotation]>""" + + +class CocoVideo: + """ + COCO formatted video. + https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file + """ + + def __init__( + self, + name: str, + id: int = None, + fps: float = None, + height: int = None, + width: int = None, + ): + """ + Creates CocoVideo object + + Args: + name: str + Video name + id: int + Video id + fps: float + Video fps + height: int + Video height in pixels + width: int + Video width in pixels + """ + self.name = name + self.id = id + self.fps = fps + self.height = height + self.width = width + self.images = [] # list of CocoImage that belong to this video + + def add_image(self, image): + """ + Adds image to this CocoVideo instance + Args: + image: CocoImage + """ + + if not isinstance(image, CocoImage): + raise TypeError("image must be a CocoImage instance") + self.images.append(CocoVidImage.from_coco_image(image)) + + def add_cocovidimage(self, cocovidimage): + """ + Adds CocoVidImage to this CocoVideo instance + Args: + cocovidimage: CocoVidImage + """ + + if not isinstance(cocovidimage, CocoVidImage): + raise TypeError("cocovidimage must be a CocoVidImage instance") + self.images.append(cocovidimage) + + @property + def json(self): + return { + "name": self.name, + "id": self.id, + "fps": self.fps, + "height": self.height, + "width": self.width, + } + + def __repr__(self): + return f"""CocoVideo< + id: {self.id}, + name: {self.name}, + fps: {self.fps}, + height: {self.height}, + width: {self.width}, + images: List[CocoVidImage]>""" + + +class Coco: + def __init__( + self, + name=None, + image_dir=None, + remapping_dict=None, + ignore_negative_samples=False, + clip_bboxes_to_img_dims=False, + image_id_setting="auto", + ): + """ + Creates Coco object. + + Args: + name: str + Name of the Coco dataset, it determines exported json name. + image_dir: str + Base file directory that contains dataset images. Required for dataset merging. + remapping_dict: dict + {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1 + ignore_negative_samples: bool + If True ignores images without annotations in all operations. + image_id_setting: str + how to assign image ids while exporting can be + auto --> will assign id from scratch (.id will be ignored) + manual --> you will need to provide image ids in instances (.id can not be None) + """ + if image_id_setting not in ["auto", "manual"]: + raise ValueError("image_id_setting must be either 'auto' or 'manual'") + self.name = name + self.image_dir = image_dir + self.remapping_dict = remapping_dict + self.ignore_negative_samples = ignore_negative_samples + self.categories = [] + self.images = [] + self._stats = None + self.clip_bboxes_to_img_dims = clip_bboxes_to_img_dims + self.image_id_setting = image_id_setting + + def add_categories_from_coco_category_list(self, coco_category_list): + """ + Creates CocoCategory object using coco category list. + + Args: + coco_category_list: List[Dict] + [ + {"supercategory": "person", "id": 1, "name": "person"}, + {"supercategory": "vehicle", "id": 2, "name": "bicycle"} + ] + """ + + for coco_category in coco_category_list: + if self.remapping_dict is not None: + for source_id in self.remapping_dict.keys(): + if coco_category["id"] == source_id: + target_id = self.remapping_dict[source_id] + coco_category["id"] = target_id + + self.add_category(CocoCategory.from_coco_category(coco_category)) + + def add_category(self, category): + """ + Adds category to this Coco instance + + Args: + category: CocoCategory + """ + + # assert type(category) == CocoCategory, "category must be a CocoCategory instance" + if not isinstance(category, CocoCategory): + raise TypeError("category must be a CocoCategory instance") + self.categories.append(category) + + def add_image(self, image): + """ + Adds image to this Coco instance + + Args: + image: CocoImage + """ + + if self.image_id_setting == "manual" and image.id is None: + raise ValueError("image id should be manually set for image_id_setting='manual'") + self.images.append(image) + + def update_categories(self, desired_name2id, update_image_filenames=False): + """ + Rearranges category mapping of given COCO object based on given desired_name2id. + Can also be used to filter some of the categories. + + Args: + desired_name2id: dict + {"big_vehicle": 1, "car": 2, "human": 3} + update_image_filenames: bool + If True, updates coco image file_names with absolute file paths. + """ + # init vars + currentid2desiredid_mapping = {} + updated_coco = Coco( + name=self.name, + image_dir=self.image_dir, + remapping_dict=self.remapping_dict, + ignore_negative_samples=self.ignore_negative_samples, + ) + # create category id mapping (currentid2desiredid_mapping) + for coco_category in copy.deepcopy(self.categories): + current_category_id = coco_category.id + current_category_name = coco_category.name + if current_category_name in desired_name2id.keys(): + currentid2desiredid_mapping[current_category_id] = desired_name2id[current_category_name] + else: + # ignore categories that are not included in desired_name2id + currentid2desiredid_mapping[current_category_id] = None + + # add updated categories + for name in desired_name2id.keys(): + updated_coco_category = CocoCategory(id=desired_name2id[name], name=name, supercategory=name) + updated_coco.add_category(updated_coco_category) + + # add updated images & annotations + for coco_image in copy.deepcopy(self.images): + updated_coco_image = CocoImage.from_coco_image_dict(coco_image.json) + # update filename to abspath + file_name_is_abspath = True if os.path.abspath(coco_image.file_name) == coco_image.file_name else False + if update_image_filenames and not file_name_is_abspath: + updated_coco_image.file_name = str(Path(os.path.abspath(self.image_dir)) / coco_image.file_name) + # update annotations + for coco_annotation in coco_image.annotations: + current_category_id = coco_annotation.category_id + desired_category_id = currentid2desiredid_mapping[current_category_id] + # append annotations with category id present in desired_name2id + if desired_category_id is not None: + # update cetegory id + coco_annotation.category_id = desired_category_id + # append updated annotation to target coco dict + updated_coco_image.add_annotation(coco_annotation) + updated_coco.add_image(updated_coco_image) + + # overwrite instance + self.__class__ = updated_coco.__class__ + self.__dict__ = updated_coco.__dict__ + + def merge(self, coco, desired_name2id=None, verbose=1): + """ + Combines the images/annotations/categories of given coco object with current one. + + Args: + coco : sahi.utils.coco.Coco instance + A COCO dataset object + desired_name2id : dict + {"human": 1, "car": 2, "big_vehicle": 3} + verbose: bool + If True, merging info is printed + """ + if self.image_dir is None or coco.image_dir is None: + raise ValueError("image_dir should be provided for merging.") + if verbose: + if not desired_name2id: + print("'desired_name2id' is not specified, combining all categories.") + + # create desired_name2id by combining all categories, if desired_name2id is not specified + coco1 = self + coco2 = coco + category_ind = 0 + if desired_name2id is None: + desired_name2id = {} + for coco in [coco1, coco2]: + temp_categories = copy.deepcopy(coco.json_categories) + for temp_category in temp_categories: + if temp_category["name"] not in desired_name2id: + desired_name2id[temp_category["name"]] = category_ind + category_ind += 1 + else: + continue + + # update categories and image paths + for coco in [coco1, coco2]: + coco.update_categories(desired_name2id=desired_name2id, update_image_filenames=True) + + # combine images and categories + coco1.images.extend(coco2.images) + self.images: List[CocoImage] = coco1.images + self.categories = coco1.categories + + # print categories + if verbose: + print( + "Categories are formed as:\n", + self.json_categories, + ) + + @classmethod + def from_coco_dict_or_path( + cls, + coco_dict_or_path: Union[Dict, str], + image_dir: Optional[str] = None, + remapping_dict: Optional[Dict] = None, + ignore_negative_samples: bool = False, + clip_bboxes_to_img_dims: bool = False, + ): + """ + Creates coco object from COCO formatted dict or COCO dataset file path. + + Args: + coco_dict_or_path: dict/str or List[dict/str] + COCO formatted dict or COCO dataset file path + List of COCO formatted dict or COCO dataset file path + image_dir: str + Base file directory that contains dataset images. Required for merging and yolov5 conversion. + remapping_dict: dict + {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1 + ignore_negative_samples: bool + If True ignores images without annotations in all operations. + clip_bboxes_to_img_dims: bool = False + Limits bounding boxes to image dimensions. + + Properties: + images: list of CocoImage + category_mapping: dict + """ + # init coco object + coco = cls( + image_dir=image_dir, + remapping_dict=remapping_dict, + ignore_negative_samples=ignore_negative_samples, + clip_bboxes_to_img_dims=clip_bboxes_to_img_dims, + ) + + if type(coco_dict_or_path) not in [str, dict]: + raise TypeError("coco_dict_or_path should be a dict or str") + + # load coco dict if path is given + if type(coco_dict_or_path) == str: + coco_dict = load_json(coco_dict_or_path) + else: + coco_dict = coco_dict_or_path + + # arrange image id to annotation id mapping + coco.add_categories_from_coco_category_list(coco_dict["categories"]) + image_id_to_annotation_list = get_imageid2annotationlist_mapping(coco_dict) + category_mapping = coco.category_mapping + + # https://github.com/obss/sahi/issues/98 + image_id_set: Set = set() + + for coco_image_dict in tqdm(coco_dict["images"], "Loading coco annotations"): + coco_image = CocoImage.from_coco_image_dict(coco_image_dict) + image_id = coco_image_dict["id"] + # https://github.com/obss/sahi/issues/98 + if image_id in image_id_set: + print(f"duplicate image_id: {image_id}, will be ignored.") + continue + else: + image_id_set.add(image_id) + # select annotations of the image + annotation_list = image_id_to_annotation_list[image_id] + for coco_annotation_dict in annotation_list: + # apply category remapping if remapping_dict is provided + if coco.remapping_dict is not None: + # apply category remapping (id:id) + category_id = coco.remapping_dict[coco_annotation_dict["category_id"]] + # update category id + coco_annotation_dict["category_id"] = category_id + else: + category_id = coco_annotation_dict["category_id"] + # get category name (id:name) + category_name = category_mapping[category_id] + coco_annotation = CocoAnnotation.from_coco_annotation_dict( + category_name=category_name, annotation_dict=coco_annotation_dict + ) + coco_image.add_annotation(coco_annotation) + coco.add_image(coco_image) + + if clip_bboxes_to_img_dims: + coco = coco.get_coco_with_clipped_bboxes() + return coco + + @property + def json_categories(self): + categories = [] + for category in self.categories: + categories.append(category.json) + return categories + + @property + def category_mapping(self): + category_mapping = {} + for category in self.categories: + category_mapping[category.id] = category.name + return category_mapping + + @property + def json(self): + return create_coco_dict( + images=self.images, + categories=self.json_categories, + ignore_negative_samples=self.ignore_negative_samples, + image_id_setting=self.image_id_setting, + ) + + @property + def prediction_array(self): + return create_coco_prediction_array( + images=self.images, + ignore_negative_samples=self.ignore_negative_samples, + image_id_setting=self.image_id_setting, + ) + + @property + def stats(self): + if not self._stats: + self.calculate_stats() + return self._stats + + def calculate_stats(self): + """ + Iterates over all annotations and calculates total number of + """ + # init all stats + num_annotations = 0 + num_images = len(self.images) + num_negative_images = 0 + num_categories = len(self.json_categories) + category_name_to_zero = {category["name"]: 0 for category in self.json_categories} + category_name_to_inf = {category["name"]: float("inf") for category in self.json_categories} + num_images_per_category = copy.deepcopy(category_name_to_zero) + num_annotations_per_category = copy.deepcopy(category_name_to_zero) + min_annotation_area_per_category = copy.deepcopy(category_name_to_inf) + max_annotation_area_per_category = copy.deepcopy(category_name_to_zero) + min_num_annotations_in_image = float("inf") + max_num_annotations_in_image = 0 + total_annotation_area = 0 + min_annotation_area = 1e10 + max_annotation_area = 0 + for image in self.images: + image_contains_category = {} + for annotation in image.annotations: + annotation_area = annotation.area + total_annotation_area += annotation_area + num_annotations_per_category[annotation.category_name] += 1 + image_contains_category[annotation.category_name] = 1 + # update min&max annotation area + if annotation_area > max_annotation_area: + max_annotation_area = annotation_area + if annotation_area < min_annotation_area: + min_annotation_area = annotation_area + if annotation_area > max_annotation_area_per_category[annotation.category_name]: + max_annotation_area_per_category[annotation.category_name] = annotation_area + if annotation_area < min_annotation_area_per_category[annotation.category_name]: + min_annotation_area_per_category[annotation.category_name] = annotation_area + # update num_negative_images + if len(image.annotations) == 0: + num_negative_images += 1 + # update num_annotations + num_annotations += len(image.annotations) + # update num_images_per_category + num_images_per_category = dict(Counter(num_images_per_category) + Counter(image_contains_category)) + # update min&max_num_annotations_in_image + num_annotations_in_image = len(image.annotations) + if num_annotations_in_image > max_num_annotations_in_image: + max_num_annotations_in_image = num_annotations_in_image + if num_annotations_in_image < min_num_annotations_in_image: + min_num_annotations_in_image = num_annotations_in_image + if (num_images - num_negative_images) > 0: + avg_num_annotations_in_image = num_annotations / (num_images - num_negative_images) + avg_annotation_area = total_annotation_area / num_annotations + else: + avg_num_annotations_in_image = 0 + avg_annotation_area = 0 + + self._stats = { + "num_images": num_images, + "num_annotations": num_annotations, + "num_categories": num_categories, + "num_negative_images": num_negative_images, + "num_images_per_category": num_images_per_category, + "num_annotations_per_category": num_annotations_per_category, + "min_num_annotations_in_image": min_num_annotations_in_image, + "max_num_annotations_in_image": max_num_annotations_in_image, + "avg_num_annotations_in_image": avg_num_annotations_in_image, + "min_annotation_area": min_annotation_area, + "max_annotation_area": max_annotation_area, + "avg_annotation_area": avg_annotation_area, + "min_annotation_area_per_category": min_annotation_area_per_category, + "max_annotation_area_per_category": max_annotation_area_per_category, + } + + def split_coco_as_train_val(self, train_split_rate=0.9, numpy_seed=0): + """ + Split images into train-val and returns them as sahi.utils.coco.Coco objects. + + Args: + train_split_rate: float + numpy_seed: int + To fix the numpy seed. + + Returns: + result : dict + { + "train_coco": "", + "val_coco": "", + } + """ + # fix numpy numpy seed + np.random.seed(numpy_seed) + + # divide images + num_images = len(self.images) + shuffled_images = copy.deepcopy(self.images) + np.random.shuffle(shuffled_images) + num_train = int(num_images * train_split_rate) + train_images = shuffled_images[:num_train] + val_images = shuffled_images[num_train:] + + # form train val coco objects + train_coco = Coco( + name=self.name if self.name else "split" + "_train", + image_dir=self.image_dir, + ) + train_coco.images = train_images + train_coco.categories = self.categories + + val_coco = Coco(name=self.name if self.name else "split" + "_val", image_dir=self.image_dir) + val_coco.images = val_images + val_coco.categories = self.categories + + # return result + return { + "train_coco": train_coco, + "val_coco": val_coco, + } + + def export_as_yolov5(self, output_dir, train_split_rate=1, numpy_seed=0, mp=False): + """ + Exports current COCO dataset in ultralytics/yolov5 format. + Creates train val folders with image symlinks and txt files and a data yaml file. + + Args: + output_dir: str + Export directory. + train_split_rate: float + If given 1, will be exported as train split. + If given 0, will be exported as val split. + If in between 0-1, both train/val splits will be calculated and exported. + numpy_seed: int + To fix the numpy seed. + mp: bool + If True, multiprocess mode is on. + Should be called in 'if __name__ == __main__:' block. + """ + try: + import yaml + except ImportError: + raise ImportError( + 'Please run "pip install -U pyyaml" ' "to install yaml first for yolov5 formatted exporting." + ) + + # set split_mode + if 0 < train_split_rate and train_split_rate < 1: + split_mode = "TRAINVAL" + elif train_split_rate == 0: + split_mode = "VAL" + elif train_split_rate == 1: + split_mode = "TRAIN" + else: + raise ValueError("train_split_rate cannot be <0 or >1") + + # split dataset + if split_mode == "TRAINVAL": + result = self.split_coco_as_train_val( + train_split_rate=train_split_rate, + numpy_seed=numpy_seed, + ) + train_coco = result["train_coco"] + val_coco = result["val_coco"] + elif split_mode == "TRAIN": + train_coco = self + val_coco = None + elif split_mode == "VAL": + train_coco = None + val_coco = self + + # create train val image dirs + train_dir = "" + val_dir = "" + if split_mode in ["TRAINVAL", "TRAIN"]: + train_dir = Path(os.path.abspath(output_dir)) / "train/" + train_dir.mkdir(parents=True, exist_ok=True) # create dir + if split_mode in ["TRAINVAL", "VAL"]: + val_dir = Path(os.path.abspath(output_dir)) / "val/" + val_dir.mkdir(parents=True, exist_ok=True) # create dir + + # create image symlinks and annotation txts + if split_mode in ["TRAINVAL", "TRAIN"]: + export_yolov5_images_and_txts_from_coco_object( + output_dir=train_dir, + coco=train_coco, + ignore_negative_samples=self.ignore_negative_samples, + mp=mp, + ) + if split_mode in ["TRAINVAL", "VAL"]: + export_yolov5_images_and_txts_from_coco_object( + output_dir=val_dir, + coco=val_coco, + ignore_negative_samples=self.ignore_negative_samples, + mp=mp, + ) + + # create yolov5 data yaml + data = { + "train": str(train_dir), + "val": str(val_dir), + "nc": len(self.category_mapping), + "names": list(self.category_mapping.values()), + } + yaml_path = str(Path(output_dir) / "data.yml") + with open(yaml_path, "w") as outfile: + yaml.dump(data, outfile, default_flow_style=None) + + def get_subsampled_coco(self, subsample_ratio: int = 2, category_id: int = None): + """ + Subsamples images with subsample_ratio and returns as sahi.utils.coco.Coco object. + + Args: + subsample_ratio: int + 10 means take every 10th image with its annotations + category_id: int + subsample only images containing given category_id, if -1 then subsamples negative samples + Returns: + subsampled_coco: sahi.utils.coco.Coco + """ + subsampled_coco = Coco( + name=self.name, + image_dir=self.image_dir, + remapping_dict=self.remapping_dict, + ignore_negative_samples=self.ignore_negative_samples, + ) + subsampled_coco.add_categories_from_coco_category_list(self.json_categories) + + if category_id is not None: + # get images that contain given category id + images_that_contain_category: List[CocoImage] = [] + for image in self.images: + category_id_to_contains = defaultdict(lambda: 0) + annotation: CocoAnnotation + for annotation in image.annotations: + category_id_to_contains[annotation.category_id] = 1 + if category_id_to_contains[category_id]: + add_this_image = True + elif category_id == -1 and len(image.annotations) == 0: + # if category_id is given as -1, select negative samples + add_this_image = True + else: + add_this_image = False + + if add_this_image: + images_that_contain_category.append(image) + + # get images that does not contain given category id + images_that_doesnt_contain_category: List[CocoImage] = [] + for image in self.images: + category_id_to_contains = defaultdict(lambda: 0) + annotation: CocoAnnotation + for annotation in image.annotations: + category_id_to_contains[annotation.category_id] = 1 + if category_id_to_contains[category_id]: + add_this_image = False + elif category_id == -1 and len(image.annotations) == 0: + # if category_id is given as -1, dont select negative samples + add_this_image = False + else: + add_this_image = True + + if add_this_image: + images_that_doesnt_contain_category.append(image) + + if category_id: + selected_images = images_that_contain_category + # add images that does not contain given category without subsampling + for image_ind in range(len(images_that_doesnt_contain_category)): + subsampled_coco.add_image(images_that_doesnt_contain_category[image_ind]) + else: + selected_images = self.images + for image_ind in range(0, len(selected_images), subsample_ratio): + subsampled_coco.add_image(selected_images[image_ind]) + + return subsampled_coco + + def get_upsampled_coco(self, upsample_ratio: int = 2, category_id: int = None): + """ + Upsamples images with upsample_ratio and returns as sahi.utils.coco.Coco object. + + Args: + upsample_ratio: int + 10 means copy each sample 10 times + category_id: int + upsample only images containing given category_id, if -1 then upsamples negative samples + Returns: + upsampled_coco: sahi.utils.coco.Coco + """ + upsampled_coco = Coco( + name=self.name, + image_dir=self.image_dir, + remapping_dict=self.remapping_dict, + ignore_negative_samples=self.ignore_negative_samples, + ) + upsampled_coco.add_categories_from_coco_category_list(self.json_categories) + for ind in range(upsample_ratio): + for image_ind in range(len(self.images)): + # calculate add_this_image + if category_id is not None: + category_id_to_contains = defaultdict(lambda: 0) + annotation: CocoAnnotation + for annotation in self.images[image_ind].annotations: + category_id_to_contains[annotation.category_id] = 1 + if category_id_to_contains[category_id]: + add_this_image = True + elif category_id == -1 and len(self.images[image_ind].annotations) == 0: + # if category_id is given as -1, select negative samples + add_this_image = True + elif ind == 0: + # in first iteration add all images + add_this_image = True + else: + add_this_image = False + else: + add_this_image = True + + if add_this_image: + upsampled_coco.add_image(self.images[image_ind]) + + return upsampled_coco + + def get_area_filtered_coco(self, min=0, max=float("inf"), intervals_per_category=None): + """ + Filters annotation areas with given min and max values and returns remaining + images as sahi.utils.coco.Coco object. + + Args: + min: int + minimum allowed area + max: int + maximum allowed area + intervals_per_category: dict of dicts + { + "human": {"min": 20, "max": 10000}, + "vehicle": {"min": 50, "max": 15000}, + } + Returns: + area_filtered_coco: sahi.utils.coco.Coco + """ + area_filtered_coco = Coco( + name=self.name, + image_dir=self.image_dir, + remapping_dict=self.remapping_dict, + ignore_negative_samples=self.ignore_negative_samples, + ) + area_filtered_coco.add_categories_from_coco_category_list(self.json_categories) + for image in self.images: + is_valid_image = True + for annotation in image.annotations: + if intervals_per_category is not None and annotation.category_name in intervals_per_category.keys(): + category_based_min = intervals_per_category[annotation.category_name]["min"] + category_based_max = intervals_per_category[annotation.category_name]["max"] + if annotation.area < category_based_min or annotation.area > category_based_max: + is_valid_image = False + if annotation.area < min or annotation.area > max: + is_valid_image = False + if is_valid_image: + area_filtered_coco.add_image(image) + + return area_filtered_coco + + def get_coco_with_clipped_bboxes(self): + """ + Limits overflowing bounding boxes to image dimensions. + """ + from sahi.slicing import annotation_inside_slice + + coco = Coco( + name=self.name, + image_dir=self.image_dir, + remapping_dict=self.remapping_dict, + ignore_negative_samples=self.ignore_negative_samples, + ) + coco.add_categories_from_coco_category_list(self.json_categories) + + for coco_img in self.images: + img_dims = [0, 0, coco_img.width, coco_img.height] + coco_image = CocoImage( + file_name=coco_img.file_name, height=coco_img.height, width=coco_img.width, id=coco_img.id + ) + for coco_ann in coco_img.annotations: + ann_dict: Dict = coco_ann.json + if annotation_inside_slice(annotation=ann_dict, slice_bbox=img_dims): + shapely_ann = coco_ann.get_sliced_coco_annotation(img_dims) + bbox = ShapelyAnnotation.to_coco_bbox(shapely_ann._shapely_annotation) + coco_ann_from_shapely = CocoAnnotation( + bbox=bbox, + category_id=coco_ann.category_id, + category_name=coco_ann.category_name, + image_id=coco_ann.image_id, + ) + coco_image.add_annotation(coco_ann_from_shapely) + else: + continue + coco.add_image(coco_image) + return coco + + +def export_yolov5_images_and_txts_from_coco_object(output_dir, coco, ignore_negative_samples=False, mp=False): + """ + Creates image symlinks and annotation txts in yolo format from coco dataset. + + Args: + output_dir: str + Export directory. + coco: sahi.utils.coco.Coco + Initialized Coco object that contains images and categories. + ignore_negative_samples: bool + If True ignores images without annotations in all operations. + mp: bool + If True, multiprocess mode is on. + Should be called in 'if __name__ == __main__:' block. + """ + + print("generating image symlinks and annotation files for yolov5..."), + if mp: + with Pool(processes=48) as pool: + args = [(coco_image, coco.image_dir, output_dir, ignore_negative_samples) for coco_image in coco.images] + pool.starmap( + export_single_yolov5_image_and_corresponding_txt, + tqdm(args, total=len(args)), + ) + else: + for coco_image in tqdm(coco.images): + export_single_yolov5_image_and_corresponding_txt( + coco_image, coco.image_dir, output_dir, ignore_negative_samples + ) + + +def export_single_yolov5_image_and_corresponding_txt( + coco_image, coco_image_dir, output_dir, ignore_negative_samples=False +): + """ + Generates yolov5 formatted image symlink and annotation txt file. + + Args: + coco_image: sahi.utils.coco.CocoImage + coco_image_dir: str + output_dir: str + Export directory. + ignore_negative_samples: bool + If True ignores images without annotations in all operations. + """ + if not ignore_negative_samples or len(coco_image.annotations) > 0: + # skip images without suffix + # https://github.com/obss/sahi/issues/114 + if Path(coco_image.file_name).suffix == "": + print(f"image file has no suffix, skipping it: '{coco_image.file_name}'") + return + elif Path(coco_image.file_name).suffix in [".txt"]: # TODO: extend this list + print(f"image file has incorrect suffix, skipping it: '{coco_image.file_name}'") + return + # set coco and yolo image paths + if Path(coco_image.file_name).is_file(): + coco_image_path = os.path.abspath(coco_image.file_name) + else: + if coco_image_dir is None: + raise ValueError("You have to specify image_dir of Coco object for yolov5 conversion.") + + coco_image_path = os.path.abspath(str(Path(coco_image_dir) / coco_image.file_name)) + + yolo_image_path_temp = str(Path(output_dir) / Path(coco_image.file_name).name) + # increment target file name if already present + yolo_image_path = copy.deepcopy(yolo_image_path_temp) + name_increment = 2 + while Path(yolo_image_path).is_file(): + parent_dir = Path(yolo_image_path_temp).parent + filename = Path(yolo_image_path_temp).stem + filesuffix = Path(yolo_image_path_temp).suffix + filename = filename + "_" + str(name_increment) + yolo_image_path = str(parent_dir / (filename + filesuffix)) + name_increment += 1 + # create a symbolic link pointing to coco_image_path named yolo_image_path + os.symlink(coco_image_path, yolo_image_path) + # calculate annotation normalization ratios + width = coco_image.width + height = coco_image.height + dw = 1.0 / (width) + dh = 1.0 / (height) + # set annotation filepath + image_file_suffix = Path(yolo_image_path).suffix + yolo_annotation_path = yolo_image_path.replace(image_file_suffix, ".txt") + # create annotation file + annotations = coco_image.annotations + with open(yolo_annotation_path, "w") as outfile: + for annotation in annotations: + # convert coco bbox to yolo bbox + x_center = annotation.bbox[0] + annotation.bbox[2] / 2.0 + y_center = annotation.bbox[1] + annotation.bbox[3] / 2.0 + bbox_width = annotation.bbox[2] + bbox_height = annotation.bbox[3] + x_center = x_center * dw + y_center = y_center * dh + bbox_width = bbox_width * dw + bbox_height = bbox_height * dh + category_id = annotation.category_id + yolo_bbox = (x_center, y_center, bbox_width, bbox_height) + # save yolo annotation + outfile.write(str(category_id) + " " + " ".join([str(value) for value in yolo_bbox]) + "\n") + + +def update_categories(desired_name2id: dict, coco_dict: dict) -> dict: + """ + Rearranges category mapping of given COCO dictionary based on given category_mapping. + Can also be used to filter some of the categories. + + Arguments: + --------- + desired_name2id : dict + {"big_vehicle": 1, "car": 2, "human": 3} + coco_dict : dict + COCO formatted dictionary. + Returns: + --------- + coco_target : dict + COCO dict with updated/filtred categories. + """ + # so that original variable doesnt get affected + coco_source = copy.deepcopy(coco_dict) + + # init target coco dict + coco_target = {"images": [], "annotations": [], "categories": []} + + # init vars + currentid2desiredid_mapping = {} + # create category id mapping (currentid2desiredid_mapping) + for category in coco_source["categories"]: + current_category_id = category["id"] + current_category_name = category["name"] + if current_category_name in desired_name2id.keys(): + currentid2desiredid_mapping[current_category_id] = desired_name2id[current_category_name] + else: + # ignore categories that are not included in desired_name2id + currentid2desiredid_mapping[current_category_id] = -1 + + # update annotations + for annotation in coco_source["annotations"]: + current_category_id = annotation["category_id"] + desired_category_id = currentid2desiredid_mapping[current_category_id] + # append annotations with category id present in desired_name2id + if desired_category_id != -1: + # update cetegory id + annotation["category_id"] = desired_category_id + # append updated annotation to target coco dict + coco_target["annotations"].append(annotation) + + # create desired categories + categories = [] + for name in desired_name2id.keys(): + category = {} + category["name"] = category["supercategory"] = name + category["id"] = desired_name2id[name] + categories.append(category) + + # update categories + coco_target["categories"] = categories + + # update images + coco_target["images"] = coco_source["images"] + + return coco_target + + +def update_categories_from_file(desired_name2id: dict, coco_path: str, save_path: str) -> None: + """ + Rearranges category mapping of a COCO dictionary in coco_path based on given category_mapping. + Can also be used to filter some of the categories. + Arguments: + --------- + desired_name2id : dict + {"human": 1, "car": 2, "big_vehicle": 3} + coco_path : str + "dirname/coco.json" + """ + # load source coco dict + coco_source = load_json(coco_path) + + # update categories + coco_target = update_categories(desired_name2id, coco_source) + + # save modified coco file + save_json(coco_target, save_path) + + +def merge(coco_dict1: dict, coco_dict2: dict, desired_name2id: dict = None) -> dict: + """ + Combines 2 coco formatted annotations dicts, and returns the combined coco dict. + + Arguments: + --------- + coco_dict1 : dict + First coco dictionary. + coco_dict2 : dict + Second coco dictionary. + desired_name2id : dict + {"human": 1, "car": 2, "big_vehicle": 3} + Returns: + --------- + merged_coco_dict : dict + Merged COCO dict. + """ + + # copy input dicts so that original dicts are not affected + temp_coco_dict1 = copy.deepcopy(coco_dict1) + temp_coco_dict2 = copy.deepcopy(coco_dict2) + + # rearrange categories if any desired_name2id mapping is given + if desired_name2id is not None: + temp_coco_dict1 = update_categories(desired_name2id, temp_coco_dict1) + temp_coco_dict2 = update_categories(desired_name2id, temp_coco_dict2) + + # rearrange categories of the second coco based on first, if their categories are not the same + if temp_coco_dict1["categories"] != temp_coco_dict2["categories"]: + desired_name2id = {category["name"]: category["id"] for category in temp_coco_dict1["categories"]} + temp_coco_dict2 = update_categories(desired_name2id, temp_coco_dict2) + + # calculate first image and annotation index of the second coco file + max_image_id = np.array([image["id"] for image in coco_dict1["images"]]).max() + max_annotation_id = np.array([annotation["id"] for annotation in coco_dict1["annotations"]]).max() + + merged_coco_dict = temp_coco_dict1 + + for image in temp_coco_dict2["images"]: + image["id"] += max_image_id + 1 + merged_coco_dict["images"].append(image) + + for annotation in temp_coco_dict2["annotations"]: + annotation["image_id"] += max_image_id + 1 + annotation["id"] += max_annotation_id + 1 + merged_coco_dict["annotations"].append(annotation) + + return merged_coco_dict + + +def merge_from_list(coco_dict_list, desired_name2id=None, verbose=1): + """ + Combines a list of coco formatted annotations dicts, and returns the combined coco dict. + + Arguments: + --------- + coco_dict_list: list of dict + A list of coco dicts + desired_name2id: dict + {"human": 1, "car": 2, "big_vehicle": 3} + verbose: bool + If True, merging info is printed + Returns: + --------- + merged_coco_dict: dict + Merged COCO dict. + """ + if verbose: + if not desired_name2id: + print("'desired_name2id' is not specified, combining all categories.") + + # create desired_name2id by combinin all categories, if desired_name2id is not specified + if desired_name2id is None: + desired_name2id = {} + ind = 0 + for coco_dict in coco_dict_list: + temp_categories = copy.deepcopy(coco_dict["categories"]) + for temp_category in temp_categories: + if temp_category["name"] not in desired_name2id: + desired_name2id[temp_category["name"]] = ind + ind += 1 + else: + continue + + for ind, coco_dict in enumerate(coco_dict_list): + if ind == 0: + merged_coco_dict = copy.deepcopy(coco_dict) + else: + merged_coco_dict = merge(merged_coco_dict, coco_dict, desired_name2id) + + # print categories + if verbose: + print( + "Categories are formed as:\n", + merged_coco_dict["categories"], + ) + + return merged_coco_dict + + +def merge_from_file(coco_path1: str, coco_path2: str, save_path: str): + """ + Combines 2 coco formatted annotations files given their paths, and saves the combined file to save_path. + + Arguments: + --------- + coco_path1 : str + Path for the first coco file. + coco_path2 : str + Path for the second coco file. + save_path : str + "dirname/coco.json" + """ + + # load coco files to be combined + coco_dict1 = load_json(coco_path1) + coco_dict2 = load_json(coco_path2) + + # merge coco dicts + merged_coco_dict = merge(coco_dict1, coco_dict2) + + # save merged coco dict + save_json(merged_coco_dict, save_path) + + +def get_imageid2annotationlist_mapping( + coco_dict: dict, +) -> Dict[int, List[CocoAnnotation]]: + """ + Get image_id to annotationlist mapping for faster indexing. + + Arguments + --------- + coco_dict : dict + coco dict with fields "images", "annotations", "categories" + Returns + ------- + image_id_to_annotation_list : dict + { + 1: [CocoAnnotation, CocoAnnotation, CocoAnnotation], + 2: [CocoAnnotation] + } + + where + CocoAnnotation = { + 'area': 2795520, + 'bbox': [491.0, 1035.0, 153.0, 182.0], + 'category_id': 1, + 'id': 1, + 'image_id': 1, + 'iscrowd': 0, + 'segmentation': [[491.0, 1035.0, 644.0, 1035.0, 644.0, 1217.0, 491.0, 1217.0]] + } + """ + image_id_to_annotation_list: Dict = defaultdict(list) + print("indexing coco dataset annotations...") + for annotation in coco_dict["annotations"]: + image_id = annotation["image_id"] + image_id_to_annotation_list[image_id].append(annotation) + + return image_id_to_annotation_list + + +def create_coco_dict(images, categories, ignore_negative_samples=False, image_id_setting="auto"): + """ + Creates COCO dict with fields "images", "annotations", "categories". + + Arguments + --------- + images : List of CocoImage containing a list of CocoAnnotation + categories : List of Dict + COCO categories + ignore_negative_samples : Bool + If True, images without annotations are ignored + image_id_setting: str + how to assign image ids while exporting can be + auto --> will assign id from scratch (.id will be ignored) + manual --> you will need to provide image ids in instances (.id can not be None) + Returns + ------- + coco_dict : Dict + COCO dict with fields "images", "annotations", "categories" + """ + # assertion of parameters + if image_id_setting not in ["auto", "manual"]: + raise ValueError(f"'image_id_setting' should be one of ['auto', 'manual']") + + # define accumulators + image_index = 1 + annotation_id = 1 + coco_dict = dict(images=[], annotations=[], categories=categories) + for coco_image in images: + # get coco annotations + coco_annotations = coco_image.annotations + # get num annotations + num_annotations = len(coco_annotations) + # if ignore_negative_samples is True and no annotations, skip image + if ignore_negative_samples and num_annotations == 0: + continue + else: + # get image_id + if image_id_setting == "auto": + image_id = image_index + image_index += 1 + elif image_id_setting == "manual": + if coco_image.id is None: + raise ValueError("'coco_image.id' should be set manually when image_id_setting == 'manual'") + image_id = coco_image.id + + # create coco image object + out_image = { + "height": coco_image.height, + "width": coco_image.width, + "id": image_id, + "file_name": coco_image.file_name, + } + coco_dict["images"].append(out_image) + + # do the same for image annotations + for coco_annotation in coco_annotations: + # create coco annotation object + out_annotation = { + "iscrowd": 0, + "image_id": image_id, + "bbox": coco_annotation.bbox, + "segmentation": coco_annotation.segmentation, + "category_id": coco_annotation.category_id, + "id": annotation_id, + "area": coco_annotation.area, + } + coco_dict["annotations"].append(out_annotation) + # increment annotation id + annotation_id += 1 + + # return coco dict + return coco_dict + + +def create_coco_prediction_array(images, ignore_negative_samples=False, image_id_setting="auto"): + """ + Creates COCO prediction array which is list of predictions + + Arguments + --------- + images : List of CocoImage containing a list of CocoAnnotation + ignore_negative_samples : Bool + If True, images without predictions are ignored + image_id_setting: str + how to assign image ids while exporting can be + auto --> will assign id from scratch (.id will be ignored) + manual --> you will need to provide image ids in instances (.id can not be None) + Returns + ------- + coco_prediction_array : List + COCO predictions array + """ + # assertion of parameters + if image_id_setting not in ["auto", "manual"]: + raise ValueError(f"'image_id_setting' should be one of ['auto', 'manual']") + # define accumulators + image_index = 1 + prediction_id = 1 + predictions_array = [] + for coco_image in images: + # get coco predictions + coco_predictions = coco_image.predictions + # get num predictions + num_predictions = len(coco_predictions) + # if ignore_negative_samples is True and no annotations, skip image + if ignore_negative_samples and num_predictions == 0: + continue + else: + # get image_id + if image_id_setting == "auto": + image_id = image_index + image_index += 1 + elif image_id_setting == "manual": + if coco_image.id is None: + raise ValueError("'coco_image.id' should be set manually when image_id_setting == 'manual'") + image_id = coco_image.id + + # create coco prediction object + for prediction_index, coco_prediction in enumerate(coco_predictions): + # create coco prediction object + out_prediction = { + "id": prediction_id, + "image_id": image_id, + "bbox": coco_prediction.bbox, + "score": coco_prediction.score, + "category_id": coco_prediction.category_id, + "segmentation": coco_prediction.segmentation, + "iscrowd": coco_prediction.iscrowd, + "area": coco_prediction.area, + } + predictions_array.append(out_prediction) + + # increment prediction id + prediction_id += 1 + + # return predictions array + return predictions_array + + +def add_bbox_and_area_to_coco( + source_coco_path: str = "", + target_coco_path: str = "", + add_bbox: bool = True, + add_area: bool = True, +) -> dict: + """ + Takes single coco dataset file path, calculates and fills bbox and area fields of the annotations + and exports the updated coco dict. + Returns: + coco_dict : dict + Updated coco dict + """ + coco_dict = load_json(source_coco_path) + coco_dict = copy.deepcopy(coco_dict) + + annotations = coco_dict["annotations"] + for ind, annotation in enumerate(annotations): + # assign annotation bbox + if add_bbox: + coco_polygons = [] + [coco_polygons.extend(coco_polygon) for coco_polygon in annotation["segmentation"]] + minx, miny, maxx, maxy = list( + [ + min(coco_polygons[0::2]), + min(coco_polygons[1::2]), + max(coco_polygons[0::2]), + max(coco_polygons[1::2]), + ] + ) + x, y, width, height = ( + round(minx), + round(miny), + round(maxx - minx), + round(maxy - miny), + ) + annotations[ind]["bbox"] = [x, y, width, height] + + # assign annotation area + if add_area: + shapely_multipolygon = get_shapely_multipolygon(coco_segmentation=annotation["segmentation"]) + annotations[ind]["area"] = shapely_multipolygon.area + + coco_dict["annotations"] = annotations + save_json(coco_dict, target_coco_path) + return coco_dict + + +@dataclass +class DatasetClassCounts: + """Stores the number of images that include each category in a dataset""" + + counts: dict + total_images: int + + def frequencies(self): + """calculates the frequenct of images that contain each category""" + return {cid: count / self.total_images for cid, count in self.counts.items()} + + def __add__(self, o): + total = self.total_images + o.total_images + exclusive_keys = set(o.counts.keys()) - set(self.counts.keys()) + counts = {} + for k, v in self.counts.items(): + counts[k] = v + o.counts.get(k, 0) + for k in exclusive_keys: + counts[k] = o.counts[k] + return DatasetClassCounts(counts, total) + + +def count_images_with_category(coco_file_path): + """Reads a coco dataset file and returns an DatasetClassCounts object + that stores the number of images that include each category in a dataset + Returns: DatasetClassCounts object + coco_file_path : str + path to coco dataset file + """ + + image_id_2_category_2_count = defaultdict(lambda: defaultdict(lambda: 0)) + coco = load_json(coco_file_path) + for annotation in coco["annotations"]: + image_id = annotation["image_id"] + cid = annotation["category_id"] + image_id_2_category_2_count[image_id][cid] = image_id_2_category_2_count[image_id][cid] + 1 + + category_2_count = defaultdict(lambda: 0) + for image_id, image_category_2_count in image_id_2_category_2_count.items(): + for cid, count in image_category_2_count.items(): + if count > 0: + category_2_count[cid] = category_2_count[cid] + 1 + + category_2_count = dict(category_2_count) + total_images = len(image_id_2_category_2_count.keys()) + return DatasetClassCounts(category_2_count, total_images) + + +class CocoVid: + def __init__(self, name=None, remapping_dict=None): + """ + Creates CocoVid object. + + Args: + name: str + Name of the CocoVid dataset, it determines exported json name. + remapping_dict: dict + {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1 + """ + self.name = name + self.remapping_dict = remapping_dict + self.categories = [] + self.videos = [] + + def add_categories_from_coco_category_list(self, coco_category_list): + """ + Creates CocoCategory object using coco category list. + + Args: + coco_category_list: List[Dict] + [ + {"supercategory": "person", "id": 1, "name": "person"}, + {"supercategory": "vehicle", "id": 2, "name": "bicycle"} + ] + """ + + for coco_category in coco_category_list: + if self.remapping_dict is not None: + for source_id in self.remapping_dict.keys(): + if coco_category["id"] == source_id: + target_id = self.remapping_dict[source_id] + coco_category["id"] = target_id + + self.add_category(CocoCategory.from_coco_category(coco_category)) + + def add_category(self, category): + """ + Adds category to this CocoVid instance + + Args: + category: CocoCategory + """ + + if type(category) != CocoCategory: + raise TypeError("category must be a CocoCategory instance") + self.categories.append(category) + + @property + def json_categories(self): + categories = [] + for category in self.categories: + categories.append(category.json) + return categories + + @property + def category_mapping(self): + category_mapping = {} + for category in self.categories: + category_mapping[category.id] = category.name + return category_mapping + + def add_video(self, video): + """ + Adds video to this CocoVid instance + + Args: + video: CocoVideo + """ + + if type(video) != CocoVideo: + raise TypeError("video must be a CocoVideo instance") + self.videos.append(video) + + @property + def json(self): + coco_dict = { + "videos": [], + "images": [], + "annotations": [], + "categories": self.json_categories, + } + annotation_id = 1 + image_id = 1 + video_id = 1 + global_instance_id = 1 + for coco_video in self.videos: + coco_video.id = video_id + coco_dict["videos"].append(coco_video.json) + + frame_id = 0 + instance_id_set = set() + for cocovid_image in coco_video.images: + cocovid_image.id = image_id + cocovid_image.frame_id = frame_id + cocovid_image.video_id = coco_video.id + coco_dict["images"].append(cocovid_image.json) + + for cocovid_annotation in cocovid_image.annotations: + instance_id_set.add(cocovid_annotation.instance_id) + cocovid_annotation.instance_id += global_instance_id + + cocovid_annotation.id = annotation_id + cocovid_annotation.image_id = cocovid_image.id + coco_dict["annotations"].append(cocovid_annotation.json) + + # increment annotation_id + annotation_id = copy.deepcopy(annotation_id + 1) + # increment image_id and frame_id + image_id = copy.deepcopy(image_id + 1) + frame_id = copy.deepcopy(frame_id + 1) + # increment video_id and global_instance_id + video_id = copy.deepcopy(video_id + 1) + global_instance_id += len(instance_id_set) + + return coco_dict + + +def remove_invalid_coco_results(result_list_or_path: Union[List, str], dataset_dict_or_path: Union[Dict, str] = None): + """ + Removes invalid predictions from coco result such as: + - negative bbox value + - extreme bbox value + + Args: + result_list_or_path: path or list for coco result json + dataset_dict_or_path (optional): path or dict for coco dataset json + """ + + # prepare coco results + if isinstance(result_list_or_path, str): + result_list = load_json(result_list_or_path) + elif isinstance(result_list_or_path, list): + result_list = result_list_or_path + else: + raise TypeError('incorrect type for "result_list_or_path"') + + # prepare image info from coco dataset + if dataset_dict_or_path is not None: + if isinstance(dataset_dict_or_path, str): + dataset_dict = load_json(dataset_dict_or_path) + elif isinstance(dataset_dict_or_path, dict): + dataset_dict = dataset_dict_or_path + else: + raise TypeError('incorrect type for "dataset_dict"') + image_id_to_height = {} + image_id_to_width = {} + for coco_image in dataset_dict["images"]: + image_id_to_height[coco_image["id"]] = coco_image["height"] + image_id_to_width[coco_image["id"]] = coco_image["width"] + + # remove invalid predictions + fixed_result_list = [] + for coco_result in result_list: + bbox = coco_result["bbox"] + # ignore invalid predictions + if not bbox: + print("ignoring invalid prediction with empty bbox") + continue + if bbox[0] < 0 or bbox[1] < 0 or bbox[2] < 0 or bbox[3] < 0: + print(f"ignoring invalid prediction with bbox: {bbox}") + continue + if dataset_dict_or_path is not None: + if ( + bbox[1] > image_id_to_height[coco_result["image_id"]] + or bbox[3] > image_id_to_height[coco_result["image_id"]] + or bbox[0] > image_id_to_width[coco_result["image_id"]] + or bbox[2] > image_id_to_width[coco_result["image_id"]] + ): + print(f"ignoring invalid prediction with bbox: {bbox}") + continue + fixed_result_list.append(coco_result) + return fixed_result_list + + +def export_coco_as_yolov5( + output_dir: str, train_coco: Coco = None, val_coco: Coco = None, train_split_rate: float = 0.9, numpy_seed=0 +): + """ + Exports current COCO dataset in ultralytics/yolov5 format. + Creates train val folders with image symlinks and txt files and a data yaml file. + + Args: + output_dir: str + Export directory. + train_coco: Coco + coco object for training + val_coco: Coco + coco object for val + train_split_rate: float + train split rate between 0 and 1. will be used when val_coco is None. + numpy_seed: int + To fix the numpy seed. + + Returns: + yaml_path: str + Path for the exported yolov5 data.yml + """ + try: + import yaml + except ImportError: + raise ImportError('Please run "pip install -U pyyaml" ' "to install yaml first for yolov5 formatted exporting.") + + # set split_mode + if train_coco and not val_coco: + split_mode = True + elif train_coco and val_coco: + split_mode = False + else: + raise ValueError("'train_coco' have to be provided") + + # check train_split_rate + if split_mode and not (0 < train_split_rate < 1): + raise ValueError("train_split_rate cannot be <0 or >1") + + # split dataset + if split_mode: + result = train_coco.split_coco_as_train_val( + train_split_rate=train_split_rate, + numpy_seed=numpy_seed, + ) + train_coco = result["train_coco"] + val_coco = result["val_coco"] + + # create train val image dirs + train_dir = Path(os.path.abspath(output_dir)) / "train/" + train_dir.mkdir(parents=True, exist_ok=True) # create dir + val_dir = Path(os.path.abspath(output_dir)) / "val/" + val_dir.mkdir(parents=True, exist_ok=True) # create dir + + # create image symlinks and annotation txts + export_yolov5_images_and_txts_from_coco_object( + output_dir=train_dir, + coco=train_coco, + ignore_negative_samples=train_coco.ignore_negative_samples, + mp=False, + ) + export_yolov5_images_and_txts_from_coco_object( + output_dir=val_dir, + coco=val_coco, + ignore_negative_samples=val_coco.ignore_negative_samples, + mp=False, + ) + + # create yolov5 data yaml + data = { + "train": str(train_dir), + "val": str(val_dir), + "nc": len(train_coco.category_mapping), + "names": list(train_coco.category_mapping.values()), + } + yaml_path = str(Path(output_dir) / "data.yml") + with open(yaml_path, "w") as outfile: + yaml.dump(data, outfile, default_flow_style=None) + + return yaml_path + + +def export_coco_as_yolov5_via_yml(yml_path: str, output_dir: str, train_split_rate: float = 0.9, numpy_seed=0): + """ + Exports current COCO dataset in ultralytics/yolov5 format. + Creates train val folders with image symlinks and txt files and a data yaml file. + Uses a yml file as input. + + Args: + yml_path: str + file should contain these fields: + train_json_path: str + train_image_dir: str + val_json_path: str + val_image_dir: str + output_dir: str + Export directory. + train_split_rate: float + train split rate between 0 and 1. will be used when val_json_path is None. + numpy_seed: int + To fix the numpy seed. + + Returns: + yaml_path: str + Path for the exported yolov5 data.yml + """ + try: + import yaml + except ImportError: + raise ImportError('Please run "pip install -U pyyaml" ' "to install yaml first for yolov5 formatted exporting.") + + with open(yml_path, "r") as stream: + config_dict = yaml.safe_load(stream) + + if config_dict["train_json_path"]: + if not config_dict["train_image_dir"]: + raise ValueError(f"{yml_path} is missing `train_image_dir`") + train_coco = Coco.from_coco_dict_or_path( + config_dict["train_json_path"], image_dir=config_dict["train_image_dir"] + ) + else: + train_coco = None + + if config_dict["val_json_path"]: + if not config_dict["val_image_dir"]: + raise ValueError(f"{yml_path} is missing `val_image_dir`") + val_coco = Coco.from_coco_dict_or_path(config_dict["val_json_path"], image_dir=config_dict["val_image_dir"]) + else: + val_coco = None + + yaml_path = export_coco_as_yolov5( + output_dir=output_dir, + train_coco=train_coco, + val_coco=val_coco, + train_split_rate=train_split_rate, + numpy_seed=numpy_seed, + ) + + return yaml_path diff --git a/sahi/utils/compatibility.py b/sahi/utils/compatibility.py new file mode 100644 index 0000000..e734a08 --- /dev/null +++ b/sahi/utils/compatibility.py @@ -0,0 +1,12 @@ +def fix_shift_amount_list(shift_amount_list): + # compatilibty for sahi v0.8.15 + if isinstance(shift_amount_list[0], (int, float)): + shift_amount_list = [shift_amount_list] + return shift_amount_list + + +def fix_full_shape_list(full_shape_list): + # compatilibty for sahi v0.8.15 + if full_shape_list is not None and isinstance(full_shape_list[0], (int, float)): + full_shape_list = [full_shape_list] + return full_shape_list diff --git a/sahi/utils/cv.py b/sahi/utils/cv.py new file mode 100644 index 0000000..c388748 --- /dev/null +++ b/sahi/utils/cv.py @@ -0,0 +1,596 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import copy +import os +import random +import time +from typing import List, Optional, Union + +import cv2 +import numpy as np +import requests +from PIL import Image + +from sahi.utils.file import Path + +IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tiff", ".bmp"] +VIDEO_EXTENSIONS = [".mp4", ".mkv", ".flv", ".avi", ".ts", ".mpg", ".mov", "wmv"] + + +class Colors: + # color palette + def __init__(self): + hex = ( + "FF3838", + "2C99A8", + "FF701F", + "6473FF", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "FF9D97", + "00C2FF", + "344593", + "FFB21D", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb("#" + c) for c in hex] + self.n = len(self.palette) + + def __call__(self, i, bgr=False): + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): # rgb order + return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) + + +def crop_object_predictions( + image: np.ndarray, + object_prediction_list, + output_dir: str = "", + file_name: str = "prediction_visual", + export_format: str = "png", +): + """ + Crops bounding boxes over the source image and exports it to output folder. + Arguments: + object_predictions: a list of prediction.ObjectPrediction + output_dir: directory for resulting visualization to be exported + file_name: exported file will be saved as: output_dir+file_name+".png" + export_format: can be specified as 'jpg' or 'png' + """ + # create output folder if not present + Path(output_dir).mkdir(parents=True, exist_ok=True) + # add bbox and mask to image if present + for ind, object_prediction in enumerate(object_prediction_list): + # deepcopy object_prediction_list so that original is not altered + object_prediction = object_prediction.deepcopy() + bbox = object_prediction.bbox.to_voc_bbox() + category_id = object_prediction.category.id + # crop detections + # deepcopy crops so that original is not altered + cropped_img = copy.deepcopy( + image[ + int(bbox[1]) : int(bbox[3]), + int(bbox[0]) : int(bbox[2]), + :, + ] + ) + save_path = os.path.join( + output_dir, + file_name + "_box" + str(ind) + "_class" + str(category_id) + "." + export_format, + ) + cv2.imwrite(save_path, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR)) + + +def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False): + """ + Reads image from path and saves as given extension. + """ + image = cv2.imread(read_path) + pre, ext = os.path.splitext(read_path) + if grayscale: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + pre = pre + "_gray" + save_path = pre + "." + extension + cv2.imwrite(save_path, image) + + +def read_large_image(image_path: str): + use_cv2 = True + # read image, cv2 fails on large files + try: + # convert to rgb (cv2 reads in bgr) + img_cv2 = cv2.imread(image_path, 1) + image0 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) + except: + try: + import skimage.io + except ImportError: + raise ImportError( + 'Please run "pip install -U scikit-image" ' "to install scikit-image first for large image handling." + ) + image0 = skimage.io.imread(image_path, as_grey=False).astype(np.uint8) # [::-1] + use_cv2 = False + return image0, use_cv2 + + +def read_image(image_path: str): + """ + Loads image as numpy array from given path. + """ + # read image + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # return image + return image + + +def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False): + """ + Loads an image as PIL.Image.Image. + + Args: + image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image + """ + # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil + Image.MAX_IMAGE_PIXELS = None + + if isinstance(image, Image.Image): + image_pil = image + elif isinstance(image, str): + # read image if str image path is provided + try: + image_pil = Image.open( + requests.get(image, stream=True).raw if str(image).startswith("http") else image + ).convert("RGB") + if exif_fix: + image_pil = exif_transpose(image_pil) + except: # handle large/tiff image reading + try: + import skimage.io + except ImportError: + raise ImportError("Please run 'pip install -U scikit-image imagecodecs' for large image handling.") + image_sk = skimage.io.imread(image).astype(np.uint8) + if len(image_sk.shape) == 2: # b&w + image_pil = Image.fromarray(image_sk, mode="1") + elif image_sk.shape[2] == 4: # rgba + image_pil = Image.fromarray(image_sk, mode="RGBA") + elif image_sk.shape[2] == 3: # rgb + image_pil = Image.fromarray(image_sk, mode="RGB") + else: + raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.") + elif isinstance(image, np.ndarray): + if image.shape[0] < 5: # image in CHW + image = image[:, :, ::-1] + image_pil = Image.fromarray(image) + else: + raise TypeError("read image with 'pillow' using 'Image.open()'") + return image_pil + + +def select_random_color(): + """ + Selects random color. + """ + colors = [ + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [0, 255, 255], + [255, 255, 0], + [255, 0, 255], + [80, 70, 180], + [250, 80, 190], + [245, 145, 50], + [70, 150, 250], + [50, 190, 190], + ] + return colors[random.randrange(0, 10)] + + +def apply_color_mask(image: np.ndarray, color: tuple): + """ + Applies color mask to given input image. + """ + r = np.zeros_like(image).astype(np.uint8) + g = np.zeros_like(image).astype(np.uint8) + b = np.zeros_like(image).astype(np.uint8) + + (r[image == 1], g[image == 1], b[image == 1]) = color + colored_mask = np.stack([r, g, b], axis=2) + return colored_mask + + +def get_video_reader( + source: str, + save_dir: str, + frame_skip_interval: int, + export_visual: bool = False, + view_visual: bool = False, +): + """ + Creates OpenCV video capture object from given video file path. + + Args: + source: Video file path + save_dir: Video export directory + frame_skip_interval: Frame skip interval + export_visual: Set True if you want to export visuals + view_visual: Set True if you want to render visual + + Returns: + iterator: Pillow Image + video_writer: cv2.VideoWriter + video_file_name: video name with extension + """ + # get video name with extension + video_file_name = os.path.basename(source) + # get video from video path + video_capture = cv2.VideoCapture(source) + + num_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) + if view_visual: + num_frames /= frame_skip_interval + 1 + num_frames = int(num_frames) + + def read_video_frame(video_capture, frame_skip_interval): + if view_visual: + cv2.imshow("Prediction of {}".format(str(video_file_name)), cv2.WINDOW_AUTOSIZE) + + while video_capture.isOpened: + + frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES) + video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num + frame_skip_interval) + + k = cv2.waitKey(20) + frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES) + + if k == 27: + print( + "\n===========================Closing===========================" + ) # Exit the prediction, Key = Esc + exit() + if k == 100: + frame_num += 100 # Skip 100 frames, Key = d + if k == 97: + frame_num -= 100 # Prev 100 frames, Key = a + if k == 103: + frame_num += 20 # Skip 20 frames, Key = g + if k == 102: + frame_num -= 20 # Prev 20 frames, Key = f + video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num) + + ret, frame = video_capture.read() + if not ret: + print("\n=========================== Video Ended ===========================") + break + yield Image.fromarray(frame) + + else: + while video_capture.isOpened: + frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES) + video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num + frame_skip_interval) + + ret, frame = video_capture.read() + if not ret: + print("\n=========================== Video Ended ===========================") + break + yield Image.fromarray(frame) + + if export_visual: + # get video properties and create VideoWriter object + if frame_skip_interval != 0: + fps = video_capture.get(cv2.CAP_PROP_FPS) # original fps of video + # The fps of export video is increasing during view_image because frame is skipped + fps = ( + fps / frame_skip_interval + ) # How many time_interval equals to original fps. One time_interval skip x frames. + else: + fps = video_capture.get(cv2.CAP_PROP_FPS) + + w = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + size = (w, h) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + video_writer = cv2.VideoWriter(os.path.join(save_dir, video_file_name), fourcc, fps, size) + else: + video_writer = None + + return read_video_frame(video_capture, frame_skip_interval), video_writer, video_file_name, num_frames + + +def visualize_prediction( + image: np.ndarray, + boxes: List[List], + classes: List[str], + masks: Optional[List[np.ndarray]] = None, + rect_th: float = None, + text_size: float = None, + text_th: float = None, + color: tuple = None, + output_dir: Optional[str] = None, + file_name: Optional[str] = "prediction_visual", +): + """ + Visualizes prediction classes, bounding boxes over the source image + and exports it to output folder. + """ + elapsed_time = time.time() + # deepcopy image so that original is not altered + image = copy.deepcopy(image) + # select predefined classwise color palette if not specified + if color is None: + colors = Colors() + else: + colors = None + # set rect_th for boxes + rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.003), 2) + # set text_th for category names + text_th = text_th or max(rect_th - 1, 1) + # set text_size for category names + text_size = text_size or rect_th / 3 + # add bbox and mask to image if present + for i in range(len(boxes)): + # deepcopy boxso that original is not altered + box = copy.deepcopy(boxes[i]) + class_ = classes[i] + + # set color + if colors is not None: + color = colors(class_) + # visualize masks if present + if masks is not None: + # deepcopy mask so that original is not altered + mask = copy.deepcopy(masks[i]) + # draw mask + rgb_mask = apply_color_mask(np.squeeze(mask), color) + image = cv2.addWeighted(image, 1, rgb_mask, 0.7, 0) + # set bbox points + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + # visualize boxes + cv2.rectangle( + image, + p1, + p2, + color=color, + thickness=rect_th, + ) + # arange bounding box text location + label = f"{class_}" + w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height + outside = p1[1] - h - 3 >= 0 # label fits outside box + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + # add bounding box text + cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + image, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + text_size, + (255, 255, 255), + thickness=text_th, + ) + if output_dir: + # create output folder if not present + Path(output_dir).mkdir(parents=True, exist_ok=True) + # save inference result + save_path = os.path.join(output_dir, file_name + ".png") + cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + + elapsed_time = time.time() - elapsed_time + return {"image": image, "elapsed_time": elapsed_time} + + +def visualize_object_predictions( + image: np.array, + object_prediction_list, + rect_th: int = None, + text_size: float = None, + text_th: float = None, + color: tuple = None, + output_dir: Optional[str] = None, + file_name: str = "prediction_visual", + export_format: str = "png", +): + """ + Visualizes prediction category names, bounding boxes over the source image + and exports it to output folder. + Arguments: + object_prediction_list: a list of prediction.ObjectPrediction + rect_th: rectangle thickness + text_size: size of the category name over box + text_th: text thickness + color: annotation color in the form: (0, 255, 0) + output_dir: directory for resulting visualization to be exported + file_name: exported file will be saved as: output_dir+file_name+".png" + export_format: can be specified as 'jpg' or 'png' + """ + elapsed_time = time.time() + # deepcopy image so that original is not altered + image = copy.deepcopy(image) + # select predefined classwise color palette if not specified + if color is None: + colors = Colors() + else: + colors = None + # set rect_th for boxes + rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.001), 1) + # set text_th for category names + text_th = text_th or max(rect_th - 1, 1) + # set text_size for category names + text_size = text_size or rect_th / 3 + # add bbox and mask to image if present + for object_prediction in object_prediction_list: + # deepcopy object_prediction_list so that original is not altered + object_prediction = object_prediction.deepcopy() + + bbox = object_prediction.bbox.to_voc_bbox() + category_name = object_prediction.category.name + score = object_prediction.score.value + + # set color + if colors is not None: + color = colors(object_prediction.category.id) + # visualize masks if present + if object_prediction.mask is not None: + # deepcopy mask so that original is not altered + mask = object_prediction.mask.bool_mask + # draw mask + rgb_mask = apply_color_mask(mask, color) + image = cv2.addWeighted(image, 1, rgb_mask, 0.4, 0) + # set bbox points + p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])) + # visualize boxes + cv2.rectangle( + image, + p1, + p2, + color=color, + thickness=rect_th, + ) + # arange bounding box text location + label = f"{category_name} {score:.2f}" + w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height + outside = p1[1] - h - 3 >= 0 # label fits outside box + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + # add bounding box text + cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + image, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + text_size, + (255, 255, 255), + thickness=text_th, + ) + + # export if output_dir is present + if output_dir is not None: + # export image with predictions + Path(output_dir).mkdir(parents=True, exist_ok=True) + # save inference result + save_path = str(Path(output_dir) / (file_name + "." + export_format)) + cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + + elapsed_time = time.time() - elapsed_time + return {"image": image, "elapsed_time": elapsed_time} + + +def get_coco_segmentation_from_bool_mask(bool_mask): + """ + Convert boolean mask to coco segmentation format + [ + [x1, y1, x2, y2, x3, y3, ...], + [x1, y1, x2, y2, x3, y3, ...], + ... + ] + """ + # Generate polygons from mask + mask = np.squeeze(bool_mask) + mask = mask.astype(np.uint8) + mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) + polygons = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE, offset=(-1, -1)) + polygons = polygons[0] if len(polygons) == 2 else polygons[1] + # Convert polygon to coco segmentation + coco_segmentation = [] + for polygon in polygons: + segmentation = polygon.flatten().tolist() + # at least 3 points needed for a polygon + if len(segmentation) >= 6: + coco_segmentation.append(segmentation) + return coco_segmentation + + +def get_bool_mask_from_coco_segmentation(coco_segmentation, width, height): + """ + Convert coco segmentation to 2D boolean mask of given height and width + """ + size = [height, width] + points = [np.array(point).reshape(-1, 2).round().astype(int) for point in coco_segmentation] + bool_mask = np.zeros(size) + bool_mask = cv2.fillPoly(bool_mask, points, 1) + bool_mask.astype(bool) + return bool_mask + + +def get_bbox_from_bool_mask(bool_mask): + """ + Generate voc bbox ([xmin, ymin, xmax, ymax]) from given bool_mask (2D np.ndarray) + """ + rows = np.any(bool_mask, axis=1) + cols = np.any(bool_mask, axis=0) + + if not np.any(rows) or not np.any(cols): + return None + + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + width = xmax - xmin + height = ymax - ymin + + if width == 0 or height == 0: + return None + + return [xmin, ymin, xmax, ymax] + + +def normalize_numpy_image(image: np.ndarray): + """ + Normalizes numpy image + """ + return image / np.max(image) + + +def ipython_display(image: np.ndarray): + """ + Displays numpy image in notebook. + + If input image is in range 0..1, please first multiply img by 255 + Assumes image is ndarray of shape [height, width, channels] where channels can be 1, 3 or 4 + """ + import IPython + + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + _, ret = cv2.imencode(".png", image) + i = IPython.display.Image(data=ret) + IPython.display.display(i) + + +def exif_transpose(image: Image.Image): + """ + Transpose a PIL image accordingly if it has an EXIF Orientation tag. + Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose() + :param image: The image to transpose. + :return: An image. + """ + exif = image.getexif() + orientation = exif.get(0x0112, 1) # default 1 + if orientation > 1: + method = { + 2: Image.FLIP_LEFT_RIGHT, + 3: Image.ROTATE_180, + 4: Image.FLIP_TOP_BOTTOM, + 5: Image.TRANSPOSE, + 6: Image.ROTATE_270, + 7: Image.TRANSVERSE, + 8: Image.ROTATE_90, + }.get(orientation) + if method is not None: + image = image.transpose(method) + del exif[0x0112] + image.info["exif"] = exif.tobytes() + return image diff --git a/sahi/utils/detectron2.py b/sahi/utils/detectron2.py new file mode 100644 index 0000000..096270a --- /dev/null +++ b/sahi/utils/detectron2.py @@ -0,0 +1,21 @@ +from pathlib import Path + + +class Detectron2TestConstants: + FASTERCNN_MODEL_ZOO_NAME = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" + RETINANET_MODEL_ZOO_NAME = "COCO-Detection/retinanet_R_50_FPN_3x.yaml" + MASKRCNN_MODEL_ZOO_NAME = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml" + + +def export_cfg_as_yaml(cfg, export_path: str = "config.yaml"): + """ + Exports Detectron2 config object in yaml format so that it can be used later. + Args: + cfg (detectron2.config.CfgNode): Detectron2 config object. + export_path (str): Path to export the Detectron2 config. + Related Detectron2 doc: https://detectron2.readthedocs.io/en/stable/modules/config.html#detectron2.config.CfgNode.dump + """ + Path(export_path).parent.mkdir(exist_ok=True, parents=True) + + with open(export_path, "w") as f: + f.write(cfg.dump()) diff --git a/sahi/utils/fiftyone.py b/sahi/utils/fiftyone.py new file mode 100644 index 0000000..b3a469c --- /dev/null +++ b/sahi/utils/fiftyone.py @@ -0,0 +1,78 @@ +import os +import subprocess +import sys + +from sahi.utils.import_utils import is_available + +if is_available("fiftyone"): + # to fix https://github.com/voxel51/fiftyone/issues/845 + if sys.platform == "win32": + _ = subprocess.run("tskill mongod", stderr=subprocess.DEVNULL) + else: + _ = subprocess.run(["pkill", "mongod"], stderr=subprocess.DEVNULL) + + # import fo utilities + import fiftyone as fo + from fiftyone.utils.coco import COCODetectionDatasetImporter as BaseCOCODetectionDatasetImporter + from fiftyone.utils.coco import _get_matching_image_ids, load_coco_detection_annotations + + class COCODetectionDatasetImporter(BaseCOCODetectionDatasetImporter): + def setup(self): + if self.labels_path is not None and os.path.isfile(self.labels_path): + ( + info, + classes, + supercategory_map, + images, + annotations, + ) = load_coco_detection_annotations(self.labels_path, extra_attrs=self.extra_attrs) + + if classes is not None: + info["classes"] = classes + + image_ids = _get_matching_image_ids( + classes, + images, + annotations, + image_ids=self.image_ids, + classes=self.classes, + shuffle=self.shuffle, + seed=self.seed, + max_samples=self.max_samples, + ) + + filenames = [images[_id]["file_name"] for _id in image_ids] + + _image_ids = set(image_ids) + image_dicts_map = {i["file_name"]: i for _id, i in images.items() if _id in _image_ids} + else: + info = {} + classes = None + supercategory_map = None + image_dicts_map = {} + annotations = None + filenames = [] + + self._image_paths_map = { + image["file_name"]: os.path.join(self.data_path, image["file_name"]) for image in images.values() + } + + self._info = info + self._classes = classes + self._supercategory_map = supercategory_map + self._image_dicts_map = image_dicts_map + self._annotations = annotations + self._filenames = filenames + + def create_fiftyone_dataset_from_coco_file(coco_image_dir: str, coco_json_path: str): + coco_importer = COCODetectionDatasetImporter( + data_path=coco_image_dir, labels_path=coco_json_path, include_id=True + ) + dataset = fo.Dataset.from_importer(coco_importer, label_field="gt") + return dataset + + def launch_fiftyone_app(coco_image_dir: str, coco_json_path: str): + dataset = create_fiftyone_dataset_from_coco_file(coco_image_dir, coco_json_path) + session = fo.launch_app() + session.dataset = dataset + return session diff --git a/sahi/utils/file.py b/sahi/utils/file.py new file mode 100644 index 0000000..9266167 --- /dev/null +++ b/sahi/utils/file.py @@ -0,0 +1,234 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import glob +import json +import ntpath +import os +import pickle +import re +import urllib.request +import zipfile +from pathlib import Path + +import numpy as np + + +def unzip(file_path: str, dest_dir: str): + """ + Unzips compressed .zip file. + Example inputs: + file_path: 'data/01_alb_id.zip' + dest_dir: 'data/' + """ + + # unzip file + with zipfile.ZipFile(file_path) as zf: + zf.extractall(dest_dir) + + +def save_json(data, save_path): + """ + Saves json formatted data (given as "data") as save_path + Example inputs: + data: {"image_id": 5} + save_path: "dirname/coco.json" + """ + # create dir if not present + Path(save_path).parent.mkdir(parents=True, exist_ok=True) + + # export as json + with open(save_path, "w", encoding="utf-8") as outfile: + json.dump(data, outfile, separators=(",", ":"), cls=NumpyEncoder) + + +# type check when save json files +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NumpyEncoder, self).default(obj) + + +def load_json(load_path: str, encoding: str = "utf-8"): + """ + Loads json formatted data (given as "data") from load_path + Encoding type can be specified with 'encoding' argument + + Example inputs: + load_path: "dirname/coco.json" + """ + # read from path + with open(load_path, encoding=encoding) as json_file: + data = json.load(json_file) + return data + + +def list_files( + directory: str, + contains: list = [".json"], + verbose: int = 1, +) -> list: + """ + Walk given directory and return a list of file path with desired extension + + Args: + directory: str + "data/coco/" + contains: list + A list of strings to check if the target file contains them, example: ["coco.png", ".jpg", "jpeg"] + verbose: int + 0: no print + 1: print number of files + + Returns: + filepath_list : list + List of file paths + """ + # define verboseprint + verboseprint = print if verbose else lambda *a, **k: None + + filepath_list = [] + + for file in os.listdir(directory): + # check if filename contains any of the terms given in contains list + if any(strtocheck in file for strtocheck in contains): + filepath = os.path.join(directory, file) + filepath_list.append(filepath) + + number_of_files = len(filepath_list) + folder_name = Path(directory).name + + verboseprint(f"There are {str(number_of_files)} listed files in folder: {folder_name}/") + + return filepath_list + + +def list_files_recursively(directory: str, contains: list = [".json"], verbose: str = True) -> (list, list): + """ + Walk given directory recursively and return a list of file path with desired extension + + Arguments + ------- + directory : str + "data/coco/" + contains : list + A list of strings to check if the target file contains them, example: ["coco.png", ".jpg", "jpeg"] + verbose : bool + If true, prints some results + Returns + ------- + relative_filepath_list : list + List of file paths relative to given directory + abs_filepath_list : list + List of absolute file paths + """ + + # define verboseprint + verboseprint = print if verbose else lambda *a, **k: None + + # walk directories recursively and find json files + abs_filepath_list = [] + relative_filepath_list = [] + + # r=root, d=directories, f=files + for r, _, f in os.walk(directory): + for file in f: + # check if filename contains any of the terms given in contains list + if any(strtocheck in file for strtocheck in contains): + abs_filepath = os.path.join(r, file) + abs_filepath_list.append(abs_filepath) + relative_filepath = abs_filepath.split(directory)[-1] + relative_filepath_list.append(relative_filepath) + + number_of_files = len(relative_filepath_list) + folder_name = directory.split(os.sep)[-1] + + verboseprint("There are {} listed files in folder {}.".format(number_of_files, folder_name)) + + return relative_filepath_list, abs_filepath_list + + +def get_base_filename(path: str): + """ + Takes a file path, returns (base_filename_with_extension, base_filename_without_extension) + """ + base_filename_with_extension = ntpath.basename(path) + base_filename_without_extension, _ = os.path.splitext(base_filename_with_extension) + return base_filename_with_extension, base_filename_without_extension + + +def get_file_extension(path: str): + filename, file_extension = os.path.splitext(path) + return file_extension + + +def load_pickle(load_path): + """ + Loads pickle formatted data (given as "data") from load_path + Example inputs: + load_path: "dirname/coco.pickle" + """ + # read from path + with open(load_path) as json_file: + data = pickle.load(json_file) + return data + + +def save_pickle(data, save_path): + """ + Saves pickle formatted data (given as "data") as save_path + Example inputs: + data: {"image_id": 5} + save_path: "dirname/coco.pickle" + """ + # create dir if not present + Path(save_path).parent.mkdir(parents=True, exist_ok=True) + + # export as json + with open(save_path, "wb") as outfile: + pickle.dump(data, outfile) + + +def import_model_class(class_name): + """ + Imports a predefined detection class by class name. + + Args: + model_name: str + Name of the detection model class (example: "MmdetDetectionModel") + Returns: + class_: class with given path + """ + module = __import__("sahi.model", fromlist=[class_name]) + class_ = getattr(module, class_name) + return class_ + + +def increment_path(path, exist_ok=True, sep=""): + # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc. + path = Path(path) # os-agnostic + if (path.exists() and exist_ok) or (not path.exists()): + return str(path) + else: + dirs = glob.glob(f"{path}{sep}*") # similar paths + matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] + i = [int(m.groups()[0]) for m in matches if m] # indices + n = max(i) + 1 if i else 2 # increment number + return f"{path}{sep}{n}" # update path + + +def download_from_url(from_url: str, to_path: str): + + Path(to_path).parent.mkdir(parents=True, exist_ok=True) + + if not os.path.exists(to_path): + urllib.request.urlretrieve( + from_url, + to_path, + ) diff --git a/sahi/utils/huggingface.py b/sahi/utils/huggingface.py new file mode 100644 index 0000000..9f90660 --- /dev/null +++ b/sahi/utils/huggingface.py @@ -0,0 +1,2 @@ +class HuggingfaceTestConstants: + YOLOS_TINY_MODEL_PATH = "hustvl/yolos-tiny" diff --git a/sahi/utils/import_utils.py b/sahi/utils/import_utils.py new file mode 100644 index 0000000..43d7368 --- /dev/null +++ b/sahi/utils/import_utils.py @@ -0,0 +1,71 @@ +import contextlib +import importlib.util +import logging +import os + +# adapted from https://github.com/huggingface/transformers/src/transformers/utils/import_utils.py + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +def get_package_info(package_name: str, verbose: bool = True): + """ + Returns the package version as a string and the package name as a string. + """ + _is_available = is_available(package_name) + + if _is_available: + try: + import importlib.metadata as _importlib_metadata + + _version = _importlib_metadata.version(package_name) + except (ModuleNotFoundError, AttributeError): + try: + _version = importlib.import_module(package_name).__version__ + except AttributeError: + _version = "unknown" + if verbose: + logger.info(f"{package_name} version {_version} is available.") + else: + _version = "N/A" + + return _is_available, _version + + +def print_enviroment_info(): + _torch_available, _torch_version = get_package_info("torch") + _torchvision_available, _torchvision_version = get_package_info("torchvision") + _tensorflow_available, _tensorflow_version = get_package_info("tensorflow") + _tensorflow_hub_available, _tensorflow_hub_version = get_package_info("tensorflow-hub") + _yolov5_available, _yolov5_version = get_package_info("yolov5") + _mmdet_available, _mmdet_version = get_package_info("mmdet") + _mmcv_available, _mmcv_version = get_package_info("mmcv") + _detectron2_available, _detectron2_version = get_package_info("detectron2") + _transformers_available, _transformers_version = get_package_info("transformers") + _timm_available, _timm_version = get_package_info("timm") + _layer_available, _layer_version = get_package_info("layer") + _fiftyone_available, _fiftyone_version = get_package_info("fiftyone") + _norfair_available, _norfair_version = get_package_info("norfair") + + +def is_available(module_name: str): + return importlib.util.find_spec(module_name) is not None + + +@contextlib.contextmanager +def check_requirements(package_names): + """ + Raise error if module is not installed. + """ + missing_packages = [] + for package_name in package_names: + if importlib.util.find_spec(package_name) is None: + missing_packages.append(package_name) + if missing_packages: + raise ImportError(f"The following packages are required to use this module: {missing_packages}") + yield diff --git a/sahi/utils/mmdet.py b/sahi/utils/mmdet.py new file mode 100644 index 0000000..229efc0 --- /dev/null +++ b/sahi/utils/mmdet.py @@ -0,0 +1,192 @@ +import shutil +import sys +import urllib.request +from importlib import import_module +from os import path +from pathlib import Path +from typing import Optional + +from sahi.utils.file import download_from_url + + +def mmdet_version_as_integer(): + import mmdet + + return int(mmdet.__version__.replace(".", "")) + + +class MmdetTestConstants: + MMDET_CASCADEMASKRCNN_MODEL_URL = "http://download.openmmlab.com/mmdetection/v2.0/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco/cascade_mask_rcnn_r50_fpn_1x_coco_20200203-9d4dcb24.pth" + MMDET_CASCADEMASKRCNN_MODEL_PATH = ( + "tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco_20200203-9d4dcb24.pth" + ) + MMDET_RETINANET_MODEL_URL = "http://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth" + MMDET_RETINANET_MODEL_PATH = "tests/data/models/mmdet_retinanet/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth" + MMDET_YOLOX_TINY_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth" + MMDET_YOLOX_TINY_MODEL_PATH = "tests/data/models/mmdet_yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth" + + MMDET_CASCADEMASKRCNN_CONFIG_PATH = "tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py" + MMDET_RETINANET_CONFIG_PATH = "tests/data/models/mmdet_retinanet/retinanet_r50_fpn_1x_coco.py" + MMDET_YOLOX_TINY_CONFIG_PATH = "tests/data/models/mmdet_yolox/yolox_tiny_8x8_300e_coco.py" + + +def download_mmdet_cascade_mask_rcnn_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + download_from_url(MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_URL, destination_path) + + +def download_mmdet_retinanet_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = MmdetTestConstants.MMDET_RETINANET_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + download_from_url(MmdetTestConstants.MMDET_RETINANET_MODEL_URL, destination_path) + + +def download_mmdet_yolox_tiny_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = MmdetTestConstants.MMDET_YOLOX_TINY_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + download_from_url(MmdetTestConstants.MMDET_YOLOX_TINY_MODEL_URL, destination_path) + + +def download_mmdet_config( + model_name: str = "cascade_rcnn", + config_file_name: str = "cascade_mask_rcnn_r50_fpn_1x_coco.py", + verbose: bool = True, +) -> str: + """ + Merges config files starting from given main config file name. Saves as single file. + + Args: + model_name (str): mmdet model name. check https://github.com/open-mmlab/mmdetection/tree/master/configs. + config_file_name (str): mdmet config file name. + verbose (bool): if True, print save path. + + Returns: + (str) abs path of the downloaded config file. + """ + + # get mmdet version + from mmdet import __version__ + + mmdet_ver = "v" + __version__ + + # set main config url + base_config_url = ( + "https://raw.githubusercontent.com/open-mmlab/mmdetection/" + mmdet_ver + "/configs/" + model_name + "/" + ) + main_config_url = base_config_url + config_file_name + + # set final config dirs + configs_dir = Path("mmdet_configs") / mmdet_ver + model_config_dir = configs_dir / model_name + + # create final config dir + configs_dir.mkdir(parents=True, exist_ok=True) + model_config_dir.mkdir(parents=True, exist_ok=True) + + # get final config file name + filename = Path(main_config_url).name + + # set final config file path + final_config_path = str(model_config_dir / filename) + + if not Path(final_config_path).exists(): + # set config dirs + temp_configs_dir = Path("temp_mmdet_configs") + main_config_dir = temp_configs_dir / model_name + + # create config dirs + temp_configs_dir.mkdir(parents=True, exist_ok=True) + main_config_dir.mkdir(parents=True, exist_ok=True) + + # get main config file name + filename = Path(main_config_url).name + + # set main config file path + main_config_path = str(main_config_dir / filename) + + # download main config file + urllib.request.urlretrieve( + main_config_url, + main_config_path, + ) + + # read main config file + sys.path.insert(0, str(main_config_dir)) + temp_module_name = path.splitext(filename)[0] + mod = import_module(temp_module_name) + sys.path.pop(0) + config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")} + + # handle when config_dict["_base_"] is string + if not isinstance(config_dict["_base_"], list): + config_dict["_base_"] = [config_dict["_base_"]] + + # iterate over secondary config files + for secondary_config_file_path in config_dict["_base_"]: + # set config url + config_url = base_config_url + secondary_config_file_path + config_path = main_config_dir / secondary_config_file_path + + # create secondary config dir + config_path.parent.mkdir(parents=True, exist_ok=True) + + # download secondary config files + urllib.request.urlretrieve( + config_url, + str(config_path), + ) + + # read secondary config file + secondary_config_dir = config_path.parent + sys.path.insert(0, str(secondary_config_dir)) + temp_module_name = path.splitext(Path(config_path).name)[0] + mod = import_module(temp_module_name) + sys.path.pop(0) + secondary_config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")} + + # go deeper if there are more steps + if secondary_config_dict.get("_base_") is not None: + # handle when config_dict["_base_"] is string + if not isinstance(secondary_config_dict["_base_"], list): + secondary_config_dict["_base_"] = [secondary_config_dict["_base_"]] + + # iterate over third config files + for third_config_file_path in secondary_config_dict["_base_"]: + # set config url + config_url = base_config_url + third_config_file_path + config_path = main_config_dir / third_config_file_path + + # create secondary config dir + config_path.parent.mkdir(parents=True, exist_ok=True) + # download secondary config files + urllib.request.urlretrieve( + config_url, + str(config_path), + ) + + # dump final config as single file + from mmcv import Config + + config = Config.fromfile(main_config_path) + config.dump(final_config_path) + + if verbose: + print(f"mmdet config file has been downloaded to {path.abspath(final_config_path)}") + + # remove temp config dir + shutil.rmtree(temp_configs_dir) + + return path.abspath(final_config_path) diff --git a/sahi/utils/mot.py b/sahi/utils/mot.py new file mode 100644 index 0000000..3c62716 --- /dev/null +++ b/sahi/utils/mot.py @@ -0,0 +1,349 @@ +import os +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np + +from sahi.utils.file import increment_path +from sahi.utils.import_utils import check_requirements, is_available + +if is_available("norfair"): + from norfair.metrics import PredictionsTextFile + + +@check_requirements(["norfair"]) +class MotTextFile(PredictionsTextFile): + from norfair.tracker import TrackedObject + + def __init__(self, save_dir: str = ".", save_name: str = "gt"): + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + self.out_file_name = os.path.join(save_dir, save_name + ".txt") + + self.frame_number = 0 + + def update(self, predictions: List[TrackedObject], frame_number: int = None): + if frame_number is None: + frame_number = self.frame_number + """ + Write tracked object information in the output file (for this frame), in the format + frame_number, id, bb_left, bb_top, bb_width, bb_height, 1, -1, -1, -1 + """ + text_file = open(self.out_file_name, "a+") + + for obj in predictions: + frame_str = str(int(frame_number)) + id_str = str(int(obj.id)) + bb_left_str = str((obj.estimate[0, 0])) + bb_top_str = str((obj.estimate[0, 1])) # [0,1] + bb_width_str = str((obj.estimate[1, 0] - obj.estimate[0, 0])) + bb_height_str = str((obj.estimate[1, 1] - obj.estimate[0, 1])) + row_text_out = ( + frame_str + + "," + + id_str + + "," + + bb_left_str + + "," + + bb_top_str + + "," + + bb_width_str + + "," + + bb_height_str + + ",1,-1,-1,-1" + ) + text_file.write(row_text_out) + text_file.write("\n") + + self.frame_number += 1 + + text_file.close() + + +def euclidean_distance(detection, tracked_object): + return np.linalg.norm(detection.points - tracked_object.estimate) + + +class MotAnnotation: + def __init__(self, bbox: List[int], track_id: Optional[int] = -1, score: Optional[float] = 1): + """ + Args: + bbox (List[int]): [x_min, y_min, width, height] + track_id: (Optional[int]): track id of the annotation + score (Optional[float]) + """ + self.bbox = bbox + self.track_id = track_id + self.score = score + + +@check_requirements(["norfair"]) +class MotFrame: + def __init__(self, file_name: Optional[str] = None): + self.annotation_list: List[MotAnnotation] = [] + self.file_name = file_name + + def add_annotation(self, detection: MotAnnotation): + if not isinstance(detection, MotAnnotation): + raise TypeError("'detection' should be a MotAnnotation object.") + self.annotation_list.append(detection) + + def to_norfair_detections(self, track_points: str = "bbox"): + """ + Args: + track_points (str): 'centroid' or 'bbox'. Defaults to 'bbox'. + """ + from norfair import Detection + + norfair_detections: List[Detection] = [] + # convert all detections to norfair detections + for annotation in self.annotation_list: + # calculate bbox points + xmin = annotation.bbox[0] + ymin = annotation.bbox[1] + xmax = annotation.bbox[0] + annotation.bbox[2] + ymax = annotation.bbox[1] + annotation.bbox[3] + scores = None + # calculate points as bbox or centroid + if track_points == "bbox": + points = np.array([[xmin, ymin], [xmax, ymax]]) # bbox + if annotation.score is not None: + scores = np.array([annotation.score, annotation.score]) + + elif track_points == "centroid": + points = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2]) # centroid + if annotation.score is not None: + scores = np.array([annotation.score]) + else: + raise ValueError("'track_points' should be one of ['centroid', 'bbox'].") + # create norfair formatted detection + norfair_detections.append(Detection(points=points, scores=scores)) + return norfair_detections + + def to_norfair_trackedobjects(self, track_points: str = "bbox"): + """ + Args: + track_points (str): 'centroid' or 'bbox'. Defaults to 'bbox'. + """ + from norfair import Detection, Tracker + from norfair.tracker import TrackedObject + + tracker = Tracker( + distance_function=euclidean_distance, + distance_threshold=30, + detection_threshold=0, + hit_counter_max=12, + pointwise_hit_counter_max=4, + ) + + tracked_object_list: List[TrackedObject] = [] + # convert all detections to norfair detections + for annotation in self.annotation_list: + # ensure annotation.track_id is not None + if annotation.track_id is None: + raise ValueError("to_norfair_trackedobjects() requires annotation.track_id to be set.") + # calculate bbox points + xmin = annotation.bbox[0] + ymin = annotation.bbox[1] + xmax = annotation.bbox[0] + annotation.bbox[2] + ymax = annotation.bbox[1] + annotation.bbox[3] + track_id = annotation.track_id + scores = None + # calculate points as bbox or centroid + if track_points == "bbox": + points = np.array([[xmin, ymin], [xmax, ymax]]) # bbox + if annotation.score is not None: + scores = np.array([annotation.score, annotation.score]) + + elif track_points == "centroid": + points = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2]) # centroid + if annotation.score is not None: + scores = np.array([annotation.score]) + else: + raise ValueError("'track_points' should be one of ['centroid', 'bbox'].") + # create norfair formatted detection + detection = Detection(points=points, scores=scores) + # create trackedobject from norfair detection + tracked_object = TrackedObject( + detection, + tracker.hit_counter_max, + tracker.initialization_delay, + pointwise_hit_counter_max=tracker.pointwise_hit_counter_max, + detection_threshold=tracker.detection_threshold, + period=1, + filter_factory=tracker.filter_factory, + past_detections_length=0, + ) + tracked_object.id = track_id + tracked_object.point_hit_counter = np.ones(tracked_object.num_points) * 1 + # append to tracked_object_list + tracked_object_list.append(tracked_object) + return tracked_object_list + + +@check_requirements(["norfair"]) +class MotVideo: + def __init__( + self, + name: str = "sequence_name", + frame_rate: Optional[int] = 30, + image_height: int = 720, + image_width: int = 1280, + tracker_kwargs: Optional[Dict] = dict(), + ): + """ + Args + name (str): Name of the video file. + frame_rate (int): FPS of the video. + image_height (int): Frame height of the video. + image_width (int): Frame width of the video. + tracker_kwargs (dict): a dict contains the tracker keys as below: + - max_distance_between_points (int) + - min_detection_threshold (float) + - hit_inertia_min (int) + - hit_inertia_max (int) + - point_transience (int) + For details: https://github.com/tryolabs/norfair/tree/master/docs#arguments + """ + + self.name = name + self.frame_rate = frame_rate + self.image_height = image_height + self.image_width = image_width + self.tracker_kwargs = tracker_kwargs + + self.frame_list: List[MotFrame] = [] + + def _create_info_file(self, seq_length: int, export_dir: str): + """ + Args: + seq_length (int): Number of frames present in video (seqLength parameter in seqinfo.ini) + For details: https://github.com/tryolabs/norfair/issues/42#issuecomment-819211873 + export_dir (str): Folder directory that will contain exported file. + """ + # set file path + filepath = Path(export_dir) / "seqinfo.ini" + # create folder directory if not exists + filepath.parent.mkdir(exist_ok=True) + # create seqinfo.ini file with seqLength + with open(str(filepath), "w") as file: + file.write("[Sequence]\n") + file.write(f"name={self.name}\n") + file.write(f"imDir=img1\n") + file.write(f"frameRate={self.frame_rate}\n") + file.write(f"seqLength={seq_length}\n") + file.write(f"imWidth={self.image_width}\n") + file.write(f"imHeight={self.image_height}") + + def _create_frame_symlinks(self, images_dir: str, export_dir: str): + """ + Args: + images_dir (str): Image directory of source data to be converted. + export_dir (str): Symlink directory that will contain symbolic links + pointing to source image files. + """ + + i = 1 + + img1 = Path(os.path.abspath(export_dir)) / "img1/" + img1.mkdir(parents=True, exist_ok=True) + + for mot_frame in self.frame_list: + if not isinstance(mot_frame.file_name, str): + raise TypeError(f"mot_frame.file_name expected to be string but got: {type(mot_frame.file_name)}") + + if not Path(mot_frame.file_name).suffix: + print(f"image file has no suffix, skipping it: '{mot_frame.file_name}'") + return + elif Path(mot_frame.file_name).suffix not in [".jpg", ".jpeg", ".bmp", ".gif", ".png", ".tiff"]: + print(f"image file has incorrect suffix, skipping it: '{mot_frame.file_name}'") + return + # set source and mot image paths + suffix = Path(mot_frame.file_name).suffix + + if os.path.isabs(mot_frame.file_name): + if not Path(mot_frame.file_name).is_file(): + raise ValueError(f"there is not any image file in path: {str(Path(mot_frame.file_name))}") + source_image_path = str(Path(mot_frame.file_name)) + else: + if not images_dir: + raise ValueError("you have to specify `images_dir` for mot conversion.") + source_image_path_tmp = os.path.abspath(str(Path(images_dir) / mot_frame.file_name)) + if not Path(source_image_path_tmp).is_file(): + raise ValueError(f"there is not any image file in path: {source_image_path_tmp}") + source_image_path = str(Path(source_image_path_tmp)) + + # generate symlink names as indicated at https://arxiv.org/pdf/1603.00831.pdf + frame_link_name = "0" * (6 - len(str(i))) + str(i) + suffix + + mot_image_path = str(Path(export_dir) / Path("img1") / Path(frame_link_name)) + os.symlink(source_image_path, mot_image_path) + i += 1 + + def add_frame(self, frame: MotFrame): + if not isinstance(frame, type(MotFrame())): + raise TypeError("'frame' should be a MotFrame object.") + self.frame_list.append(frame) + + def export( + self, + images_dir: str = None, + export_dir: str = "runs/mot", + type: str = "gt", + use_tracker: bool = None, + exist_ok=False, + ): + """ + Args + export_dir (str): Folder directory that will contain exported mot challenge formatted data. + type (str): Type of the MOT challenge export. 'gt' for groundturth data export, 'det' for detection data export. + use_tracker (bool): Determines whether to apply kalman based tracker over frame detections or not. + Default is True for type='gt'. + It is always False for type='det'. + exist_ok (bool): If True overwrites given directory. + """ + from norfair import Detection, Tracker + from norfair.filter import FilterPyKalmanFilterFactory + + if type not in ["gt", "det"]: + raise ValueError(f"'type' can be one of ['gt', 'det'], you provided: {type}") + export_dir: str = str(increment_path(Path(export_dir), exist_ok=exist_ok)) + + if type == "gt": + gt_dir = os.path.join(export_dir, self.name if self.name else "", "gt") + mot_text_file: MotTextFile = MotTextFile(save_dir=gt_dir, save_name="gt") + if not use_tracker: + use_tracker = True + elif type == "det": + det_dir = os.path.join(export_dir, self.name if self.name else "", "det") + mot_text_file: MotTextFile = MotTextFile(save_dir=det_dir, save_name="det") + use_tracker = False + + tracker = Tracker( + distance_function=self.tracker_kwargs.get("distance_function", euclidean_distance), + distance_threshold=self.tracker_kwargs.get("distance_threshold", 50), + hit_counter_max=self.tracker_kwargs.get("hit_counter_max", 1), + initialization_delay=self.tracker_kwargs.get("initialization_delay", 0), + detection_threshold=self.tracker_kwargs.get("detection_threshold", 0), + pointwise_hit_counter_max=self.tracker_kwargs.get("pointwise_hit_counter_max", 4), + filter_factory=self.tracker_kwargs.get("filter_factory", FilterPyKalmanFilterFactory(R=0.2)), + ) + + for mot_frame in self.frame_list: + if use_tracker: + norfair_detections: List[Detection] = mot_frame.to_norfair_detections(track_points="bbox") + tracked_objects = tracker.update(detections=norfair_detections) + else: + tracked_objects = mot_frame.to_norfair_trackedobjects(track_points="bbox") + mot_text_file.update(predictions=tracked_objects) + + if type == "gt": + info_dir = os.path.join(export_dir, self.name if self.name else "") + self._create_info_file(seq_length=mot_text_file.frame_number, export_dir=info_dir) + # create symlinks if mot frames contain file_name + if self.frame_list[0].file_name: + self._create_frame_symlinks(images_dir=images_dir, export_dir=info_dir) + else: + print("skipping frame symlink creation since file_name is not set for mot frames") diff --git a/sahi/utils/shapely.py b/sahi/utils/shapely.py new file mode 100644 index 0000000..8bbc59a --- /dev/null +++ b/sahi/utils/shapely.py @@ -0,0 +1,291 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +from typing import List + +from shapely.geometry import CAP_STYLE, JOIN_STYLE, MultiPolygon, Polygon, box + + +def get_shapely_box(x: int, y: int, width: int, height: int) -> Polygon: + """ + Accepts coco style bbox coords and converts it to shapely box object + """ + minx = x + miny = y + maxx = x + width + maxy = y + height + shapely_box = box(minx, miny, maxx, maxy) + + return shapely_box + + +def get_shapely_multipolygon(coco_segmentation: List[List]) -> MultiPolygon: + """ + Accepts coco style polygon coords and converts it to shapely multipolygon object + """ + polygon_list = [] + for coco_polygon in coco_segmentation: + point_list = list(zip(coco_polygon[0::2], coco_polygon[1::2])) + shapely_polygon = Polygon(point_list) + polygon_list.append(shapely_polygon) + shapely_multipolygon = MultiPolygon(polygon_list) + + return shapely_multipolygon + + +def get_bbox_from_shapely(shapely_object): + """ + Accepts shapely box/poly object and returns its bounding box in coco and voc formats + """ + minx, miny, maxx, maxy = shapely_object.bounds + width = maxx - minx + height = maxy - miny + coco_bbox = [minx, miny, width, height] + coco_bbox = [round(point) for point in coco_bbox] if coco_bbox else coco_bbox + voc_bbox = [minx, miny, maxx, maxy] + voc_bbox = [round(point) for point in voc_bbox] if voc_bbox else voc_bbox + + return coco_bbox, voc_bbox + + +class ShapelyAnnotation: + """ + Creates ShapelyAnnotation (as shapely MultiPolygon). + Can convert this instance annotation to various formats. + """ + + @classmethod + def from_coco_segmentation(cls, segmentation, slice_bbox=None): + """ + Init ShapelyAnnotation from coco segmentation. + + segmentation : List[List] + [[1, 1, 325, 125, 250, 200, 5, 200]] + slice_bbox (List[int]): [xmin, ymin, width, height] + Should have the same format as the output of the get_bbox_from_shapely function. + Is used to calculate sliced coco coordinates. + """ + shapely_multipolygon = get_shapely_multipolygon(segmentation) + return cls(multipolygon=shapely_multipolygon, slice_bbox=slice_bbox) + + @classmethod + def from_coco_bbox(cls, bbox: List[int], slice_bbox: List[int] = None): + """ + Init ShapelyAnnotation from coco bbox. + + bbox (List[int]): [xmin, ymin, width, height] + slice_bbox (List[int]): [x_min, y_min, x_max, y_max] Is used + to calculate sliced coco coordinates. + """ + shapely_polygon = get_shapely_box(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3]) + shapely_multipolygon = MultiPolygon([shapely_polygon]) + return cls(multipolygon=shapely_multipolygon, slice_bbox=slice_bbox) + + def __init__(self, multipolygon: MultiPolygon, slice_bbox=None): + self.multipolygon = multipolygon + self.slice_bbox = slice_bbox + + @property + def multipolygon(self): + return self.__multipolygon + + @property + def area(self): + return int(self.__area) + + @multipolygon.setter + def multipolygon(self, multipolygon: MultiPolygon): + self.__multipolygon = multipolygon + # calculate areas of all polygons + area = 0 + for shapely_polygon in multipolygon.geoms: + area += shapely_polygon.area + # set instance area + self.__area = area + + def to_list(self): + """ + [ + [(x1, y1), (x2, y2), (x3, y3), ...], + [(x1, y1), (x2, y2), (x3, y3), ...], + ... + ] + """ + list_of_list_of_points: List = [] + for shapely_polygon in self.multipolygon.geoms: + # create list_of_points for selected shapely_polygon + if shapely_polygon.area != 0: + x_coords = shapely_polygon.exterior.coords.xy[0] + y_coords = shapely_polygon.exterior.coords.xy[1] + # fix coord by slice_bbox + if self.slice_bbox: + minx = self.slice_bbox[0] + miny = self.slice_bbox[1] + x_coords = [x_coord - minx for x_coord in x_coords] + y_coords = [y_coord - miny for y_coord in y_coords] + list_of_points = list(zip(x_coords, y_coords)) + else: + list_of_points = [] + # append list_of_points to list_of_list_of_points + list_of_list_of_points.append(list_of_points) + # return result + return list_of_list_of_points + + def to_coco_segmentation(self): + """ + [ + [x1, y1, x2, y2, x3, y3, ...], + [x1, y1, x2, y2, x3, y3, ...], + ... + ] + """ + coco_segmentation: List = [] + for shapely_polygon in self.multipolygon.geoms: + # create list_of_points for selected shapely_polygon + if shapely_polygon.area != 0: + x_coords = shapely_polygon.exterior.coords.xy[0] + y_coords = shapely_polygon.exterior.coords.xy[1] + # fix coord by slice_bbox + if self.slice_bbox: + minx = self.slice_bbox[0] + miny = self.slice_bbox[1] + x_coords = [x_coord - minx for x_coord in x_coords] + y_coords = [y_coord - miny for y_coord in y_coords] + # convert intersection to coco style segmentation annotation + coco_polygon = [None] * len(x_coords) * 2 + coco_polygon[0::2] = [int(coord) for coord in x_coords] + coco_polygon[1::2] = [int(coord) for coord in y_coords] + else: + coco_polygon = [] + # remove if first and last points are duplicate + if coco_polygon[:2] == coco_polygon[-2:]: + del coco_polygon[-2:] + # append coco_polygon to coco_segmentation + coco_polygon = [round(point) for point in coco_polygon] if coco_polygon else coco_polygon + coco_segmentation.append(coco_polygon) + return coco_segmentation + + def to_opencv_contours(self): + """ + [ + [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]], + [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]] + ] + """ + opencv_contours: List = [] + for shapely_polygon in self.multipolygon.geoms: + # create opencv_contour for selected shapely_polygon + if shapely_polygon.area != 0: + x_coords = shapely_polygon.exterior.coords.xy[0] + y_coords = shapely_polygon.exterior.coords.xy[1] + # fix coord by slice_bbox + if self.slice_bbox: + minx = self.slice_bbox[0] + miny = self.slice_bbox[1] + x_coords = [x_coord - minx for x_coord in x_coords] + y_coords = [y_coord - miny for y_coord in y_coords] + opencv_contour = [[[int(x_coords[ind]), int(y_coords[ind])]] for ind in range(len(x_coords))] + else: + opencv_contour: List = [] + # append opencv_contour to opencv_contours + opencv_contours.append(opencv_contour) + # return result + return opencv_contours + + def to_coco_bbox(self): + """ + [xmin, ymin, width, height] + """ + if self.multipolygon.area != 0: + coco_bbox, _ = get_bbox_from_shapely(self.multipolygon) + # fix coord by slice box + if self.slice_bbox: + minx = round(self.slice_bbox[0]) + miny = round(self.slice_bbox[1]) + coco_bbox[0] = round(coco_bbox[0] - minx) + coco_bbox[1] = round(coco_bbox[1] - miny) + else: + coco_bbox: List = [] + return coco_bbox + + def to_voc_bbox(self): + """ + [xmin, ymin, xmax, ymax] + """ + if self.multipolygon.area != 0: + _, voc_bbox = get_bbox_from_shapely(self.multipolygon) + # fix coord by slice box + if self.slice_bbox: + minx = self.slice_bbox[0] + miny = self.slice_bbox[1] + voc_bbox[0] = round(voc_bbox[0] - minx) + voc_bbox[2] = round(voc_bbox[2] - minx) + voc_bbox[1] = round(voc_bbox[1] - miny) + voc_bbox[3] = round(voc_bbox[3] - miny) + else: + voc_bbox = [] + return voc_bbox + + def get_convex_hull_shapely_annotation(self): + shapely_multipolygon = MultiPolygon([self.multipolygon.convex_hull]) + shapely_annotation = ShapelyAnnotation(shapely_multipolygon) + return shapely_annotation + + def get_simplified_shapely_annotation(self, tolerance=1): + shapely_multipolygon = MultiPolygon([self.multipolygon.simplify(tolerance)]) + shapely_annotation = ShapelyAnnotation(shapely_multipolygon) + return shapely_annotation + + def get_buffered_shapely_annotation( + self, + distance=3, + resolution=16, + quadsegs=None, + cap_style=CAP_STYLE.round, + join_style=JOIN_STYLE.round, + mitre_limit=5.0, + single_sided=False, + ): + """ + Approximates the present polygon to have a valid polygon shape. + For more, check: https://shapely.readthedocs.io/en/stable/manual.html#object.buffer + """ + buffered_polygon = self.multipolygon.buffer( + distance=distance, + resolution=resolution, + quadsegs=quadsegs, + cap_style=cap_style, + join_style=join_style, + mitre_limit=mitre_limit, + single_sided=single_sided, + ) + shapely_annotation = ShapelyAnnotation(MultiPolygon([buffered_polygon])) + return shapely_annotation + + def get_intersection(self, polygon: Polygon): + """ + Accepts shapely polygon object and returns the intersection in ShapelyAnnotation format + """ + # convert intersection polygon to list of tuples + intersection = self.multipolygon.intersection(polygon) + # if polygon is box then set slice_box property + if ( + len(polygon.exterior.xy[0]) == 5 + and polygon.exterior.xy[0][0] == polygon.exterior.xy[0][1] + and polygon.exterior.xy[0][2] == polygon.exterior.xy[0][3] + ): + coco_bbox, voc_bbox = get_bbox_from_shapely(polygon) + slice_bbox = coco_bbox + else: + slice_bbox = None + # convert intersection to multipolygon + if intersection.geom_type == "Polygon": + intersection_multipolygon = MultiPolygon([intersection]) + elif intersection.geom_type == "MultiPolygon": + intersection_multipolygon = intersection + else: + intersection_multipolygon = MultiPolygon([]) + # create shapely annotation from intersection multipolygon + intersection_shapely_annotation = ShapelyAnnotation(intersection_multipolygon, slice_bbox) + + return intersection_shapely_annotation diff --git a/sahi/utils/torch.py b/sahi/utils/torch.py new file mode 100644 index 0000000..74a6027 --- /dev/null +++ b/sahi/utils/torch.py @@ -0,0 +1,55 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + + +from sahi.utils.import_utils import check_requirements, is_available + + +@check_requirements(["torch"]) +def empty_cuda_cache(): + if is_torch_cuda_available(): + import torch + + return torch.cuda.empty_cache() + else: + raise RuntimeError("CUDA not available.") + + +@check_requirements(["torch"]) +def to_float_tensor(img): + """ + Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W). + Args: + img: np.ndarray + Returns: + torch.tensor + """ + import torch + + img = img.transpose((2, 0, 1)) + img = torch.from_numpy(img).float() + if img.max() > 1: + img /= 255 + + return img + + +@check_requirements(["torch"]) +def torch_to_numpy(img): + import torch + + img = img.numpy() + if img.max() > 1: + img /= 255 + return img.transpose((1, 2, 0)) + + +@check_requirements(["torch"]) +def is_torch_cuda_available(): + if is_available("torch"): + import torch + + return torch.cuda.is_available() + else: + return False diff --git a/sahi/utils/torchvision.py b/sahi/utils/torchvision.py new file mode 100644 index 0000000..29a9567 --- /dev/null +++ b/sahi/utils/torchvision.py @@ -0,0 +1,126 @@ +# OBSS SAHI Tool +# Code written by Kadir Nar, 2022. + + +from packaging import version + +from sahi.utils.import_utils import get_package_info + + +class TorchVisionTestConstants: + FASTERRCNN_CONFIG_PATH = "tests/data/models/torchvision/fasterrcnn_resnet50_fpn.yaml" + SSD300_CONFIG_PATH = "tests/data/models/torchvision/ssd300_vgg16.yaml" + + +_torchvision_available, _torchvision_version = get_package_info("torchvision", verbose=False) + +if _torchvision_available: + import torchvision + + MODEL_NAME_TO_CONSTRUCTOR = { + "fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn, + "fasterrcnn_mobilenet_v3_large_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn, + "fasterrcnn_mobilenet_v3_large_320_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn, + "retinanet_resnet50_fpn": torchvision.models.detection.retinanet_resnet50_fpn, + "ssd300_vgg16": torchvision.models.detection.ssd300_vgg16, + "ssdlite320_mobilenet_v3_large": torchvision.models.detection.ssdlite320_mobilenet_v3_large, + } + + # fcos requires torchvision >= 0.12.0 + if version.parse(_torchvision_version) >= version.parse("0.12.0"): + MODEL_NAME_TO_CONSTRUCTOR["fcos_resnet50_fpn"] = (torchvision.models.detection.fcos_resnet50_fpn,) + + +COCO_CLASSES = [ + "__background__", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "N/A", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "N/A", + "backpack", + "umbrella", + "N/A", + "N/A", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "N/A", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "N/A", + "dining table", + "N/A", + "N/A", + "toilet", + "N/A", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "N/A", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +] diff --git a/sahi/utils/versions.py b/sahi/utils/versions.py new file mode 100644 index 0000000..b0fd7a0 --- /dev/null +++ b/sahi/utils/versions.py @@ -0,0 +1,7 @@ +import sys + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata diff --git a/sahi/utils/yolov5.py b/sahi/utils/yolov5.py new file mode 100644 index 0000000..9582f11 --- /dev/null +++ b/sahi/utils/yolov5.py @@ -0,0 +1,43 @@ +import urllib.request +from os import path +from pathlib import Path +from typing import Optional + + +class Yolov5TestConstants: + YOLOV5N_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n.pt" + YOLOV5N_MODEL_PATH = "tests/data/models/yolov5/yolov5n.pt" + + YOLOV5S6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s6.pt" + YOLOV5S6_MODEL_PATH = "tests/data/models/yolov5/yolov5s6.pt" + + YOLOV5M6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5m6.pt" + YOLOV5M6_MODEL_PATH = "tests/data/models/yolov5/yolov5m6.pt" + + +def download_yolov5n_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = Yolov5TestConstants.YOLOV5N_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + Yolov5TestConstants.YOLOV5N_MODEL_URL, + destination_path, + ) + + +def download_yolov5s6_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = Yolov5TestConstants.YOLOV5S6_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + Yolov5TestConstants.YOLOV5S6_MODEL_URL, + destination_path, + ) diff --git a/sahi/yolo6.py b/sahi/yolo6.py new file mode 100644 index 0000000..f0fff5c --- /dev/null +++ b/sahi/yolo6.py @@ -0,0 +1,45 @@ +import numpy as np +import torch +from YOLOv6.yolov6.data.data_augment import letterbox +import math +import cv2 + +def precess_image(img_src, img_size, stride): + '''Process image before image inference.''' + image = letterbox(img_src, img_size, stride=stride)[0] + # Convert + image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + image = torch.from_numpy(np.ascontiguousarray(image)) + image = image.float() # uint8 to fp16/32 + image /= 255 # 0 - 255 to 0.0 - 1.0 + image = image.unsqueeze(0) # add batch dimension + + img_shape = image.shape[2:] + img_src_shape = img_src.shape[:2] + + return image, img_shape, img_src_shape + +def check_img_size(img_size, s=32, floor=0): + """Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image.""" + if isinstance(img_size, int): # integer i.e. img_size=640 + new_size = max(math.ceil(img_size / int(s)) * int(s), floor) + elif isinstance(img_size, list): # list i.e. img_size=[640, 480] + new_size = [max(math.ceil(img_size / int(s)) * int(s), floor) for x in img_size] + else: + raise Exception(f"Unsupported type of img_size: {type(img_size)}") + + if new_size != img_size: + print(f'WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}') + return new_size if isinstance(img_size,list) else [new_size]*2 + + + +COCO_CLASSES = [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' ] \ No newline at end of file