1
+ from memo import memo , domain , make_module
2
+ import jax
3
+ import jax .numpy as np
4
+ from enum import IntEnum
5
+
6
+ from matplotlib import pyplot as plt
7
+
8
+ """
9
+ This example shows how to use memo to compute an agent's empowerment in a gridworld. The particular example is inspired by Figure 3a in Klyubin et al (2005).
10
+
11
+ Klyubin, A. S., Polani, D., & Nehaniv, C. L. (2005, September). All else being equal be empowered. In European Conference on Artificial Life (pp. 744-753). Berlin, Heidelberg: Springer Berlin Heidelberg.
12
+ """
13
+
14
+ # See: https://www.comm.utoronto.ca/~weiyu/ab_isit04.pdf
15
+ def make_blahut_arimoto (X , Y , Z , p_Y_given_X ):
16
+ m = make_module ('blahut_arimoto' )
17
+ m .X = X
18
+ m .Y = Y
19
+ m .Z = Z
20
+ m .p_Y_given_X = p_Y_given_X
21
+
22
+ @memo (install_module = m .install )
23
+ def q [x : X , z : Z ](t ):
24
+ alice : knows (z )
25
+ alice : chooses (x in X , wpp = imagine [
26
+ bob : knows (x , z ),
27
+ bob : chooses (y in Y , wpp = p_Y_given_X (y , x , z )),
28
+ # exp(E[ log(Q[x, bob.y, z](t - 1) if t > 0 else 1) ])
29
+ bob : thinks [
30
+ charlie : knows (y , z ),
31
+ charlie : chooses (x in X , wpp = Q [x , y , z ](t - 1 ) if t > 0 else 1 )
32
+ ],
33
+ exp (E [bob [H [charlie .x ]]])
34
+ ])
35
+ return Pr [alice .x == x ]
36
+
37
+ @memo (install_module = m .install )
38
+ def Q [x : X , y : Y , z : Z ](t ):
39
+ alice : knows (x , y , z )
40
+ alice : thinks [
41
+ bob : knows (x , z ),
42
+ bob : chooses (x in X , wpp = q [x , z ](t )),
43
+ bob : chooses (y in Y , wpp = p_Y_given_X (y , x , z ))
44
+ ]
45
+ alice : observes [bob .y ] is y
46
+ return alice [Pr [bob .x == x ]]
47
+
48
+ @memo (install_module = m .install )
49
+ def C [z : Z ](t ):
50
+ alice : knows (z )
51
+ alice : chooses (x in X , wpp = q [x , z ](t ))
52
+ alice : chooses (y in Y , wpp = p_Y_given_X (y , x , z ))
53
+ return (H [alice .x ] + H [alice .y ] - H [alice .x , alice .y ]) / log (2 ) # convert to bits
54
+
55
+ return m
56
+
57
+ # # Example: channel that drops message with probability 0.1 has capacity 0.9
58
+ # X = [0, 1]
59
+ # Y = [0, 1, 2]
60
+ # @jax.jit
61
+ # def p_Y_given_X(y, x, z):
62
+ # return np.array([
63
+ # [0.9, 0.1, 1e-10],
64
+ # [1e-10, 0.1, 0.9]
65
+ # ])[x, y]
66
+ # m = make_blahut_arimoto(X, Y, np.array([0]), p_Y_given_X)
67
+ # print(m.q(10))
68
+ # print(m.C(10))
69
+
70
+
71
+
72
+ world = np .array ([
73
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
74
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
75
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
76
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
77
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
78
+ [0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 ],
79
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
80
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
81
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
82
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
83
+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
84
+ ])
85
+
86
+ X = np .arange (world .shape [0 ])
87
+ Y = np .arange (world .shape [1 ])
88
+ S = domain (x = len (X ), y = len (Y ))
89
+
90
+ class A (IntEnum ):
91
+ N = 0
92
+ S = 1
93
+ W = 2
94
+ E = 3
95
+ O = 4
96
+ Ax = domain (
97
+ a1 = len (A ),
98
+ a2 = len (A ),
99
+ a3 = len (A ),
100
+ a4 = len (A ),
101
+ )
102
+
103
+ @jax .jit
104
+ def Tr1 (s , a ):
105
+ x = S .x (s )
106
+ y = S .y (s )
107
+ z = np .array ([
108
+ [x , y - 1 ],
109
+ [x , y + 1 ],
110
+ [x - 1 , y ],
111
+ [x + 1 , y ],
112
+ [x , y ]
113
+ ])[a ]
114
+ x_ = np .clip (z [0 ], 0 , len (X ) - 1 )
115
+ y_ = np .clip (z [1 ], 0 , len (Y ) - 1 )
116
+ return np .where (world [x_ , y_ ], s , S (x_ , y_ ))
117
+
118
+
119
+ @jax .jit
120
+ def Tr (s_ , ax , s ):
121
+ for a in Ax ._tuple (ax ):
122
+ s = Tr1 (s , a )
123
+ return s == s_
124
+
125
+ m = make_blahut_arimoto (X = Ax , Y = S , Z = S , p_Y_given_X = Tr )
126
+ m .Z = S
127
+ @memo (install_module = m .install , debug_trace = True )
128
+ def empowerment [s : Z ](t ):
129
+ return C [s ](t )
130
+
131
+ emp = m .empowerment (5 ).block_until_ready ()
132
+ emp = emp .reshape (len (X ), len (Y ))
133
+ emp = emp * (1 - world )
134
+ plt .colorbar (plt .imshow (emp .reshape (len (X ), len (Y )) * (1 - world ), cmap = 'gray' ))
135
+ plt .savefig ('out.png' )
0 commit comments