Skip to content

Commit

Permalink
Pass associative scan arg through linalg funcs (#34)
Browse files Browse the repository at this point in the history
* Pass associative scan arg through linalg funcs

* Fix style

* Add docstring for assoc scan arg
  • Loading branch information
denismelanson authored Jul 11, 2024
1 parent 3446bd4 commit 7494f72
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions thermox/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def solve(
dt: float = 1.0,
burnin: int = 0,
key: Array = None,
associative_scan: bool = True,
) -> Array:
"""
Obtain the solution of the linear system
Expand All @@ -29,6 +30,7 @@ def solve(
burnin: Time-step index corresponding to the end of the burn-in period.
Samples before this step are not collected.
key: JAX random key
associative_scan: If True, uses jax.lax.associative_scan.
Returns:
Approximate solution, x, of the linear system.
Expand All @@ -37,7 +39,9 @@ def solve(
key = random.PRNGKey(0)
ts = jnp.arange(burnin, burnin + num_samples) * dt
x0 = jnp.zeros_like(b)
samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b))
samples = sample_identity_diffusion(
key, ts, x0, A, jnp.linalg.solve(A, b), associative_scan
)
return jnp.mean(samples, axis=0)


Expand All @@ -47,6 +51,7 @@ def inv(
dt: float = 1.0,
burnin: int = 0,
key: Array = None,
associative_scan: bool = True,
) -> Array:
"""
Obtain the inverse of a matrix A by
Expand All @@ -60,6 +65,7 @@ def inv(
burnin: Time-step index corresponding to the end of the burn-in period.
Samples before this step are not collected.
key: JAX random key
associative_scan: If True, uses jax.lax.associative_scan.
Returns:
Approximate inverse of A.
Expand All @@ -69,7 +75,7 @@ def inv(
ts = jnp.arange(burnin, burnin + num_samples) * dt
b = jnp.zeros(A.shape[0])
x0 = jnp.zeros_like(b)
samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]))
samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]), associative_scan)
return jnp.cov(samples.T)


Expand All @@ -80,6 +86,7 @@ def expnegm(
burnin: int = 0,
key: Array = None,
alpha: float = 0.0,
associative_scan: bool = True,
) -> Array:
"""
Obtain the negative exponential of a matrix A by
Expand All @@ -95,6 +102,7 @@ def expnegm(
key: JAX random key
alpha: Regularization parameter to ensure diffusion matrix
is symmetric positive definite.
associative_scan: If True, uses jax.lax.associative_scan.
Returns:
Approximate negative matrix exponential, exp(-A).
Expand All @@ -108,7 +116,7 @@ def expnegm(
ts = jnp.arange(burnin, burnin + num_samples) * dt
b = jnp.zeros(A.shape[0])
x0 = jnp.zeros_like(b)
samples = sample(key, ts, x0, A_shifted, b, B)
samples = sample(key, ts, x0, A_shifted, b, B, associative_scan)
return autocovariance(samples) * jnp.exp(alpha)


Expand All @@ -119,6 +127,7 @@ def expm(
burnin: int = 0,
key: Array = None,
alpha: float = 1.0,
associative_scan: bool = True,
) -> Array:
"""
Obtain the exponential of a matrix A by
Expand All @@ -134,11 +143,12 @@ def expm(
key: JAX random key
alpha: Regularization parameter to ensure diffusion matrix
is symmetric positive definite.
associative_scan: If True, uses jax.lax.associative_scan.
Returns:
Approximate matrix exponential, exp(A).
"""
return expnegm(-A, num_samples, dt, burnin, key, alpha)
return expnegm(-A, num_samples, dt, burnin, key, alpha, associative_scan)


def autocovariance(samples: Array) -> Array:
Expand Down

0 comments on commit 7494f72

Please sign in to comment.