3
3
4
4
5
5
class CyclicPadding2D (keras .layers .Layer ):
6
+ """
7
+ It adds cyclic padding around the two last dimensions of the tensor. No weights to train.
8
+ """
6
9
7
10
def __init__ (self ,):
8
11
super (CyclicPadding2D , self ).__init__ ()
@@ -13,6 +16,13 @@ def build(self, input_shape):
13
16
super (CyclicPadding2D , self ).build (input_shape )
14
17
15
18
def call (self , inputs ):
19
+ """
20
+ Args:
21
+ inputs: a 3D tensor of shape (batch_size, d1, d2)
22
+
23
+ Returns:
24
+ The padded 3D tensor of shape (batch_size, d1+2, d+2)
25
+ """
16
26
17
27
self .grid [:,1 :- 1 , 1 :- 1 ].assign (inputs )
18
28
self .grid [:,0 ,0 ].assign (inputs [:,- 1 ,- 1 ])
@@ -27,11 +37,14 @@ def call(self, inputs):
27
37
return self .grid
28
38
29
39
class DenseSymmetric2D (tf .keras .layers .Layer ):
40
+ """
41
+ It creates a dense layer where the weight matrix is symmetric along the two axes.
42
+ """
30
43
31
44
def __init__ (self ,):
32
45
super (DenseSymmetric2D , self ).__init__ ()
33
46
34
- def __call__ (self , input_shape ):
47
+ def build (self , input_shape ):
35
48
36
49
w1 = tf .constant (tf .keras .initializers .RandomUniform (minval = 0.01 , maxval = 0.09 ),
37
50
shape = (input_shape [0 ], input_shape [1 ], input_shape [2 ]))
@@ -44,23 +57,13 @@ def __call__(self, input_shape):
44
57
w2 = tf .transpose (w1 )
45
58
self .W = w1 + w2
46
59
60
+ def call (self , x ):
61
+ # TODO: to do it all
47
62
48
- def __init__ (self ,):
49
-
50
- def build (self , input_shape ):
51
- self .grid = tf .Variable (tf .zeros (shape = (input_shape [0 ], input_shape [1 ]+ 2 , input_shape [2 ]+ 2 ), dtype = tf .float32 ),
52
- trainable = False , validate_shape = True )
53
- super (CyclicPadding2D , self ).build (input_shape )
54
-
55
- def call (self , inputs ):
56
-
57
- self .grid [:,1 :- 1 , 1 :- 1 ].assign (inputs )
58
- self .grid [:,0 ,0 ].assign (inputs [:,- 1 ,- 1 ])
59
- self .grid [:,0 ,- 1 ].assign ( inputs [:,- 1 ,0 ])
60
- self .grid [:,- 1 ,0 ].assign (inputs [:,0 ,- 1 ])
61
- self .grid [:,- 1 ,- 1 ].assign (inputs [:,0 ,0 ])
62
-
63
63
class LocallyDense (keras .layers .Layer ):
64
+ """
65
+ Warning: to be used after an instance of CyclicPadding2D.
66
+ """
64
67
65
68
def __init__ (self , ):
66
69
super (LocallyDense , self ).__init__ ()
@@ -81,6 +84,14 @@ def build(self, input_shape):
81
84
self .b = self .add_weight (name = "b" , shape = (m ,n ), initializer = 'zeros' , trainable = True )
82
85
83
86
def call (self , padded_input ):
87
+ """
88
+ Args:
89
+ padded_input (3D tensor): A tensor with shape (batch_size, d1, d2). A list of grids with cyclic padding.
90
+
91
+ Returns:
92
+ 3D tensor of shape (batch_size, d1-2, d2-2). Weighted sum of the elements in the 3x3 grid around each cell,
93
+ with bias.
94
+ """
84
95
p00 = padded_input [:,:- 2 ,:- 2 ]
85
96
p01 = padded_input [:,:- 2 ,1 :- 1 ]
86
97
p02 = padded_input [:,:- 2 ,2 :]
@@ -91,17 +102,20 @@ def call(self, padded_input):
91
102
p21 = padded_input [:,2 :,1 :- 1 ]
92
103
p22 = padded_input [:,2 :,2 :]
93
104
94
- return tf .matmul (p00 , self .w00 ) + tf .matmul (p01 , self .w01 ) + tf .matmul (p02 , self .w02 ) +
95
- tf .matmul (p10 , self .w10 ) + tf .matmul (p11 , self .w11 ) + tf .matmul (p12 , self .w12 ) +
96
- tf .matmul (p20 , self .w20 ) + tf .matmul (p21 , self .w21 ) + tf .matmul (p22 , self .w22 ) + self .b
105
+ return tf .matmul (p00 , self .w00 ) + tf .matmul (p01 , self .w01 ) + tf .matmul (p02 , self .w02 ) + \
106
+ tf .matmul (p10 , self .w10 ) + tf .matmul (p11 , self .w11 ) + tf .matmul (p12 , self .w12 ) + \
107
+ tf .matmul (p20 , self .w20 ) + tf .matmul (p21 , self .w21 ) + tf .matmul (p22 , self .w22 ) + self .b
97
108
98
109
class Conv2D (keras .layers .Layer ):
99
-
110
+ """
111
+ Just a function, no weights to train.
112
+ """
113
+ # TODO: check if I can just use tf.nn.conv2d in the model
114
+
100
115
def __init__ (self ,kernel ):
101
116
super (Conv2D , self ).__init__ ()
102
117
self .kernel = kernel
103
118
104
119
def call (self , x ):
105
- print (x .shape )
106
120
x = tf .nn .conv2d (x , self .kernel , strides = 1 , padding = 'VALID' )
107
121
return x
0 commit comments