Skip to content

Commit

Permalink
fix bug in example
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Mar 12, 2024
1 parent fd8f990 commit 7e943ae
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions examples/bug-64.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_ts(params, dt=0.1, T=50.0, G=0.0, sigma=0.1):
_, loop = vb.make_sde(dt, dfun=network, gfun=sigma)
par = par._replace(G=G, omega=omega)
nt = int(T / dt)
zs = vb.rand(nt, nn) * 2 * jnp.pi
zs = vb.randn(nt, nn) * 2 * jnp.pi
xs = loop(zs[0], zs[1:], (weights, par))
ts = jnp.linspace(0, nt * dt, len(xs))
return xs, ts
Expand All @@ -41,10 +41,10 @@ def get_ts(params, dt=0.1, T=50.0, G=0.0, sigma=0.1):
omega = jnp.abs(vb.randn(nn) * 1.0)
print('omega values are', omega)

# plt.figure(figsize=(10, 3))
for i, sigma in enumerate([0.0, 0.1, 0.2]):
xs, ts = get_ts((omega, weights, km_default_theta), dt=dt, G=0.9, sigma=sigma)
#plt.subplot(1, 3, i + 1)
#plt.plot(ts[:-1], jnp.sin(xs))
print(i, 'sigma=', sigma, jnp.sum(jnp.abs(jnp.diff(xs,axis=0))) )
# plt.show()
plt.figure()
for i, sigma in enumerate([0.0, 0.01, 0.1]):
xs, ts = get_ts((omega, weights, km_default_theta), dt=dt, G=0.1, sigma=sigma)
plt.subplot(3, 1, i + 1)
plt.plot(ts, jnp.sin(xs))
print(i, 'sigma=', sigma, jnp.sum(jnp.diff(xs,axis=0)) )
plt.show()

0 comments on commit 7e943ae

Please sign in to comment.