Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename function to flag internal use #8

Merged
merged 5 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Tests

on:
pull_request:
branches: [main]
push:
branches: [main]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- uses: pre-commit/[email protected]

tests:
name: Run tests for Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
needs:
- pre-commit
strategy:
matrix:
python-version: ['3.10']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dev environment
run: |
python -m pip install --upgrade pip
pip install .[test]
- name: Run the tests with pytest
run: |
python -m pytest --cov=thermox --cov-report term-missing
9 changes: 1 addition & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,7 @@ classifiers = [
dependencies = ["jax>=0.4.0", "jaxlib>=0.4.0"]

[project.optional-dependencies]
dev = [
'pytest-cov',
'pytest',
'optax',
'mypy',
'pre-commit',
'ruff',
]
test = ["pre-commit", "pytest-cov", "ruff", "optax", "mypy"]

[tool.setuptools]
packages = ["thermox"]
Expand Down
2 changes: 1 addition & 1 deletion thermox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from thermox import linalg
from thermox.sampler import sample
from thermox.log_prob import log_prob
from thermox.prob import log_prob
from thermox.utils import preprocess
from thermox.utils import ProcessedDriftMatrix
from thermox.utils import ProcessedDiffusionMatrix
19 changes: 12 additions & 7 deletions thermox/log_prob.py → thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def log_prob_identity_diffusion(
xs: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
A_spd: bool = False,
) -> float:
"""Calculates log probability of samples from the Ornstein-Uhlenbeck process,
defined as:
Expand All @@ -32,12 +33,17 @@ def log_prob_identity_diffusion(
- xs: initial state of the process.
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
- b: drift displacement vector.
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
If false uses jax.scipy.linalg.eig.
jax.linalg.eigh supports gradients but assumes A is Hermitian
(i.e. real symmetric).
See https://github.com/google/jax/issues/2748

Returns:
- log probability of given xs.
"""
if isinstance(A, Array):
A = preprocess_drift_matrix(A)
A = preprocess_drift_matrix(A, A_spd)

def expm_vp(v, dt):
out = A.eigvecs_inv @ v
Expand Down Expand Up @@ -104,9 +110,11 @@ def log_prob(
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
- b: drift displacement vector.
- D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
- A_spd: bool, whether A is symmetric positive definite.
gradients (via jax.linalg.eigh) only supported if A is
symmetric positive definite.
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
If false uses jax.scipy.linalg.eig.
jax.linalg.eigh supports gradients but assumes A is Hermitian
(i.e. real symmetric).
See https://github.com/google/jax/issues/2748

Returns:
- log probability of given xs.
Expand All @@ -122,8 +130,5 @@ def log_prob(
b_y = D.sqrt_inv @ b
log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, b_y)

# ys = vmap(lambda x: D.sqrt_inv @ (x - b))(xs)
# log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, jnp.zeros_like(b))

D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv))
return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1)
8 changes: 7 additions & 1 deletion thermox/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def sample_identity_diffusion(
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
A_spd: bool = False,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:

Expand All @@ -33,14 +34,19 @@ def sample_identity_diffusion(
- x0: initial state of the process.
- A: drift matrix (Array or thermox.ProcessedDriftMatrix).
- b: drift displacement vector.
- A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A.
If false uses jax.scipy.linalg.eig.
jax.linalg.eigh supports gradients but assumes A is Hermitian
(i.e. real symmetric).
See https://github.com/google/jax/issues/2748

Returns:
- samples: array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""

if isinstance(A, Array):
A = preprocess_drift_matrix(A)
A = preprocess_drift_matrix(A, A_spd)

def expm_vp(v, dt):
out = A.eigvecs_inv @ v
Expand Down