Skip to content

Commit

Permalink
Merge pull request #13 from normal-computing/seed
Browse files Browse the repository at this point in the history
Change seed to random key for linalg
  • Loading branch information
SamDuffield authored Apr 19, 2024
2 parents 91016d7 + 4c0c5bb commit 23d395c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
10 changes: 4 additions & 6 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,30 @@ def test_linear_system():
A = jnp.array([[3, 2], [2, 4.0]])
b = jnp.array([1, 2.0])

x = thermox.linalg.solve(A, b, num_samples=10000, dt=0.1, burnin=0, seed=0)
x = thermox.linalg.solve(A, b, num_samples=10000, dt=0.1, burnin=0)

assert jnp.allclose(A @ x, b, atol=1e-1)


def test_inv():
A = jnp.array([[3, 2], [2, 4.0]])

A_inv = thermox.linalg.inv(A, num_samples=10000, dt=0.1, burnin=0, seed=0)
A_inv = thermox.linalg.inv(A, num_samples=10000, dt=0.1, burnin=0)

assert jnp.allclose(A @ A_inv, jnp.eye(2), atol=1e-1)


def test_expnegm():
A = jnp.array([[3, 2], [2, 4.0]])

expnegA = thermox.linalg.expnegm(A, num_samples=10000, dt=0.1, burnin=0, seed=0)
expnegA = thermox.linalg.expnegm(A, num_samples=10000, dt=0.1, burnin=0)

assert jnp.allclose(expnegA, jax.scipy.linalg.expm(-A), atol=1e-1)


def test_expm():
A = jnp.array([[-0.4, 0.1], [0.5, -0.3]])

expA = thermox.linalg.expm(
A, num_samples=100000, dt=0.1, burnin=0, seed=0, alpha=1.0
)
expA = thermox.linalg.expm(A, num_samples=100000, dt=0.1, burnin=0, alpha=1.0)

assert jnp.allclose(expA, jax.scipy.linalg.expm(A), atol=1e-1)
29 changes: 16 additions & 13 deletions thermox/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def solve(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
seed: int = 0,
key: Array = None,
) -> Array:
"""
Obtain the solution of the linear system
Expand All @@ -27,12 +27,13 @@ def solve(
- num_samples: float, number of samples to be collected.
- dt: float, time step.
- burnin: burn-in, steps before which samples are not collected.
- seed: random seed
- key: JAX random key
Returns:
- approximate solution, x, of the linear system.
"""
key = jax.random.PRNGKey(seed)
if key is None:
key = jax.random.PRNGKey(0)
ts = jnp.arange(burnin, burnin + num_samples) * dt
x0 = jnp.zeros_like(b)
samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b))
Expand All @@ -44,7 +45,7 @@ def inv(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
seed: int = 0,
key: Array = None,
) -> Array:
"""
Obtain the inverse of a matrix A by
Expand All @@ -56,12 +57,13 @@ def inv(
- num_samples: float, number of samples to be collected.
- dt: float, time step.
- burnin: burn-in, steps before which samples are not collected.
- seed: random seed
- key: JAX random key
Returns:
- approximate inverse of A.
"""
key = jax.random.PRNGKey(seed)
if key is None:
key = jax.random.PRNGKey(0)
ts = jnp.arange(burnin, burnin + num_samples) * dt
b = jnp.zeros(A.shape[0])
x0 = jnp.zeros_like(b)
Expand All @@ -74,7 +76,7 @@ def expnegm(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
seed: int = 0,
key: Array = None,
alpha: float = 0.0,
) -> Array:
"""
Expand All @@ -87,18 +89,19 @@ def expnegm(
- num_samples: float, number of samples to be collected.
- dt: float, time step.
- burnin: burn-in, steps before which samples are not collected.
- seed: random seed
- key: JAX random key
- alpha: float, regularization parameter to ensure diffusion matrix
is symmetric positive definite.
Returns:
- approximate negative matrix exponential, exp(-A).
"""
if key is None:
key = jax.random.PRNGKey(0)

A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
B = A_shifted + A_shifted.T

key = jax.random.PRNGKey(seed)

ts = jnp.arange(burnin, burnin + num_samples) * dt
b = jnp.zeros(A.shape[0])
x0 = jnp.zeros_like(b)
Expand All @@ -111,7 +114,7 @@ def expm(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
seed: int = 0,
key: Array = None,
alpha: float = 1.0,
) -> Array:
"""
Expand All @@ -124,14 +127,14 @@ def expm(
- num_samples: float, number of samples to be collected.
- dt: float, time step.
- burnin: burn-in, steps before which samples are not collected.
- seed: random seed
- key: JAX random key
- alpha: float, regularization parameter to ensure diffusion matrix
is symmetric positive definite.
Returns:
- approximate matrix exponential, exp(A).
"""
return expnegm(-A, num_samples, dt, burnin, seed, alpha)
return expnegm(-A, num_samples, dt, burnin, key, alpha)


def autocovariance(samples: Array) -> Array:
Expand Down

0 comments on commit 23d395c

Please sign in to comment.