Skip to content

Commit 48dcbf9

Browse files
committed
add empowerment demo
1 parent 7775709 commit 48dcbf9

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed

demo/demo-empowerment.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from memo import memo, domain, make_module
2+
import jax
3+
import jax.numpy as np
4+
from enum import IntEnum
5+
6+
from matplotlib import pyplot as plt
7+
8+
"""
9+
This example shows how to use memo to compute an agent's empowerment in a gridworld. The particular example is inspired by Figure 3a in Klyubin et al (2005).
10+
11+
Klyubin, A. S., Polani, D., & Nehaniv, C. L. (2005, September). All else being equal be empowered. In European Conference on Artificial Life (pp. 744-753). Berlin, Heidelberg: Springer Berlin Heidelberg.
12+
"""
13+
14+
# See: https://www.comm.utoronto.ca/~weiyu/ab_isit04.pdf
15+
def make_blahut_arimoto(X, Y, Z, p_Y_given_X):
16+
m = make_module('blahut_arimoto')
17+
m.X = X
18+
m.Y = Y
19+
m.Z = Z
20+
m.p_Y_given_X = p_Y_given_X
21+
22+
@memo(install_module=m.install)
23+
def q[x: X, z: Z](t):
24+
alice: knows(z)
25+
alice: chooses(x in X, wpp=imagine[
26+
bob: knows(x, z),
27+
bob: chooses(y in Y, wpp=p_Y_given_X(y, x, z)),
28+
# exp(E[ log(Q[x, bob.y, z](t - 1) if t > 0 else 1) ])
29+
bob: thinks[
30+
charlie: knows(y, z),
31+
charlie: chooses(x in X, wpp=Q[x, y, z](t - 1) if t > 0 else 1)
32+
],
33+
exp(E[bob[H[charlie.x]]])
34+
])
35+
return Pr[alice.x == x]
36+
37+
@memo(install_module=m.install)
38+
def Q[x: X, y: Y, z: Z](t):
39+
alice: knows(x, y, z)
40+
alice: thinks[
41+
bob: knows(x, z),
42+
bob: chooses(x in X, wpp=q[x, z](t)),
43+
bob: chooses(y in Y, wpp=p_Y_given_X(y, x, z))
44+
]
45+
alice: observes [bob.y] is y
46+
return alice[Pr[bob.x == x]]
47+
48+
@memo(install_module=m.install)
49+
def C[z: Z](t):
50+
alice: knows(z)
51+
alice: chooses(x in X, wpp=q[x, z](t))
52+
alice: chooses(y in Y, wpp=p_Y_given_X(y, x, z))
53+
return (H[alice.x] + H[alice.y] - H[alice.x, alice.y]) / log(2) # convert to bits
54+
55+
return m
56+
57+
# # Example: channel that drops message with probability 0.1 has capacity 0.9
58+
# X = [0, 1]
59+
# Y = [0, 1, 2]
60+
# @jax.jit
61+
# def p_Y_given_X(y, x, z):
62+
# return np.array([
63+
# [0.9, 0.1, 1e-10],
64+
# [1e-10, 0.1, 0.9]
65+
# ])[x, y]
66+
# m = make_blahut_arimoto(X, Y, np.array([0]), p_Y_given_X)
67+
# print(m.q(10))
68+
# print(m.C(10))
69+
70+
71+
72+
world = np.array([
73+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
74+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
75+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
76+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
77+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
78+
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
79+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
80+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
81+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
82+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
83+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
84+
])
85+
86+
X = np.arange(world.shape[0])
87+
Y = np.arange(world.shape[1])
88+
S = domain(x=len(X), y=len(Y))
89+
90+
class A(IntEnum):
91+
N = 0
92+
S = 1
93+
W = 2
94+
E = 3
95+
O = 4
96+
Ax = domain(
97+
a1=len(A),
98+
a2=len(A),
99+
a3=len(A),
100+
a4=len(A),
101+
)
102+
103+
@jax.jit
104+
def Tr1(s, a):
105+
x = S.x(s)
106+
y = S.y(s)
107+
z = np.array([
108+
[x, y - 1],
109+
[x, y + 1],
110+
[x - 1, y],
111+
[x + 1, y],
112+
[x, y]
113+
])[a]
114+
x_ = np.clip(z[0], 0, len(X) - 1)
115+
y_ = np.clip(z[1], 0, len(Y) - 1)
116+
return np.where(world[x_, y_], s, S(x_, y_))
117+
118+
119+
@jax.jit
120+
def Tr(s_, ax, s):
121+
for a in Ax._tuple(ax):
122+
s = Tr1(s, a)
123+
return s == s_
124+
125+
m = make_blahut_arimoto(X=Ax, Y=S, Z=S, p_Y_given_X=Tr)
126+
m.Z = S
127+
@memo(install_module=m.install, debug_trace=True)
128+
def empowerment[s: Z](t):
129+
return C[s](t)
130+
131+
emp = m.empowerment(5).block_until_ready()
132+
emp = emp.reshape(len(X), len(Y))
133+
emp = emp * (1 - world)
134+
plt.colorbar(plt.imshow(emp.reshape(len(X), len(Y)) * (1 - world), cmap='gray'))
135+
plt.savefig('out.png')

0 commit comments

Comments
 (0)