Skip to content

Commit 23d395c

Browse files
authored
Merge pull request #13 from normal-computing/seed
Change seed to random key for linalg
2 parents 91016d7 + 4c0c5bb commit 23d395c

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

tests/test_linalg.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,30 @@ def test_linear_system():
88
A = jnp.array([[3, 2], [2, 4.0]])
99
b = jnp.array([1, 2.0])
1010

11-
x = thermox.linalg.solve(A, b, num_samples=10000, dt=0.1, burnin=0, seed=0)
11+
x = thermox.linalg.solve(A, b, num_samples=10000, dt=0.1, burnin=0)
1212

1313
assert jnp.allclose(A @ x, b, atol=1e-1)
1414

1515

1616
def test_inv():
1717
A = jnp.array([[3, 2], [2, 4.0]])
1818

19-
A_inv = thermox.linalg.inv(A, num_samples=10000, dt=0.1, burnin=0, seed=0)
19+
A_inv = thermox.linalg.inv(A, num_samples=10000, dt=0.1, burnin=0)
2020

2121
assert jnp.allclose(A @ A_inv, jnp.eye(2), atol=1e-1)
2222

2323

2424
def test_expnegm():
2525
A = jnp.array([[3, 2], [2, 4.0]])
2626

27-
expnegA = thermox.linalg.expnegm(A, num_samples=10000, dt=0.1, burnin=0, seed=0)
27+
expnegA = thermox.linalg.expnegm(A, num_samples=10000, dt=0.1, burnin=0)
2828

2929
assert jnp.allclose(expnegA, jax.scipy.linalg.expm(-A), atol=1e-1)
3030

3131

3232
def test_expm():
3333
A = jnp.array([[-0.4, 0.1], [0.5, -0.3]])
3434

35-
expA = thermox.linalg.expm(
36-
A, num_samples=100000, dt=0.1, burnin=0, seed=0, alpha=1.0
37-
)
35+
expA = thermox.linalg.expm(A, num_samples=100000, dt=0.1, burnin=0, alpha=1.0)
3836

3937
assert jnp.allclose(expA, jax.scipy.linalg.expm(A), atol=1e-1)

thermox/linalg.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def solve(
1111
num_samples: int = 10000,
1212
dt: float = 1.0,
1313
burnin: int = 0,
14-
seed: int = 0,
14+
key: Array = None,
1515
) -> Array:
1616
"""
1717
Obtain the solution of the linear system
@@ -27,12 +27,13 @@ def solve(
2727
- num_samples: float, number of samples to be collected.
2828
- dt: float, time step.
2929
- burnin: burn-in, steps before which samples are not collected.
30-
- seed: random seed
30+
- key: JAX random key
3131
3232
Returns:
3333
- approximate solution, x, of the linear system.
3434
"""
35-
key = jax.random.PRNGKey(seed)
35+
if key is None:
36+
key = jax.random.PRNGKey(0)
3637
ts = jnp.arange(burnin, burnin + num_samples) * dt
3738
x0 = jnp.zeros_like(b)
3839
samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b))
@@ -44,7 +45,7 @@ def inv(
4445
num_samples: int = 10000,
4546
dt: float = 1.0,
4647
burnin: int = 0,
47-
seed: int = 0,
48+
key: Array = None,
4849
) -> Array:
4950
"""
5051
Obtain the inverse of a matrix A by
@@ -56,12 +57,13 @@ def inv(
5657
- num_samples: float, number of samples to be collected.
5758
- dt: float, time step.
5859
- burnin: burn-in, steps before which samples are not collected.
59-
- seed: random seed
60+
- key: JAX random key
6061
6162
Returns:
6263
- approximate inverse of A.
6364
"""
64-
key = jax.random.PRNGKey(seed)
65+
if key is None:
66+
key = jax.random.PRNGKey(0)
6567
ts = jnp.arange(burnin, burnin + num_samples) * dt
6668
b = jnp.zeros(A.shape[0])
6769
x0 = jnp.zeros_like(b)
@@ -74,7 +76,7 @@ def expnegm(
7476
num_samples: int = 10000,
7577
dt: float = 1.0,
7678
burnin: int = 0,
77-
seed: int = 0,
79+
key: Array = None,
7880
alpha: float = 0.0,
7981
) -> Array:
8082
"""
@@ -87,18 +89,19 @@ def expnegm(
8789
- num_samples: float, number of samples to be collected.
8890
- dt: float, time step.
8991
- burnin: burn-in, steps before which samples are not collected.
90-
- seed: random seed
92+
- key: JAX random key
9193
- alpha: float, regularization parameter to ensure diffusion matrix
9294
is symmetric positive definite.
9395
9496
Returns:
9597
- approximate negative matrix exponential, exp(-A).
9698
"""
99+
if key is None:
100+
key = jax.random.PRNGKey(0)
101+
97102
A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
98103
B = A_shifted + A_shifted.T
99104

100-
key = jax.random.PRNGKey(seed)
101-
102105
ts = jnp.arange(burnin, burnin + num_samples) * dt
103106
b = jnp.zeros(A.shape[0])
104107
x0 = jnp.zeros_like(b)
@@ -111,7 +114,7 @@ def expm(
111114
num_samples: int = 10000,
112115
dt: float = 1.0,
113116
burnin: int = 0,
114-
seed: int = 0,
117+
key: Array = None,
115118
alpha: float = 1.0,
116119
) -> Array:
117120
"""
@@ -124,14 +127,14 @@ def expm(
124127
- num_samples: float, number of samples to be collected.
125128
- dt: float, time step.
126129
- burnin: burn-in, steps before which samples are not collected.
127-
- seed: random seed
130+
- key: JAX random key
128131
- alpha: float, regularization parameter to ensure diffusion matrix
129132
is symmetric positive definite.
130133
131134
Returns:
132135
- approximate matrix exponential, exp(A).
133136
"""
134-
return expnegm(-A, num_samples, dt, burnin, seed, alpha)
137+
return expnegm(-A, num_samples, dt, burnin, key, alpha)
135138

136139

137140
def autocovariance(samples: Array) -> Array:

0 commit comments

Comments
 (0)