Skip to content

Commit

Permalink
rename function to flag internal use (#8)
Browse files Browse the repository at this point in the history
* rename function to flag internal use

* revert function name change and add `A_spd` kwarg

* rename `log_prob.py` to `prob.py`

* unify docstring for `A_spd`

* add workflow for pre-commit and test (#11)

* add workflow for pre-commit and test

* split dependencies into test and dev

* install dev depencies in the workflow

* simplify pyproject.toml
  • Loading branch information
KaelanDt authored Apr 9, 2024
1 parent 422f159 commit 91016d7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion thermox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from thermox import linalg
from thermox.sampler import sample
from thermox.log_prob import log_prob
from thermox.prob import log_prob
from thermox.utils import preprocess
from thermox.utils import ProcessedDriftMatrix
from thermox.utils import ProcessedDiffusionMatrix
19 changes: 12 additions & 7 deletions thermox/log_prob.py → thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def log_prob_identity_diffusion(
xs: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
A_spd: bool = False,
) -> float:
"""Calculates log probability of samples from the Ornstein-Uhlenbeck process,
defined as:
Expand All @@ -32,12 +33,17 @@ def log_prob_identity_diffusion(
- xs: initial state of the process.
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
- b: drift displacement vector.
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
If false uses jax.scipy.linalg.eig.
jax.linalg.eigh supports gradients but assumes A is Hermitian
(i.e. real symmetric).
See https://github.com/google/jax/issues/2748
Returns:
- log probability of given xs.
"""
if isinstance(A, Array):
A = preprocess_drift_matrix(A)
A = preprocess_drift_matrix(A, A_spd)

def expm_vp(v, dt):
out = A.eigvecs_inv @ v
Expand Down Expand Up @@ -104,9 +110,11 @@ def log_prob(
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
- b: drift displacement vector.
- D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
- A_spd: bool, whether A is symmetric positive definite.
gradients (via jax.linalg.eigh) only supported if A is
symmetric positive definite.
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
If false uses jax.scipy.linalg.eig.
jax.linalg.eigh supports gradients but assumes A is Hermitian
(i.e. real symmetric).
See https://github.com/google/jax/issues/2748
Returns:
- log probability of given xs.
Expand All @@ -122,8 +130,5 @@ def log_prob(
b_y = D.sqrt_inv @ b
log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, b_y)

# ys = vmap(lambda x: D.sqrt_inv @ (x - b))(xs)
# log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, jnp.zeros_like(b))

D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv))
return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1)
8 changes: 7 additions & 1 deletion thermox/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def sample_identity_diffusion(
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
A_spd: bool = False,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
Expand All @@ -33,14 +34,19 @@ def sample_identity_diffusion(
- x0: initial state of the process.
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
- b: drift displacement vector.
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
If false uses jax.scipy.linalg.eig.
jax.linalg.eigh supports gradients but assumes A is Hermitian
(i.e. real symmetric).
See https://github.com/google/jax/issues/2748
Returns:
- samples: array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""

if isinstance(A, Array):
A = preprocess_drift_matrix(A)
A = preprocess_drift_matrix(A, A_spd)

def expm_vp(v, dt):
out = A.eigvecs_inv @ v
Expand Down

0 comments on commit 91016d7

Please sign in to comment.