From 1fc956ca9398f55079162fff5f6dad47aa1cc0ce Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Tue, 9 Apr 2024 07:13:10 +0000 Subject: [PATCH 1/5] rename function to flag internal use --- thermox/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thermox/sampler.py b/thermox/sampler.py index 29004da..37715a9 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -11,7 +11,7 @@ ) -def sample_identity_diffusion( +def _sample_identity_diffusion( key: Array, ts: Array, x0: Array, @@ -117,5 +117,5 @@ def sample( y0 = D.sqrt_inv @ x0 b_y = D.sqrt_inv @ b - ys = sample_identity_diffusion(key, ts, y0, A_y, b_y) + ys = _sample_identity_diffusion(key, ts, y0, A_y, b_y) return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys) From 39bb6068bcff3bc8a38fe3b1d7fb436a0c0d9b4a Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Tue, 9 Apr 2024 08:37:05 +0000 Subject: [PATCH 2/5] revert function name change and add `A_spd` kwarg --- thermox/log_prob.py | 11 +++++++---- thermox/sampler.py | 12 +++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/thermox/log_prob.py b/thermox/log_prob.py index a9ee351..c764a63 100644 --- a/thermox/log_prob.py +++ b/thermox/log_prob.py @@ -15,6 +15,7 @@ def log_prob_identity_diffusion( xs: Array, A: Array | ProcessedDriftMatrix, b: Array, + A_spd: bool = False, ) -> float: """Calculates log probability of samples from the Ornstein-Uhlenbeck process, defined as: @@ -32,12 +33,17 @@ def log_prob_identity_diffusion( - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + - A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A. + If false uses jax.scipy.linalg.eig. + jax.linalg.eigh supports gradients but assumes A is Hermitian + (i.e. real symmetric). + See https://github.com/google/jax/issues/2748 Returns: - log probability of given xs. """ if isinstance(A, Array): - A = preprocess_drift_matrix(A) + A = preprocess_drift_matrix(A, A_spd) def expm_vp(v, dt): out = A.eigvecs_inv @ v @@ -122,8 +128,5 @@ def log_prob( b_y = D.sqrt_inv @ b log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, b_y) - # ys = vmap(lambda x: D.sqrt_inv @ (x - b))(xs) - # log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, jnp.zeros_like(b)) - D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv)) return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1) diff --git a/thermox/sampler.py b/thermox/sampler.py index 37715a9..d445773 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -11,12 +11,13 @@ ) -def _sample_identity_diffusion( +def sample_identity_diffusion( key: Array, ts: Array, x0: Array, A: Array | ProcessedDriftMatrix, b: Array, + A_spd: bool = False, ) -> Array: """Collects samples from the Ornstein-Uhlenbeck process, defined as: @@ -33,6 +34,11 @@ def _sample_identity_diffusion( - x0: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + - A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A. + If false uses jax.scipy.linalg.eig. + jax.linalg.eigh supports gradients but assumes A is Hermitian + (i.e. real symmetric). + See https://github.com/google/jax/issues/2748 Returns: - samples: array-like, desired samples. @@ -40,7 +46,7 @@ def _sample_identity_diffusion( """ if isinstance(A, Array): - A = preprocess_drift_matrix(A) + A = preprocess_drift_matrix(A, A_spd) def expm_vp(v, dt): out = A.eigvecs_inv @ v @@ -117,5 +123,5 @@ def sample( y0 = D.sqrt_inv @ x0 b_y = D.sqrt_inv @ b - ys = _sample_identity_diffusion(key, ts, y0, A_y, b_y) + ys = sample_identity_diffusion(key, ts, y0, A_y, b_y) return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys) From f2adf36438626d3c111c99e66d0af67d47e62294 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Tue, 9 Apr 2024 09:19:01 +0000 Subject: [PATCH 3/5] rename `log_prob.py` to `prob.py` --- thermox/__init__.py | 2 +- thermox/{log_prob.py => prob.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename thermox/{log_prob.py => prob.py} (100%) diff --git a/thermox/__init__.py b/thermox/__init__.py index 59122d0..64a3b84 100644 --- a/thermox/__init__.py +++ b/thermox/__init__.py @@ -1,6 +1,6 @@ from thermox import linalg from thermox.sampler import sample -from thermox.log_prob import log_prob +from thermox.prob import log_prob from thermox.utils import preprocess from thermox.utils import ProcessedDriftMatrix from thermox.utils import ProcessedDiffusionMatrix diff --git a/thermox/log_prob.py b/thermox/prob.py similarity index 100% rename from thermox/log_prob.py rename to thermox/prob.py From d6accd1971fe81242254d5a2692509b3b015ca98 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Tue, 9 Apr 2024 09:20:31 +0000 Subject: [PATCH 4/5] unify docstring for `A_spd` --- thermox/prob.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/thermox/prob.py b/thermox/prob.py index c764a63..3508127 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -110,9 +110,11 @@ def log_prob( - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. - D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). - - A_spd: bool, whether A is symmetric positive definite. - gradients (via jax.linalg.eigh) only supported if A is - symmetric positive definite. + - A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A. + If false uses jax.scipy.linalg.eig. + jax.linalg.eigh supports gradients but assumes A is Hermitian + (i.e. real symmetric). + See https://github.com/google/jax/issues/2748 Returns: - log probability of given xs. From 6ff3748342c6a8a23707ba7cbb5f85061e116497 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Tue, 9 Apr 2024 11:49:17 +0200 Subject: [PATCH 5/5] add workflow for pre-commit and test (#11) * add workflow for pre-commit and test * split dependencies into test and dev * install dev depencies in the workflow * simplify pyproject.toml --- .github/workflows/tests.yaml | 39 ++++++++++++++++++++++++++++++++++++ pyproject.toml | 9 +-------- 2 files changed, 40 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/tests.yaml diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..dfd8712 --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,39 @@ +name: Tests + +on: + pull_request: + branches: [main] + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + - uses: pre-commit/action@v3.0.1 + + tests: + name: Run tests for Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + needs: + - pre-commit + strategy: + matrix: + python-version: ['3.10'] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dev environment + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Run the tests with pytest + run: | + python -m pytest --cov=thermox --cov-report term-missing \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index eb59414..32a55f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,14 +18,7 @@ classifiers = [ dependencies = ["jax>=0.4.0", "jaxlib>=0.4.0"] [project.optional-dependencies] -dev = [ - 'pytest-cov', - 'pytest', - 'optax', - 'mypy', - 'pre-commit', - 'ruff', -] +test = ["pre-commit", "pytest-cov", "ruff", "optax", "mypy"] [tool.setuptools] packages = ["thermox"]