diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..dfd8712 --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,39 @@ +name: Tests + +on: + pull_request: + branches: [main] + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + - uses: pre-commit/action@v3.0.1 + + tests: + name: Run tests for Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + needs: + - pre-commit + strategy: + matrix: + python-version: ['3.10'] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dev environment + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Run the tests with pytest + run: | + python -m pytest --cov=thermox --cov-report term-missing \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index eb59414..32a55f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,14 +18,7 @@ classifiers = [ dependencies = ["jax>=0.4.0", "jaxlib>=0.4.0"] [project.optional-dependencies] -dev = [ - 'pytest-cov', - 'pytest', - 'optax', - 'mypy', - 'pre-commit', - 'ruff', -] +test = ["pre-commit", "pytest-cov", "ruff", "optax", "mypy"] [tool.setuptools] packages = ["thermox"] diff --git a/thermox/__init__.py b/thermox/__init__.py index 59122d0..64a3b84 100644 --- a/thermox/__init__.py +++ b/thermox/__init__.py @@ -1,6 +1,6 @@ from thermox import linalg from thermox.sampler import sample -from thermox.log_prob import log_prob +from thermox.prob import log_prob from thermox.utils import preprocess from thermox.utils import ProcessedDriftMatrix from thermox.utils import ProcessedDiffusionMatrix diff --git a/thermox/log_prob.py b/thermox/prob.py similarity index 84% rename from thermox/log_prob.py rename to thermox/prob.py index a9ee351..3508127 100644 --- a/thermox/log_prob.py +++ b/thermox/prob.py @@ -15,6 +15,7 @@ def log_prob_identity_diffusion( xs: Array, A: Array | ProcessedDriftMatrix, b: Array, + A_spd: bool = False, ) -> float: """Calculates log probability of samples from the Ornstein-Uhlenbeck process, defined as: @@ -32,12 +33,17 @@ def log_prob_identity_diffusion( - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + - A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A. + If false uses jax.scipy.linalg.eig. + jax.linalg.eigh supports gradients but assumes A is Hermitian + (i.e. real symmetric). + See https://github.com/google/jax/issues/2748 Returns: - log probability of given xs. """ if isinstance(A, Array): - A = preprocess_drift_matrix(A) + A = preprocess_drift_matrix(A, A_spd) def expm_vp(v, dt): out = A.eigvecs_inv @ v @@ -104,9 +110,11 @@ def log_prob( - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. - D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). - - A_spd: bool, whether A is symmetric positive definite. - gradients (via jax.linalg.eigh) only supported if A is - symmetric positive definite. + - A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A. + If false uses jax.scipy.linalg.eig. + jax.linalg.eigh supports gradients but assumes A is Hermitian + (i.e. real symmetric). + See https://github.com/google/jax/issues/2748 Returns: - log probability of given xs. @@ -122,8 +130,5 @@ def log_prob( b_y = D.sqrt_inv @ b log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, b_y) - # ys = vmap(lambda x: D.sqrt_inv @ (x - b))(xs) - # log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, jnp.zeros_like(b)) - D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv)) return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1) diff --git a/thermox/sampler.py b/thermox/sampler.py index 29004da..d445773 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -17,6 +17,7 @@ def sample_identity_diffusion( x0: Array, A: Array | ProcessedDriftMatrix, b: Array, + A_spd: bool = False, ) -> Array: """Collects samples from the Ornstein-Uhlenbeck process, defined as: @@ -33,6 +34,11 @@ def sample_identity_diffusion( - x0: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + - A_spd: if true uses jax.linalg.eigh to calculate eigendecomposition of A. + If false uses jax.scipy.linalg.eig. + jax.linalg.eigh supports gradients but assumes A is Hermitian + (i.e. real symmetric). + See https://github.com/google/jax/issues/2748 Returns: - samples: array-like, desired samples. @@ -40,7 +46,7 @@ def sample_identity_diffusion( """ if isinstance(A, Array): - A = preprocess_drift_matrix(A) + A = preprocess_drift_matrix(A, A_spd) def expm_vp(v, dt): out = A.eigvecs_inv @ v