diff --git a/thermox/linalg.py b/thermox/linalg.py index 800301e..cfcd5bc 100644 --- a/thermox/linalg.py +++ b/thermox/linalg.py @@ -38,7 +38,9 @@ def solve( key = random.PRNGKey(0) ts = jnp.arange(burnin, burnin + num_samples) * dt x0 = jnp.zeros_like(b) - samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b), associative_scan) + samples = sample_identity_diffusion( + key, ts, x0, A, jnp.linalg.solve(A, b), associative_scan + ) return jnp.mean(samples, axis=0)