forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_segmenter.py
32 lines (28 loc) · 1.01 KB
/
image_segmenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
Module for image segmentation default handler
"""
import torch
from torchvision import transforms as T
from .vision_handler import VisionHandler
class ImageSegmenter(VisionHandler):
"""
ImageSegmenter handler class. This handler takes a batch of images
and returns output shape as [N K H W],
where N - batch size, K - number of classes, H - height and W - width.
"""
image_processing = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
def postprocess(self, data):
# Returning the class for every pixel makes the response size too big
# (> 24mb). Instead, we'll only return the top class for each image
data = data['out']
data = torch.nn.functional.softmax(data, dim=1)
data = torch.max(data, dim=1)
data = torch.stack([data.indices.type(data.values.dtype), data.values], dim=3)
return data.tolist()