From 7494f72efb0f90102a3534d3d20019f603226b39 Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Thu, 11 Jul 2024 13:27:39 -0400 Subject: [PATCH] Pass associative scan arg through linalg funcs (#34) * Pass associative scan arg through linalg funcs * Fix style * Add docstring for assoc scan arg --- thermox/linalg.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/thermox/linalg.py b/thermox/linalg.py index a544aea..dba4a72 100644 --- a/thermox/linalg.py +++ b/thermox/linalg.py @@ -12,6 +12,7 @@ def solve( dt: float = 1.0, burnin: int = 0, key: Array = None, + associative_scan: bool = True, ) -> Array: """ Obtain the solution of the linear system @@ -29,6 +30,7 @@ def solve( burnin: Time-step index corresponding to the end of the burn-in period. Samples before this step are not collected. key: JAX random key + associative_scan: If True, uses jax.lax.associative_scan. Returns: Approximate solution, x, of the linear system. @@ -37,7 +39,9 @@ def solve( key = random.PRNGKey(0) ts = jnp.arange(burnin, burnin + num_samples) * dt x0 = jnp.zeros_like(b) - samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b)) + samples = sample_identity_diffusion( + key, ts, x0, A, jnp.linalg.solve(A, b), associative_scan + ) return jnp.mean(samples, axis=0) @@ -47,6 +51,7 @@ def inv( dt: float = 1.0, burnin: int = 0, key: Array = None, + associative_scan: bool = True, ) -> Array: """ Obtain the inverse of a matrix A by @@ -60,6 +65,7 @@ def inv( burnin: Time-step index corresponding to the end of the burn-in period. Samples before this step are not collected. key: JAX random key + associative_scan: If True, uses jax.lax.associative_scan. Returns: Approximate inverse of A. @@ -69,7 +75,7 @@ def inv( ts = jnp.arange(burnin, burnin + num_samples) * dt b = jnp.zeros(A.shape[0]) x0 = jnp.zeros_like(b) - samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0])) + samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]), associative_scan) return jnp.cov(samples.T) @@ -80,6 +86,7 @@ def expnegm( burnin: int = 0, key: Array = None, alpha: float = 0.0, + associative_scan: bool = True, ) -> Array: """ Obtain the negative exponential of a matrix A by @@ -95,6 +102,7 @@ def expnegm( key: JAX random key alpha: Regularization parameter to ensure diffusion matrix is symmetric positive definite. + associative_scan: If True, uses jax.lax.associative_scan. Returns: Approximate negative matrix exponential, exp(-A). @@ -108,7 +116,7 @@ def expnegm( ts = jnp.arange(burnin, burnin + num_samples) * dt b = jnp.zeros(A.shape[0]) x0 = jnp.zeros_like(b) - samples = sample(key, ts, x0, A_shifted, b, B) + samples = sample(key, ts, x0, A_shifted, b, B, associative_scan) return autocovariance(samples) * jnp.exp(alpha) @@ -119,6 +127,7 @@ def expm( burnin: int = 0, key: Array = None, alpha: float = 1.0, + associative_scan: bool = True, ) -> Array: """ Obtain the exponential of a matrix A by @@ -134,11 +143,12 @@ def expm( key: JAX random key alpha: Regularization parameter to ensure diffusion matrix is symmetric positive definite. + associative_scan: If True, uses jax.lax.associative_scan. Returns: Approximate matrix exponential, exp(A). """ - return expnegm(-A, num_samples, dt, burnin, key, alpha) + return expnegm(-A, num_samples, dt, burnin, key, alpha, associative_scan) def autocovariance(samples: Array) -> Array: