1
1
import tensorflow as tf
2
- from tensorflow .python .training .tracking .data_structures import NoDependency
3
- import torch
4
-
2
+ from tools import CyclicPadding2D
5
3
6
4
class ContinuousGameOfLife (tf .keras .layers .Layer ):
7
5
8
- def __init__ (self , ):
6
+ def __init__ (self , game_function ):
9
7
super (ContinuousGameOfLife , self ).__init__ ()
10
- self .flat = tf .keras .layers .Flatten ()
8
+ self .forward_game = game_function
9
+
11
10
self .add_padding = CyclicPadding2D ()
12
11
13
- def build (self , input_shape ):
14
- self .k1 = tf .constant ([[1 ,1 ,1 ],[1 ,0 ,1 ],[1 ,1 ,1 ]], dtype = 'float32' )
15
- self .k1 = tf .reshape (self .k1 , shape = (3 ,3 ,1 ,1 ))
16
- self .k2 = tf .constant ([[0 ,0 ,0 ],[0 ,1 ,0 ],[0 ,0 ,0 ]], dtype = 'float32' )
17
- self .k2 = tf .reshape (self .k2 , shape = (3 ,3 ,1 ,1 ))
18
- super (ContinuousGameOfLife , self ).build (input_shape )
12
+ self .k1 = tf .constant ([[1 ,1 ,1 ],[1 ,0 ,1 ],[1 ,1 ,1 ]], shape = (3 ,3 ,1 ,1 ), dtype = 'float32' )
13
+ self .k2 = tf .constant ([[0 ,0 ,0 ],[0 ,1 ,0 ],[0 ,0 ,0 ]], shape = (3 ,3 ,1 ,1 ), dtype = 'float32' )
19
14
20
15
def call (self , inputs ):
21
16
batch_size , d1 , d2 = inputs .shape
@@ -24,32 +19,36 @@ def call(self, inputs):
24
19
cell = tf .nn .conv2d (x , filters = self .k2 , strides = 1 , padding = 'VALID' )
25
20
around_cell = tf .nn .conv2d (x , filters = self .k1 , strides = 1 , padding = 'VALID' )
26
21
27
- x1 = tf .math .maximum (4 - around_cell ,0 )
28
- x2 = tf .math .maximum ((around_cell + cell )- 2 ,0 )
29
- x3 = tf .math .minimum (x1 , x2 )
30
- x4 = tf .math .minimum (x3 ,1 )
31
-
32
- return tf .reshape (x4 , shape = (batch_size ,d1 ,d2 ))
22
+ xx = self .forward_game (cell , around_cell )
23
+
24
+ return tf .reshape (xx , shape = (batch_size ,d1 ,d2 ))
33
25
34
26
35
- class ContinuousGameOfLife3x3 (tf .keras .layers . Layer ):
27
+ class ContinuousReverseGame (tf .keras .models . Model ):
36
28
37
- def __init__ (self , ):
38
- super (ContinuousGameOfLife3x3 , self ).__init__ ()
29
+ def __init__ (self , game_function , min_v , max_v , grid_len ):
30
+ super (ContinuousReverseGame , self ).__init__ ()
31
+ self .forward_game = game_function
32
+ self .min_v = min_v
33
+ self .max_v = max_v
34
+ self .l = grid_len
35
+ self .k1 = tf .constant ([[1 ,1 ,1 ],[1 ,0 ,1 ],[1 ,1 ,1 ]], shape = (3 ,3 ,1 ,1 ), dtype = 'float32' )
36
+ self .k2 = tf .constant ([[0 ,0 ,0 ],[0 ,1 ,0 ],[0 ,0 ,0 ]], shape = (3 ,3 ,1 ,1 ), dtype = 'float32' )
39
37
40
- def build (self , input_shape ):
41
- self .k1 = tf .constant ([[1 ,1 ,1 ],[1 ,0 ,1 ],[1 ,1 ,1 ]], dtype = 'float32' )
42
- self .k2 = tf .constant ([[0 ,0 ,0 ],[0 ,1 ,0 ],[0 ,0 ,0 ]], dtype = 'float32' )
43
- super (ContinuousGameOfLife3x3 , self ).build (input_shape )
44
-
45
- def call (self , inputs ):
46
- cell = tf .tensordot (inputs , self .k2 , axes = ([1 ,2 ], [0 ,1 ]))
47
- around_cell = tf .tensordot (inputs , self .k1 , axes = ([1 ,2 ], [0 ,1 ]))
48
-
49
- x1 = tf .math .maximum (4 - around_cell ,0 )
50
- x2 = tf .math .maximum ((around_cell + cell )- 2 ,0 )
51
- x3 = tf .math .minimum (x1 , x2 )
52
- x4 = tf .math .minimum (x3 ,1 )
53
-
54
- return tf .reshape (x4 , shape = (- 1 ,1 ,1 ))
38
+ self .input_img = tf .Variable (tf .random .uniform (shape = (1 ,self .l + 2 ,self .l + 2 ), minval = self .min_v , maxval = self .max_v ), trainable = True , validate_shape = True ) #constraint=tf.keras.constraints.min_max_norm(0,1))
55
39
40
+
41
+ def call (self , target ):
42
+ self .input_img [:,0 ,:].assign (self .input_img [:,- 2 ,:])
43
+ self .input_img [:,- 1 ,:].assign (self .input_img [:,1 ,:])
44
+ self .input_img [:,:,0 ].assign (self .input_img [:,:,- 2 ])
45
+ self .input_img [:,:,- 1 ].assign (self .input_img [:,:,1 ])
46
+
47
+
48
+ input_img = tf .reshape (self .input_img , shape = (1 , self .l + 2 , self .l + 2 , 1 ))
49
+ cell = tf .nn .conv2d (input_img , filters = self .k2 , strides = 1 , padding = 'VALID' )
50
+ around_cell = tf .nn .conv2d (input_img , filters = self .k1 , strides = 1 , padding = 'VALID' )
51
+
52
+ xx = self .forward_game (cell , around_cell )
53
+ xx = tf .reshape (xx , shape = (self .l ,self .l ))
54
+ return xx
0 commit comments