Skip to content

Commit 91016d7

Browse files
authored
rename function to flag internal use (#8)
* rename function to flag internal use * revert function name change and add `A_spd` kwarg * rename `log_prob.py` to `prob.py` * unify docstring for `A_spd` * add workflow for pre-commit and test (#11) * add workflow for pre-commit and test * split dependencies into test and dev * install dev depencies in the workflow * simplify pyproject.toml
1 parent 422f159 commit 91016d7

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

thermox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from thermox import linalg
22
from thermox.sampler import sample
3-
from thermox.log_prob import log_prob
3+
from thermox.prob import log_prob
44
from thermox.utils import preprocess
55
from thermox.utils import ProcessedDriftMatrix
66
from thermox.utils import ProcessedDiffusionMatrix

thermox/log_prob.py renamed to thermox/prob.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def log_prob_identity_diffusion(
1515
xs: Array,
1616
A: Array | ProcessedDriftMatrix,
1717
b: Array,
18+
A_spd: bool = False,
1819
) -> float:
1920
"""Calculates log probability of samples from the Ornstein-Uhlenbeck process,
2021
defined as:
@@ -32,12 +33,17 @@ def log_prob_identity_diffusion(
3233
- xs: initial state of the process.
3334
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
3435
- b: drift displacement vector.
36+
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
37+
If false uses jax.scipy.linalg.eig.
38+
jax.linalg.eigh supports gradients but assumes A is Hermitian
39+
(i.e. real symmetric).
40+
See https://github.com/google/jax/issues/2748
3541
3642
Returns:
3743
- log probability of given xs.
3844
"""
3945
if isinstance(A, Array):
40-
A = preprocess_drift_matrix(A)
46+
A = preprocess_drift_matrix(A, A_spd)
4147

4248
def expm_vp(v, dt):
4349
out = A.eigvecs_inv @ v
@@ -104,9 +110,11 @@ def log_prob(
104110
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
105111
- b: drift displacement vector.
106112
- D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
107-
- A_spd: bool, whether A is symmetric positive definite.
108-
gradients (via jax.linalg.eigh) only supported if A is
109-
symmetric positive definite.
113+
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
114+
If false uses jax.scipy.linalg.eig.
115+
jax.linalg.eigh supports gradients but assumes A is Hermitian
116+
(i.e. real symmetric).
117+
See https://github.com/google/jax/issues/2748
110118
111119
Returns:
112120
- log probability of given xs.
@@ -122,8 +130,5 @@ def log_prob(
122130
b_y = D.sqrt_inv @ b
123131
log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, b_y)
124132

125-
# ys = vmap(lambda x: D.sqrt_inv @ (x - b))(xs)
126-
# log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, jnp.zeros_like(b))
127-
128133
D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv))
129134
return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1)

thermox/sampler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def sample_identity_diffusion(
1717
x0: Array,
1818
A: Array | ProcessedDriftMatrix,
1919
b: Array,
20+
A_spd: bool = False,
2021
) -> Array:
2122
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
2223
@@ -33,14 +34,19 @@ def sample_identity_diffusion(
3334
- x0: initial state of the process.
3435
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
3536
- b: drift displacement vector.
37+
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
38+
If false uses jax.scipy.linalg.eig.
39+
jax.linalg.eigh supports gradients but assumes A is Hermitian
40+
(i.e. real symmetric).
41+
See https://github.com/google/jax/issues/2748
3642
3743
Returns:
3844
- samples: array-like, desired samples.
3945
shape: (len(ts), ) + x0.shape
4046
"""
4147

4248
if isinstance(A, Array):
43-
A = preprocess_drift_matrix(A)
49+
A = preprocess_drift_matrix(A, A_spd)
4450

4551
def expm_vp(v, dt):
4652
out = A.eigvecs_inv @ v

0 commit comments

Comments
 (0)