Skip to content

Commit

Permalink
add retinanet resnet-FPN backbone + focal loss
Browse files Browse the repository at this point in the history
  • Loading branch information
deisler134 committed Nov 15, 2018
1 parent 1f65084 commit 0633363
Show file tree
Hide file tree
Showing 16 changed files with 145,713 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch-retinanet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# PyTorch-RetinaNet
Train _RetinaNet_ with _Focal Loss_ in PyTorch.

Reference:
[1] [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
117,266 changes: 117,266 additions & 0 deletions pytorch-retinanet/data/coco17_train.txt

Large diffs are not rendered by default.

4,952 changes: 4,952 additions & 0 deletions pytorch-retinanet/data/coco17_val.txt

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions pytorch-retinanet/data/test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
000001.jpg 48 240 195 371 11 8 12 352 498 14
000002.jpg 139 200 207 301 18
000003.jpg 123 155 215 195 17 239 156 307 205 8
000004.jpg 13 311 84 362 6 362 330 500 389 6 235 328 334 375 6 175 327 252 364 6 139 320 189 359 6 108 325 150 353 6 84 323 121 350 6
000005.jpg 263 211 324 339 8 165 264 253 372 8 5 244 67 374 8 241 194 295 299 8 277 186 312 220 8
000006.jpg 187 135 282 242 15 154 209 369 375 10 255 207 366 375 8 298 195 332 247 8 279 190 308 231 8 137 192 151 199 8 137 198 156 212 8 138 211 249 375 8
000007.jpg 141 50 500 330 6
000008.jpg 192 16 364 249 8
000009.jpg 69 172 270 330 12 150 141 229 284 14 285 201 327 331 14 258 198 297 329 14
000010.jpg 87 97 258 427 12 133 72 245 284 14
000011.jpg 126 51 330 308 7
000012.jpg 156 97 351 270 6
000013.jpg 299 160 446 252 9
000014.jpg 72 163 302 228 5 185 194 500 316 6 416 180 500 222 6 314 8 344 65 14 331 4 361 61 14 357 8 401 61 14 163 197 267 244 6
000015.jpg 77 136 360 358 1
000016.jpg 92 72 305 473 1
000017.jpg 185 62 279 199 14 90 78 403 336 12
000018.jpg 31 30 358 279 11
000019.jpg 231 88 483 256 7 11 113 266 259 7
000020.jpg 33 148 371 416 6
17,125 changes: 17,125 additions & 0 deletions pytorch-retinanet/data/voc12_train.txt

Large diffs are not rendered by default.

5,138 changes: 5,138 additions & 0 deletions pytorch-retinanet/data/voc12_val.txt

Large diffs are not rendered by default.

148 changes: 148 additions & 0 deletions pytorch-retinanet/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
'''Load image/labels/boxes from an annotation file.
The list file is like:
img.jpg xmin ymin xmax ymax label xmin ymin xmax ymax label ...
'''
from __future__ import print_function

import os
import sys
import random

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from PIL import Image
from encoder import DataEncoder
from transform import resize, random_flip, random_crop, center_crop


class ListDataset(data.Dataset):
def __init__(self, root, list_file, train, transform, input_size):
'''
Args:
root: (str) ditectory to images.
list_file: (str) path to index file.
train: (boolean) train or test.
transform: ([transforms]) image transforms.
input_size: (int) model input size.
'''
self.root = root
self.train = train
self.transform = transform
self.input_size = input_size

self.fnames = []
self.boxes = []
self.labels = []

self.encoder = DataEncoder()

with open(list_file) as f:
lines = f.readlines()
self.num_samples = len(lines)

for line in lines:
splited = line.strip().split()
self.fnames.append(splited[0])
num_boxes = (len(splited) - 1) // 5
box = []
label = []
for i in range(num_boxes):
xmin = splited[1+5*i]
ymin = splited[2+5*i]
xmax = splited[3+5*i]
ymax = splited[4+5*i]
c = splited[5+5*i]
box.append([float(xmin),float(ymin),float(xmax),float(ymax)])
label.append(int(c))
self.boxes.append(torch.Tensor(box))
self.labels.append(torch.LongTensor(label))

def __getitem__(self, idx):
'''Load image.
Args:
idx: (int) image index.
Returns:
img: (tensor) image tensor.
loc_targets: (tensor) location targets.
cls_targets: (tensor) class label targets.
'''
# Load image and boxes.
fname = self.fnames[idx]
img = Image.open(os.path.join(self.root, fname))
if img.mode != 'RGB':
img = img.convert('RGB')

boxes = self.boxes[idx].clone()
labels = self.labels[idx]
size = self.input_size

# Data augmentation.
if self.train:
img, boxes = random_flip(img, boxes)
img, boxes = random_crop(img, boxes)
img, boxes = resize(img, boxes, (size,size))
else:
img, boxes = resize(img, boxes, size)
img, boxes = center_crop(img, boxes, (size,size))

img = self.transform(img)
return img, boxes, labels

def collate_fn(self, batch):
'''Pad images and encode targets.
As for images are of different sizes, we need to pad them to the same size.
Args:
batch: (list) of images, cls_targets, loc_targets.
Returns:
padded images, stacked cls_targets, stacked loc_targets.
'''
imgs = [x[0] for x in batch]
boxes = [x[1] for x in batch]
labels = [x[2] for x in batch]

h = w = self.input_size
num_imgs = len(imgs)
inputs = torch.zeros(num_imgs, 3, h, w)

loc_targets = []
cls_targets = []
for i in range(num_imgs):
inputs[i] = imgs[i]
loc_target, cls_target = self.encoder.encode(boxes[i], labels[i], input_size=(w,h))
loc_targets.append(loc_target)
cls_targets.append(cls_target)
return inputs, torch.stack(loc_targets), torch.stack(cls_targets)

def __len__(self):
return self.num_samples


def test():
import torchvision

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
dataset = ListDataset(root='/mnt/hgfs/D/download/PASCAL_VOC/voc_all_images',
list_file='./data/voc12_train.txt', train=True, transform=transform, input_size=600)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn)

for images, loc_targets, cls_targets in dataloader:
print(images.size())
print(loc_targets.size())
print(cls_targets.size())
grid = torchvision.utils.make_grid(images, 1)
torchvision.utils.save_image(grid, 'a.jpg')
break

# test()
125 changes: 125 additions & 0 deletions pytorch-retinanet/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
'''Encode object boxes and labels.'''
import math
import torch

from utils import meshgrid, box_iou, box_nms, change_box_order


class DataEncoder:
def __init__(self):
self.anchor_areas = [32*32., 64*64., 128*128., 256*256., 512*512.] # p3 -> p7
self.aspect_ratios = [1/2., 1/1., 2/1.]
self.scale_ratios = [1., pow(2,1/3.), pow(2,2/3.)]
self.anchor_wh = self._get_anchor_wh()

def _get_anchor_wh(self):
'''Compute anchor width and height for each feature map.
Returns:
anchor_wh: (tensor) anchor wh, sized [#fm, #anchors_per_cell, 2].
'''
anchor_wh = []
for s in self.anchor_areas:
for ar in self.aspect_ratios: # w/h = ar
h = math.sqrt(s/ar)
w = ar * h
for sr in self.scale_ratios: # scale
anchor_h = h*sr
anchor_w = w*sr
anchor_wh.append([anchor_w, anchor_h])
num_fms = len(self.anchor_areas)
return torch.Tensor(anchor_wh).view(num_fms, -1, 2)

def _get_anchor_boxes(self, input_size):
'''Compute anchor boxes for each feature map.
Args:
input_size: (tensor) model input size of (w,h).
Returns:
boxes: (list) anchor boxes for each feature map. Each of size [#anchors,4],
where #anchors = fmw * fmh * #anchors_per_cell
'''
num_fms = len(self.anchor_areas)
fm_sizes = [(input_size/pow(2.,i+3)).ceil() for i in range(num_fms)] # p3 -> p7 feature map sizes

boxes = []
for i in range(num_fms):
fm_size = fm_sizes[i]
grid_size = input_size / fm_size
fm_w, fm_h = int(fm_size[0]), int(fm_size[1])
xy = meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2]
xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2)
wh = self.anchor_wh[i].view(1,1,9,2).expand(fm_h,fm_w,9,2)
box = torch.cat([xy,wh], 3) # [x,y,w,h]
boxes.append(box.view(-1,4))
return torch.cat(boxes, 0)

def encode(self, boxes, labels, input_size):
'''Encode target bounding boxes and class labels.
We obey the Faster RCNN box coder:
tx = (x - anchor_x) / anchor_w
ty = (y - anchor_y) / anchor_h
tw = log(w / anchor_w)
th = log(h / anchor_h)
Args:
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [#obj, 4].
labels: (tensor) object class labels, sized [#obj,].
input_size: (int/tuple) model input size of (w,h).
Returns:
loc_targets: (tensor) encoded bounding boxes, sized [#anchors,4].
cls_targets: (tensor) encoded class labels, sized [#anchors,].
'''
input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \
else torch.Tensor(input_size)
anchor_boxes = self._get_anchor_boxes(input_size)
boxes = change_box_order(boxes, 'xyxy2xywh')

ious = box_iou(anchor_boxes, boxes, order='xywh')
max_ious, max_ids = ious.max(1)
boxes = boxes[max_ids]

loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:]
loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:])
loc_targets = torch.cat([loc_xy,loc_wh], 1)
cls_targets = 1 + labels[max_ids]

cls_targets[max_ious<0.5] = 0
ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5]
cls_targets[ignore] = -1 # for now just mark ignored to -1
return loc_targets, cls_targets

def decode(self, loc_preds, cls_preds, input_size):
'''Decode outputs back to bouding box locations and class labels.
Args:
loc_preds: (tensor) predicted locations, sized [#anchors, 4].
cls_preds: (tensor) predicted class labels, sized [#anchors, #classes].
input_size: (int/tuple) model input size of (w,h).
Returns:
boxes: (tensor) decode box locations, sized [#obj,4].
labels: (tensor) class labels for each box, sized [#obj,].
'''
CLS_THRESH = 0.5
NMS_THRESH = 0.5

input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \
else torch.Tensor(input_size)
anchor_boxes = self._get_anchor_boxes(input_size)

loc_xy = loc_preds[:,:2]
loc_wh = loc_preds[:,2:]

xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2]
wh = loc_wh.exp() * anchor_boxes[:,2:]
boxes = torch.cat([xy-wh/2, xy+wh/2], 1) # [#anchors,4]

score, labels = cls_preds.sigmoid().max(1) # [#anchors,]
ids = score > CLS_THRESH
ids = ids.nonzero().squeeze() # [#obj,]
keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH)
return boxes[ids][keep], labels[ids][keep]
Loading

0 comments on commit 0633363

Please sign in to comment.