Skip to content

Add Deformable ConvNets #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deformable_convnets/mnist/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

45 changes: 45 additions & 0 deletions deformable_convnets/mnist/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import chainer
import chainer.functions as F
import chainer.links as L


class Convnet(chainer.Chain):

def __init__(self, n_out):
super(Convnet, self).__init__(
# the size of the inputs to each layer will be inferred
l1=L.Convolution2D(None, 32, 3, 1, 1),
l2=L.Convolution2D(None, 32, 3, 1, 1),
l3=L.Convolution2D(None, 32, 3, 1, 1),
fc=L.Linear(None, n_out),
)

def __call__(self, x):
B = x.shape[0]
h1 = F.relu(self.l1(x)) # 28
h2 = F.relu(self.l2(h1))
h2 = F.max_pooling_2d(h2, ksize=4, stride=2, pad=1) # 14
h3 = F.relu(self.l3(h2))
self.feat = h3
return self.fc(h3.reshape(B, -1))


class DeformableConvnet(chainer.Chain):

def __init__(self, n_out):
super(DeformableConvnet, self).__init__(
# the size of the inputs to each layer will be inferred
l1=L.Convolution2D(None, 32, 3, 1, 1),
l2=L.DeformableConvolution2D(None, 32, 3, 1, 1),
l3=L.DeformableConvolution2D(None, 32, 3, 1, 1),
fc=L.Linear(None, n_out),
)

def __call__(self, x):
B = x.shape[0]
h1 = F.relu(self.l1(x)) # 28
h2 = F.relu(self.l2(h1))
h2 = F.max_pooling_2d(h2, ksize=4, stride=2, pad=1) # 14
h3 = F.relu(self.l3(h2))
self.feat = h3
return self.fc(h3.reshape(B, -1))
8 changes: 8 additions & 0 deletions deformable_convnets/mnist/scale_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from chainercv import transforms


def transform(in_data):
img, label = in_data
img = transforms.random_expand(img, max_ratio=3)
img = transforms.resize(img, (28, 28))
return img, label
67 changes: 67 additions & 0 deletions deformable_convnets/mnist/test_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse
import matplotlib.pyplot as plt
import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from chainer.dataset.convert import concat_examples

from chainercv.datasets import TransformDataset
from chainercv import transforms

from models import Convnet
from models import DeformableConvnet
from scale_transform import transform


def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--resume', '-r', default='result/model_iter_9',
help='Resume the training from snapshot')
parser.add_argument('--deformable', '-d', type=int, default=1,
help='use deformable convolutions')
args = parser.parse_args()

if args.deformable == 1:
model = DeformableConvnet(10)
else:
model = Convnet(10)
chainer.serializers.load_npz(args.resume, model)

train, test = chainer.datasets.get_mnist(ndim=3)
test = TransformDataset(test, transform)

test_iter = chainer.iterators.SerialIterator(test, batch_size=1,
repeat=False, shuffle=False)

threshold = 1
for i in range(1):
batch = test_iter.next()
in_arrays = concat_examples(batch, device=None)
in_vars = tuple(chainer.Variable(x) for x in in_arrays)
img, label = in_vars
model(img)
feat = model.feat
H, W = feat.shape[2:]
center = F.sum(feat[:, :, H / 2, W / 2])
center.grad= np.ones_like(center.data)
model.zerograds()
img.zerograd()
center.backward(retain_grad=True)

img_grad = img.grad[0] # (1, 28, 28)

img_grad_abs = (np.abs(img_grad) / np.max(np.abs(img_grad)) * 255)[0] # 28, 28
img_grad_abs[np.isnan(img_grad_abs)] = 0
y_indices, x_indices = np.where(img_grad_abs > threshold)
plt.scatter(x_indices, y_indices, c='red')
vis_img = transforms.chw_to_pil_image(255 * img.data[0])[:, :, 0]
plt.imshow(vis_img, interpolation='nearest', cmap='gray')
plt.show()


if __name__ == '__main__':
main()
109 changes: 109 additions & 0 deletions deformable_convnets/mnist/train_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse

import chainer
import chainer.links as L
from chainer import training
from chainer.training import extensions

from chainercv.datasets import TransformDataset
from chainercv import transforms

from models import Convnet
from models import DeformableConvnet
from scale_transform import transform


def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--deformable', '-d', type=int, default=1,
help='use deformable convolutions')
args = parser.parse_args()

print('GPU: {}'.format(args.gpu))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('deformable: {}'.format(args.deformable))
print('')

# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
# model = L.Classifier(Convnet(10))
if args.deformable == 1:
feat_extractor = DeformableConvnet(10)
else:
feat_extractor = Convnet(10)
model = L.Classifier(feat_extractor)
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current
model.to_gpu() # Copy the model to the GPU

# Setup an optimizer
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

# Load the MNIST dataset
train, test = chainer.datasets.get_mnist(ndim=3)
train = TransformDataset(train, transform)
test = TransformDataset(test, transform)

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)

# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'))

# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer.extend(extensions.dump_graph('main/loss'))

# Take a snapshot for each specified epoch
if args.deformable:
snapshot_fn = 'model_iter_{.updater.epoch}'
else:
snapshot_fn = 'non_deformable_{.updater.epoch}'
trainer.extend(extensions.snapshot_object(feat_extractor, snapshot_fn),
trigger=(1, 'epoch'))

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())

# Save two plot images to the result dir
if extensions.PlotReport.available():
trainer.extend(
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
trainer.extend(
extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))

# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

# Print a progress bar to stdout
trainer.extend(extensions.ProgressBar())

# Run the training
trainer.run()


if __name__ == '__main__':
main()