-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcaps_net.py
70 lines (59 loc) · 2.91 KB
/
caps_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from caps_layers import PrimaryCap, CapsuleLayer, length, Mask, squash
class CapsNet(gluon.HybridBlock):
def __init__(self, n_class, num_routing, input_shape, outmask=True, **kwargs):
super(CapsNet, self).__init__(**kwargs)
N, C, W, H = input_shape
self.n_class = n_class
self.batch_size = N
self.outmask = outmask
self.input_shape = input_shape
with self.name_scope():
self.net = nn.HybridSequential(prefix='')
self.net.add(nn.Conv2D(256, kernel_size=9, strides=1, padding=0,
activation='relu'))
self.net.add(PrimaryCap(dim_capsule=8, n_channels=32, kernel_size=9, strides=2,
padding=0))
caps_in_shape = (N, 1152, 8)
weight_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=3)
self.net.add(CapsuleLayer(num_capsule=n_class,
dim_capsule=16,
num_routing=num_routing,
in_shape=caps_in_shape,
weight_initializer=weight_initializer))
self.decoder = nn.HybridSequential(prefix='')
self.decoder.add(nn.Dense(512, activation='relu'))
self.decoder.add(nn.Dense(1024, activation='relu'))
self.decoder.add(nn.Dense(W*H, activation='sigmoid'))
def hybrid_forward(self, F, x, y=None):
digitcaps = self.net(x)
#print "digitcaps", digitcaps.shape
out_caps = length(F, digitcaps)
if self.outmask:
y_reshaped = F.reshape(y, (self.batch_size, -4, self.n_class, -1))
#print "y_reshaped", y_reshaped.shape
# decode network
#masked_by_y = Mask(F, [digitcaps, y])
masked_by_y = F.linalg_gemm2(y_reshaped, digitcaps, transpose_a=True)
masked_by_y = F.reshape(data=masked_by_y, shape=(-3, 0))
out_mask = self.decoder(masked_by_y)
out_mask = F.reshape(out_mask, self.input_shape)
return out_caps, out_mask
else:
return out_caps
def margin_loss(F, y_true, y_pred):
"""
Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
:param y_true: [None, n_classes]
:param y_pred: [None, num_capsule]
:return: a scalar loss value.
"""
loss = y_true * F.square(F.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * F.square(F.maximum(0., y_pred - 0.1))
return F.mean(F.sum(loss, 1))
def mask_mse_loss(F, mask_true, mask_pred):
squared_error = F.square(mask_pred-mask_true)
recon_error = F.mean(squared_error)
return recon_error