-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdemo-empowerment.py
131 lines (113 loc) · 3.45 KB
/
demo-empowerment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from memo import memo, domain, make_module
import jax
import jax.numpy as np
from enum import IntEnum
from matplotlib import pyplot as plt
"""
**Inspired by:** Klyubin, A. S., Polani, D., & Nehaniv, C. L. (2005, September).
All else being equal be empowered. In European Conference on Artificial Life
(pp. 744-753). Berlin, Heidelberg: Springer Berlin Heidelberg.
This example shows how to use memo to compute an agent's empowerment in a gridworld.
The particular example is inspired by Figure 3a in Klyubin et al (2005).
"""
## This is a little memo module that implements the Blahut-Arimoto algorithm for empowerment
# See: https://www.comm.utoronto.ca/~weiyu/ab_isit04.pdf
def make_blahut_arimoto(X, Y, Z, p_Y_given_X):
m = make_module('blahut_arimoto')
m.X = X
m.Y = Y
m.Z = Z
m.p_Y_given_X = p_Y_given_X
@memo(install_module=m.install)
def q[x: X, z: Z](t):
alice: knows(z)
alice: chooses(x in X, wpp=imagine[
bob: knows(x, z),
bob: chooses(y in Y, wpp=p_Y_given_X(y, x, z)),
# exp(E[ log(Q[x, bob.y, z](t - 1) if t > 0 else 1) ])
bob: thinks[
charlie: knows(y, z),
charlie: chooses(x in X, wpp=Q[x, y, z](t - 1) if t > 0 else 1)
],
exp(E[bob[H[charlie.x]]])
])
return Pr[alice.x == x]
@memo(install_module=m.install)
def Q[x: X, y: Y, z: Z](t):
alice: knows(x, y, z)
alice: thinks[
bob: knows(z),
bob: chooses(x in X, wpp=q[x, z](t)),
bob: chooses(y in Y, wpp=p_Y_given_X(y, x, z))
]
alice: observes [bob.y] is y
return alice[Pr[bob.x == x]]
@memo(install_module=m.install)
def C[z: Z](t):
alice: knows(z)
alice: chooses(x in X, wpp=q[x, z](t))
alice: chooses(y in Y, wpp=p_Y_given_X(y, x, z))
return (H[alice.x] + H[alice.y] - H[alice.x, alice.y]) / log(2) # convert to bits
return m
# # Sanity check: a channel that drops messages with probability 0.1 should have capacity 0.9 bits.
# X = [0, 1]
# Y = [0, 1, 2]
# @jax.jit
# def p_Y_given_X(y, x, z):
# return np.array([
# [0.9, 0.1, 1e-10],
# [1e-10, 0.1, 0.9]
# ])[x, y]
# m = make_blahut_arimoto(X, Y, np.array([0]), p_Y_given_X)
# print(m.q(10))
# print(m.C(10))
## Setting up a gridworld...
N = 13
world = np.zeros((N, N))
world = world.at[N // 2, N // 2].set(1)
X = np.arange(world.shape[0])
Y = np.arange(world.shape[1])
S = domain(x=len(X), y=len(Y))
class A(IntEnum):
N = 0
S = 1
W = 2
E = 3
O = 4
Ax = domain(
a1=len(A),
a2=len(A),
a3=len(A),
a4=len(A),
a5=len(A),
)
@jax.jit
def Tr1(s, a):
x = S.x(s)
y = S.y(s)
z = np.array([
[x, y - 1],
[x, y + 1],
[x - 1, y],
[x + 1, y],
[x, y]
])[a]
x_ = np.clip(z[0], 0, len(X) - 1)
y_ = np.clip(z[1], 0, len(Y) - 1)
return np.where(world[x_, y_], s, S(x_, y_))
@jax.jit
def Tr(s_, ax, s):
for a in Ax._tuple(ax):
s = Tr1(s, a)
return s == s_
# ...and computing 5-step empowerment in the gridworld!
m = make_blahut_arimoto(X=Ax, Y=S, Z=S, p_Y_given_X=Tr)
m.Z = S
@memo(install_module=m.install, debug_trace=True)
def empowerment[s: Z](t):
return C[s](t)
emp = m.empowerment(5).block_until_ready()
emp = emp.reshape(len(X), len(Y))
emp = emp * (1 - world)
plt.colorbar(plt.imshow(emp.reshape(len(X), len(Y)) * (1 - world), cmap='gray'))
plt.savefig('out.png')