-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdemo-rsa.py
88 lines (76 loc) · 2.19 KB
/
demo-rsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from memo import memo
import jax
import jax.numpy as np
from enum import IntEnum
## Boilerplate - define U, R, and denotes.
class U(IntEnum): # utterance space
GREEN = 0b0001
PINK = 0b0010
SQUARE = 0b0100
ROUND = 0b1000
class R(IntEnum): # referent space
GREEN_SQUARE = U.GREEN | U.SQUARE
GREEN_CIRCLE = U.GREEN | U.ROUND
PINK_CIRCLE = U.PINK | U.ROUND
@jax.jit
def denotes(u, r):
return (u & r) != 0
## Recursive RSA model
@memo
def L[u: U, r: R](beta, t):
listener: thinks[
speaker: given(r in R, wpp=1),
speaker: chooses(u in U, wpp=
denotes(u, r) * (1 if t == 0 else exp(beta * L[u, r](beta, t - 1))))
]
listener: observes [speaker.u] is u
listener: chooses(r in R, wpp=Pr[speaker.r == r])
return Pr[listener.r == r]
beta = 1.
print(L(beta, 0))
print(L(beta, 1))
## Fitting the model to data...
Y = np.array([65, 115, 0]) / 180 # data from Qing & Franke 2015
@jax.jit
def loss(beta):
return np.mean((L(beta, 1)[0] - Y) ** 2)
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
## Best model fit vs. data
beta = 1.74
plt.subplot(2, 1, 2)
X = np.array([0, 1, 2])
plt.bar(X - 0.25, Y, width=0.25, yerr=2 * np.sqrt(Y * (1 - Y) / 180), capsize=2, label='humans')
plt.bar(X + 0.00, L(beta, 1)[0], width=0.25, label='model, ℓ=1')
plt.bar(X + 0.25, L(beta, 0)[0], width=0.25, label='model, ℓ=0')
plt.xticks([0, 1, 2], ['green\nsquare', 'green\ncircle', 'pink\ncircle'])
plt.xlabel('Inferred referent r')
plt.ylabel("Probability")
plt.ylim(0, 1)
plt.legend()
plt.title('Final model fit')
## Fitting by grid search!
plt.subplot(2, 2, 1)
beta = np.linspace(0, 3, 100)
plt.plot(beta, jax.vmap(loss)(beta))
plt.xlabel('beta')
plt.ylabel('MSE (%)')
plt.yticks([0, 0.02], [0, 2])
plt.xticks([0, 1, 2, 3])
plt.title('Grid search')
## Fitting by gradient descent!
vg = jax.value_and_grad(loss)
plt.subplot(2, 2, 2)
losses = []
beta = 0.
for _ in range(26):
l, dbeta = vg(beta)
losses.append(l)
beta = beta - dbeta * 12.
plt.plot(np.arange(len(losses)), losses)
plt.ylabel('MSE (%)')
plt.xlabel('Step #')
plt.yticks([0, 0.02], [0, 2])
plt.title('Gradient descent')
plt.tight_layout()
plt.savefig('../paper/fig/rsa-fit.pdf')