forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
object_detector.py
49 lines (40 loc) · 1.71 KB
/
object_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Module for object detection default handler
"""
import torch
from torchvision import transforms
from torchvision import __version__ as torchvision_version
from packaging import version
from .vision_handler import VisionHandler
from ..utils.util import map_class_to_label
class ObjectDetector(VisionHandler):
"""
ObjectDetector handler class. This handler takes an image
and returns list of detected classes and bounding boxes respectively
"""
image_processing = transforms.Compose([transforms.ToTensor()])
threshold = 0.5
def initialize(self, context):
super().initialize(context)
# Torchvision breaks with object detector models before 0.6.0
if version.parse(torchvision_version) < version.parse("0.6.0"):
self.initialized = False
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
self.initialized = True
def postprocess(self, data):
result = []
box_filters = [row['scores'] >= self.threshold for row in data]
filtered_boxes, filtered_classes, filtered_scores = [
[row[key][box_filter].tolist() for row, box_filter in zip(data, box_filters)]
for key in ['boxes', 'labels', 'scores']
]
for classes, boxes, scores in zip(filtered_classes, filtered_boxes, filtered_scores):
retval = []
for _class, _box, _score in zip(classes, boxes, scores):
_retval = map_class_to_label([[_box]], self.mapping, [[_class]])[0]
_retval['score'] = _score
retval.append(_retval)
result.append(retval)
return result