-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathslds.py
83 lines (67 loc) · 2.85 KB
/
slds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import argparse
import torch
import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
def main(args):
funsor.set_backend("torch")
# Declare parameters.
trans_probs = funsor.Tensor(torch.tensor([[0.9, 0.1],
[0.1, 0.9]], requires_grad=True))
trans_noise = funsor.Tensor(torch.tensor([
0.1, # low noise component
1.0, # high noisy component
], requires_grad=True))
emit_noise = funsor.Tensor(torch.tensor(0.5, requires_grad=True))
params = [trans_probs.data,
trans_noise.data,
emit_noise.data]
# A Gaussian HMM model.
@funsor.interpretation(funsor.terms.moment_matching)
def model(data):
log_prob = funsor.Number(0.)
# s is the discrete latent state,
# x is the continuous latent state,
# y is the observed state.
s_curr = funsor.Tensor(torch.tensor(0), dtype=2)
x_curr = funsor.Tensor(torch.tensor(0.))
for t, y in enumerate(data):
s_prev = s_curr
x_prev = x_curr
# A delayed sample statement.
s_curr = funsor.Variable('s_{}'.format(t), funsor.Bint[2])
log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)
# A delayed sample statement.
x_curr = funsor.Variable('x_{}'.format(t), funsor.Real)
log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)
# Marginalize out previous delayed sample statements.
if t > 0:
log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name})
# An observe statement.
log_prob += dist.Normal(x_curr, emit_noise, value=y)
log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
# Train model parameters.
torch.manual_seed(0)
data = torch.randn(args.time_steps)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
log_prob = model(data)
assert not log_prob.inputs, 'free variables remain'
loss = -log_prob.data
loss.backward()
optim.step()
if args.verbose and step % 10 == 0:
print('step {} loss = {}'.format(step, loss.item()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Switching linear dynamical system")
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-n", "--train-steps", default=101, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("--filter", action='store_true')
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
main(args)