-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add retinanet resnet-FPN backbone + focal loss
- Loading branch information
1 parent
1f65084
commit 0633363
Showing
16 changed files
with
145,713 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# PyTorch-RetinaNet | ||
Train _RetinaNet_ with _Focal Loss_ in PyTorch. | ||
|
||
Reference: | ||
[1] [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
000001.jpg 48 240 195 371 11 8 12 352 498 14 | ||
000002.jpg 139 200 207 301 18 | ||
000003.jpg 123 155 215 195 17 239 156 307 205 8 | ||
000004.jpg 13 311 84 362 6 362 330 500 389 6 235 328 334 375 6 175 327 252 364 6 139 320 189 359 6 108 325 150 353 6 84 323 121 350 6 | ||
000005.jpg 263 211 324 339 8 165 264 253 372 8 5 244 67 374 8 241 194 295 299 8 277 186 312 220 8 | ||
000006.jpg 187 135 282 242 15 154 209 369 375 10 255 207 366 375 8 298 195 332 247 8 279 190 308 231 8 137 192 151 199 8 137 198 156 212 8 138 211 249 375 8 | ||
000007.jpg 141 50 500 330 6 | ||
000008.jpg 192 16 364 249 8 | ||
000009.jpg 69 172 270 330 12 150 141 229 284 14 285 201 327 331 14 258 198 297 329 14 | ||
000010.jpg 87 97 258 427 12 133 72 245 284 14 | ||
000011.jpg 126 51 330 308 7 | ||
000012.jpg 156 97 351 270 6 | ||
000013.jpg 299 160 446 252 9 | ||
000014.jpg 72 163 302 228 5 185 194 500 316 6 416 180 500 222 6 314 8 344 65 14 331 4 361 61 14 357 8 401 61 14 163 197 267 244 6 | ||
000015.jpg 77 136 360 358 1 | ||
000016.jpg 92 72 305 473 1 | ||
000017.jpg 185 62 279 199 14 90 78 403 336 12 | ||
000018.jpg 31 30 358 279 11 | ||
000019.jpg 231 88 483 256 7 11 113 266 259 7 | ||
000020.jpg 33 148 371 416 6 |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
'''Load image/labels/boxes from an annotation file. | ||
The list file is like: | ||
img.jpg xmin ymin xmax ymax label xmin ymin xmax ymax label ... | ||
''' | ||
from __future__ import print_function | ||
|
||
import os | ||
import sys | ||
import random | ||
|
||
import torch | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
|
||
from PIL import Image | ||
from encoder import DataEncoder | ||
from transform import resize, random_flip, random_crop, center_crop | ||
|
||
|
||
class ListDataset(data.Dataset): | ||
def __init__(self, root, list_file, train, transform, input_size): | ||
''' | ||
Args: | ||
root: (str) ditectory to images. | ||
list_file: (str) path to index file. | ||
train: (boolean) train or test. | ||
transform: ([transforms]) image transforms. | ||
input_size: (int) model input size. | ||
''' | ||
self.root = root | ||
self.train = train | ||
self.transform = transform | ||
self.input_size = input_size | ||
|
||
self.fnames = [] | ||
self.boxes = [] | ||
self.labels = [] | ||
|
||
self.encoder = DataEncoder() | ||
|
||
with open(list_file) as f: | ||
lines = f.readlines() | ||
self.num_samples = len(lines) | ||
|
||
for line in lines: | ||
splited = line.strip().split() | ||
self.fnames.append(splited[0]) | ||
num_boxes = (len(splited) - 1) // 5 | ||
box = [] | ||
label = [] | ||
for i in range(num_boxes): | ||
xmin = splited[1+5*i] | ||
ymin = splited[2+5*i] | ||
xmax = splited[3+5*i] | ||
ymax = splited[4+5*i] | ||
c = splited[5+5*i] | ||
box.append([float(xmin),float(ymin),float(xmax),float(ymax)]) | ||
label.append(int(c)) | ||
self.boxes.append(torch.Tensor(box)) | ||
self.labels.append(torch.LongTensor(label)) | ||
|
||
def __getitem__(self, idx): | ||
'''Load image. | ||
Args: | ||
idx: (int) image index. | ||
Returns: | ||
img: (tensor) image tensor. | ||
loc_targets: (tensor) location targets. | ||
cls_targets: (tensor) class label targets. | ||
''' | ||
# Load image and boxes. | ||
fname = self.fnames[idx] | ||
img = Image.open(os.path.join(self.root, fname)) | ||
if img.mode != 'RGB': | ||
img = img.convert('RGB') | ||
|
||
boxes = self.boxes[idx].clone() | ||
labels = self.labels[idx] | ||
size = self.input_size | ||
|
||
# Data augmentation. | ||
if self.train: | ||
img, boxes = random_flip(img, boxes) | ||
img, boxes = random_crop(img, boxes) | ||
img, boxes = resize(img, boxes, (size,size)) | ||
else: | ||
img, boxes = resize(img, boxes, size) | ||
img, boxes = center_crop(img, boxes, (size,size)) | ||
|
||
img = self.transform(img) | ||
return img, boxes, labels | ||
|
||
def collate_fn(self, batch): | ||
'''Pad images and encode targets. | ||
As for images are of different sizes, we need to pad them to the same size. | ||
Args: | ||
batch: (list) of images, cls_targets, loc_targets. | ||
Returns: | ||
padded images, stacked cls_targets, stacked loc_targets. | ||
''' | ||
imgs = [x[0] for x in batch] | ||
boxes = [x[1] for x in batch] | ||
labels = [x[2] for x in batch] | ||
|
||
h = w = self.input_size | ||
num_imgs = len(imgs) | ||
inputs = torch.zeros(num_imgs, 3, h, w) | ||
|
||
loc_targets = [] | ||
cls_targets = [] | ||
for i in range(num_imgs): | ||
inputs[i] = imgs[i] | ||
loc_target, cls_target = self.encoder.encode(boxes[i], labels[i], input_size=(w,h)) | ||
loc_targets.append(loc_target) | ||
cls_targets.append(cls_target) | ||
return inputs, torch.stack(loc_targets), torch.stack(cls_targets) | ||
|
||
def __len__(self): | ||
return self.num_samples | ||
|
||
|
||
def test(): | ||
import torchvision | ||
|
||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)) | ||
]) | ||
dataset = ListDataset(root='/mnt/hgfs/D/download/PASCAL_VOC/voc_all_images', | ||
list_file='./data/voc12_train.txt', train=True, transform=transform, input_size=600) | ||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn) | ||
|
||
for images, loc_targets, cls_targets in dataloader: | ||
print(images.size()) | ||
print(loc_targets.size()) | ||
print(cls_targets.size()) | ||
grid = torchvision.utils.make_grid(images, 1) | ||
torchvision.utils.save_image(grid, 'a.jpg') | ||
break | ||
|
||
# test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
'''Encode object boxes and labels.''' | ||
import math | ||
import torch | ||
|
||
from utils import meshgrid, box_iou, box_nms, change_box_order | ||
|
||
|
||
class DataEncoder: | ||
def __init__(self): | ||
self.anchor_areas = [32*32., 64*64., 128*128., 256*256., 512*512.] # p3 -> p7 | ||
self.aspect_ratios = [1/2., 1/1., 2/1.] | ||
self.scale_ratios = [1., pow(2,1/3.), pow(2,2/3.)] | ||
self.anchor_wh = self._get_anchor_wh() | ||
|
||
def _get_anchor_wh(self): | ||
'''Compute anchor width and height for each feature map. | ||
Returns: | ||
anchor_wh: (tensor) anchor wh, sized [#fm, #anchors_per_cell, 2]. | ||
''' | ||
anchor_wh = [] | ||
for s in self.anchor_areas: | ||
for ar in self.aspect_ratios: # w/h = ar | ||
h = math.sqrt(s/ar) | ||
w = ar * h | ||
for sr in self.scale_ratios: # scale | ||
anchor_h = h*sr | ||
anchor_w = w*sr | ||
anchor_wh.append([anchor_w, anchor_h]) | ||
num_fms = len(self.anchor_areas) | ||
return torch.Tensor(anchor_wh).view(num_fms, -1, 2) | ||
|
||
def _get_anchor_boxes(self, input_size): | ||
'''Compute anchor boxes for each feature map. | ||
Args: | ||
input_size: (tensor) model input size of (w,h). | ||
Returns: | ||
boxes: (list) anchor boxes for each feature map. Each of size [#anchors,4], | ||
where #anchors = fmw * fmh * #anchors_per_cell | ||
''' | ||
num_fms = len(self.anchor_areas) | ||
fm_sizes = [(input_size/pow(2.,i+3)).ceil() for i in range(num_fms)] # p3 -> p7 feature map sizes | ||
|
||
boxes = [] | ||
for i in range(num_fms): | ||
fm_size = fm_sizes[i] | ||
grid_size = input_size / fm_size | ||
fm_w, fm_h = int(fm_size[0]), int(fm_size[1]) | ||
xy = meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2] | ||
xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2) | ||
wh = self.anchor_wh[i].view(1,1,9,2).expand(fm_h,fm_w,9,2) | ||
box = torch.cat([xy,wh], 3) # [x,y,w,h] | ||
boxes.append(box.view(-1,4)) | ||
return torch.cat(boxes, 0) | ||
|
||
def encode(self, boxes, labels, input_size): | ||
'''Encode target bounding boxes and class labels. | ||
We obey the Faster RCNN box coder: | ||
tx = (x - anchor_x) / anchor_w | ||
ty = (y - anchor_y) / anchor_h | ||
tw = log(w / anchor_w) | ||
th = log(h / anchor_h) | ||
Args: | ||
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [#obj, 4]. | ||
labels: (tensor) object class labels, sized [#obj,]. | ||
input_size: (int/tuple) model input size of (w,h). | ||
Returns: | ||
loc_targets: (tensor) encoded bounding boxes, sized [#anchors,4]. | ||
cls_targets: (tensor) encoded class labels, sized [#anchors,]. | ||
''' | ||
input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \ | ||
else torch.Tensor(input_size) | ||
anchor_boxes = self._get_anchor_boxes(input_size) | ||
boxes = change_box_order(boxes, 'xyxy2xywh') | ||
|
||
ious = box_iou(anchor_boxes, boxes, order='xywh') | ||
max_ious, max_ids = ious.max(1) | ||
boxes = boxes[max_ids] | ||
|
||
loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:] | ||
loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:]) | ||
loc_targets = torch.cat([loc_xy,loc_wh], 1) | ||
cls_targets = 1 + labels[max_ids] | ||
|
||
cls_targets[max_ious<0.5] = 0 | ||
ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5] | ||
cls_targets[ignore] = -1 # for now just mark ignored to -1 | ||
return loc_targets, cls_targets | ||
|
||
def decode(self, loc_preds, cls_preds, input_size): | ||
'''Decode outputs back to bouding box locations and class labels. | ||
Args: | ||
loc_preds: (tensor) predicted locations, sized [#anchors, 4]. | ||
cls_preds: (tensor) predicted class labels, sized [#anchors, #classes]. | ||
input_size: (int/tuple) model input size of (w,h). | ||
Returns: | ||
boxes: (tensor) decode box locations, sized [#obj,4]. | ||
labels: (tensor) class labels for each box, sized [#obj,]. | ||
''' | ||
CLS_THRESH = 0.5 | ||
NMS_THRESH = 0.5 | ||
|
||
input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \ | ||
else torch.Tensor(input_size) | ||
anchor_boxes = self._get_anchor_boxes(input_size) | ||
|
||
loc_xy = loc_preds[:,:2] | ||
loc_wh = loc_preds[:,2:] | ||
|
||
xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2] | ||
wh = loc_wh.exp() * anchor_boxes[:,2:] | ||
boxes = torch.cat([xy-wh/2, xy+wh/2], 1) # [#anchors,4] | ||
|
||
score, labels = cls_preds.sigmoid().max(1) # [#anchors,] | ||
ids = score > CLS_THRESH | ||
ids = ids.nonzero().squeeze() # [#obj,] | ||
keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH) | ||
return boxes[ids][keep], labels[ids][keep] |
Oops, something went wrong.