diff --git a/examples/bug-64.py b/examples/bug-64.py index 6913e0b..9bbf2d3 100644 --- a/examples/bug-64.py +++ b/examples/bug-64.py @@ -28,7 +28,7 @@ def get_ts(params, dt=0.1, T=50.0, G=0.0, sigma=0.1): _, loop = vb.make_sde(dt, dfun=network, gfun=sigma) par = par._replace(G=G, omega=omega) nt = int(T / dt) - zs = vb.rand(nt, nn) * 2 * jnp.pi + zs = vb.randn(nt, nn) * 2 * jnp.pi xs = loop(zs[0], zs[1:], (weights, par)) ts = jnp.linspace(0, nt * dt, len(xs)) return xs, ts @@ -41,10 +41,10 @@ def get_ts(params, dt=0.1, T=50.0, G=0.0, sigma=0.1): omega = jnp.abs(vb.randn(nn) * 1.0) print('omega values are', omega) -# plt.figure(figsize=(10, 3)) -for i, sigma in enumerate([0.0, 0.1, 0.2]): - xs, ts = get_ts((omega, weights, km_default_theta), dt=dt, G=0.9, sigma=sigma) - #plt.subplot(1, 3, i + 1) - #plt.plot(ts[:-1], jnp.sin(xs)) - print(i, 'sigma=', sigma, jnp.sum(jnp.abs(jnp.diff(xs,axis=0))) ) -# plt.show() +plt.figure() +for i, sigma in enumerate([0.0, 0.01, 0.1]): + xs, ts = get_ts((omega, weights, km_default_theta), dt=dt, G=0.1, sigma=sigma) + plt.subplot(3, 1, i + 1) + plt.plot(ts, jnp.sin(xs)) + print(i, 'sigma=', sigma, jnp.sum(jnp.diff(xs,axis=0)) ) +plt.show()