Skip to content

Commit 7494f72

Browse files
Pass associative scan arg through linalg funcs (#34)
* Pass associative scan arg through linalg funcs * Fix style * Add docstring for assoc scan arg
1 parent 3446bd4 commit 7494f72

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

thermox/linalg.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def solve(
1212
dt: float = 1.0,
1313
burnin: int = 0,
1414
key: Array = None,
15+
associative_scan: bool = True,
1516
) -> Array:
1617
"""
1718
Obtain the solution of the linear system
@@ -29,6 +30,7 @@ def solve(
2930
burnin: Time-step index corresponding to the end of the burn-in period.
3031
Samples before this step are not collected.
3132
key: JAX random key
33+
associative_scan: If True, uses jax.lax.associative_scan.
3234
3335
Returns:
3436
Approximate solution, x, of the linear system.
@@ -37,7 +39,9 @@ def solve(
3739
key = random.PRNGKey(0)
3840
ts = jnp.arange(burnin, burnin + num_samples) * dt
3941
x0 = jnp.zeros_like(b)
40-
samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b))
42+
samples = sample_identity_diffusion(
43+
key, ts, x0, A, jnp.linalg.solve(A, b), associative_scan
44+
)
4145
return jnp.mean(samples, axis=0)
4246

4347

@@ -47,6 +51,7 @@ def inv(
4751
dt: float = 1.0,
4852
burnin: int = 0,
4953
key: Array = None,
54+
associative_scan: bool = True,
5055
) -> Array:
5156
"""
5257
Obtain the inverse of a matrix A by
@@ -60,6 +65,7 @@ def inv(
6065
burnin: Time-step index corresponding to the end of the burn-in period.
6166
Samples before this step are not collected.
6267
key: JAX random key
68+
associative_scan: If True, uses jax.lax.associative_scan.
6369
6470
Returns:
6571
Approximate inverse of A.
@@ -69,7 +75,7 @@ def inv(
6975
ts = jnp.arange(burnin, burnin + num_samples) * dt
7076
b = jnp.zeros(A.shape[0])
7177
x0 = jnp.zeros_like(b)
72-
samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]))
78+
samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]), associative_scan)
7379
return jnp.cov(samples.T)
7480

7581

@@ -80,6 +86,7 @@ def expnegm(
8086
burnin: int = 0,
8187
key: Array = None,
8288
alpha: float = 0.0,
89+
associative_scan: bool = True,
8390
) -> Array:
8491
"""
8592
Obtain the negative exponential of a matrix A by
@@ -95,6 +102,7 @@ def expnegm(
95102
key: JAX random key
96103
alpha: Regularization parameter to ensure diffusion matrix
97104
is symmetric positive definite.
105+
associative_scan: If True, uses jax.lax.associative_scan.
98106
99107
Returns:
100108
Approximate negative matrix exponential, exp(-A).
@@ -108,7 +116,7 @@ def expnegm(
108116
ts = jnp.arange(burnin, burnin + num_samples) * dt
109117
b = jnp.zeros(A.shape[0])
110118
x0 = jnp.zeros_like(b)
111-
samples = sample(key, ts, x0, A_shifted, b, B)
119+
samples = sample(key, ts, x0, A_shifted, b, B, associative_scan)
112120
return autocovariance(samples) * jnp.exp(alpha)
113121

114122

@@ -119,6 +127,7 @@ def expm(
119127
burnin: int = 0,
120128
key: Array = None,
121129
alpha: float = 1.0,
130+
associative_scan: bool = True,
122131
) -> Array:
123132
"""
124133
Obtain the exponential of a matrix A by
@@ -134,11 +143,12 @@ def expm(
134143
key: JAX random key
135144
alpha: Regularization parameter to ensure diffusion matrix
136145
is symmetric positive definite.
146+
associative_scan: If True, uses jax.lax.associative_scan.
137147
138148
Returns:
139149
Approximate matrix exponential, exp(A).
140150
"""
141-
return expnegm(-A, num_samples, dt, burnin, key, alpha)
151+
return expnegm(-A, num_samples, dt, burnin, key, alpha, associative_scan)
142152

143153

144154
def autocovariance(samples: Array) -> Array:

0 commit comments

Comments
 (0)