diff --git a/tests/test_conditional.py b/tests/test_conditional.py new file mode 100644 index 0000000..07fba84 --- /dev/null +++ b/tests/test_conditional.py @@ -0,0 +1,32 @@ +import jax +from jax import numpy as jnp + +import thermox + + +def test_mean_and_cov(): + jax.config.update("jax_enable_x64", True) + dim = 2 + t = 1.0 + + A = jnp.array([[3, 2.5], [2, 4.0]]) + b = jax.random.normal(jax.random.PRNGKey(1), (dim,)) + x0 = jax.random.normal(jax.random.PRNGKey(2), (dim,)) + D = 2 * jnp.eye(dim) + + mean = thermox.conditional.mean(t, x0, A, b, D) + samples = jax.vmap( + lambda k: thermox.sample(k, jnp.array([0.0, t]), x0, A, b, D)[-1] + )(jax.random.split(jax.random.PRNGKey(0), 1000000)) + assert mean.shape == (dim,) + assert jnp.allclose(mean, jnp.mean(samples, axis=0), atol=1e-2) + + cov = thermox.conditional.covariance(t, A, D) + assert cov.shape == (dim, dim) + assert jnp.allclose(cov, jnp.cov(samples.T), atol=1e-3) + + mean_and_cov = thermox.conditional.mean_and_covariance(t, x0, A, b, D) + assert mean_and_cov[0].shape == (dim,) + assert mean_and_cov[1].shape == (dim, dim) + assert jnp.allclose(mean_and_cov[0], mean, atol=1e-5) + assert jnp.allclose(mean_and_cov[1], cov, atol=1e-5) diff --git a/tests/test_log_prob.py b/tests/test_log_prob.py index f8508d8..03b9747 100644 --- a/tests/test_log_prob.py +++ b/tests/test_log_prob.py @@ -91,33 +91,31 @@ def test_MLE(): D_true = jnp.array([[1, 0.3, -0.1], [0.3, 1, 0.2], [-0.1, 0.2, 1.0]]) nts = 300 - ts = jnp.linspace(0, 10, nts) + ts = jnp.linspace(0, 100, nts) x0 = jnp.zeros_like(b_true) - n_trajecs = 3 + n_trajecs = 5 rks = jax.random.split(jax.random.PRNGKey(0), n_trajecs) samps = jax.vmap(lambda key: thermox.sample(key, ts, x0, A_true, b_true, D_true))( rks ) - A_sqrt_init = jnp.tril(jnp.eye(3) + jax.random.normal(rks[0], (3, 3)) * 1e-1) + A_init = jnp.eye(3) + jax.random.normal(rks[0], (3, 3)) * 1e-1 b_init = jnp.zeros(3) D_sqrt_init = jnp.eye(3) log_prob_true = thermox.log_prob(ts, samps[0], A_true, b_true, D_true) log_prob_init = thermox.log_prob( - ts, samps[0], A_sqrt_init @ A_sqrt_init.T, b_init, D_sqrt_init @ D_sqrt_init.T + ts, samps[0], A_init, b_init, D_sqrt_init @ D_sqrt_init.T ) assert log_prob_true > log_prob_init # Gradient descent def loss(params): - A_sqrt, b, D_sqrt = params - A_sqrt = jnp.tril(A_sqrt) + A, b, D_sqrt = params D_sqrt = jnp.tril(D_sqrt) - A = A_sqrt @ A_sqrt.T D = D_sqrt @ D_sqrt.T return -jax.vmap(lambda s: thermox.log_prob(ts, s, A, b, D))( samps @@ -125,8 +123,8 @@ def loss(params): val_and_g = jax.jit(jax.value_and_grad(loss)) - ps = (A_sqrt_init, b_init, D_sqrt_init) - ps_true = (jnp.linalg.cholesky(A_true), b_true, jnp.linalg.cholesky(D_true)) + ps = (A_init, b_init, D_sqrt_init) + ps_true = (A_true, b_true, jnp.linalg.cholesky(D_true)) v, g = val_and_g(ps) v_true, g_true = val_and_g(ps_true) @@ -138,7 +136,7 @@ def loss(params): n_steps = 20000 neg_log_probs = jnp.zeros(n_steps) - optimizer = optax.adam(1e-2) + optimizer = optax.adam(1e-3) opt_state = optimizer.init(ps) for i in range(n_steps): @@ -149,7 +147,7 @@ def loss(params): ps = optax.apply_updates(ps, updates) neg_log_probs = neg_log_probs.at[i].set(neg_log_prob) - A_recover = ps[0] @ ps[0].T + A_recover = ps[0] b_recover = ps[1] D_recover = ps[2] @ ps[2].T diff --git a/thermox/__init__.py b/thermox/__init__.py index 64a3b84..175d672 100644 --- a/thermox/__init__.py +++ b/thermox/__init__.py @@ -1,4 +1,5 @@ from thermox import linalg +from thermox import conditional from thermox.sampler import sample from thermox.prob import log_prob from thermox.utils import preprocess diff --git a/thermox/conditional.py b/thermox/conditional.py new file mode 100644 index 0000000..330a2c0 --- /dev/null +++ b/thermox/conditional.py @@ -0,0 +1,98 @@ +from jax import numpy as jnp +from jax import Array + +from thermox.utils import ( + ProcessedDriftMatrix, + ProcessedDiffusionMatrix, + handle_matrix_inputs, +) +from thermox.sampler import expm_vp + + +def mean( + t: float, + x0: Array, + A: Array | ProcessedDriftMatrix, + b: Array, + D: Array | ProcessedDiffusionMatrix, +) -> Array: + """Computes the mean of p(x_t | x_0) + + For x_t evolving according to the SDE: + + dx = - A * (x - b) dt + sqrt(D) dW + + Args: + ts: Times at which samples are collected. Includes time for x0. + x0: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + Note: If a thermox.ProcessedDriftMatrix instance is used as input, + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils.preprocess_drift_matrix. + b: Drift displacement vector. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + + """ + A_y, D = handle_matrix_inputs(A, D) + + y0 = D.sqrt_inv @ (x0 - b) + return b + D.sqrt @ expm_vp(A_y, y0, t) + + +def covariance( + t: float, + A: Array | ProcessedDriftMatrix, + D: Array | ProcessedDiffusionMatrix, +) -> Array: + """Computes the covariance of p(x_t | x_0) + + For x evolving according to the SDE: + + dx = - A * (x - b) dt + sqrt(D) dW + + Args: + ts: Times at which samples are collected. Includes time for x0. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + Note: If a thermox.ProcessedDriftMatrix instance is used as input, + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils.preprocess_drift_matrix. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + """ + A_y, D = handle_matrix_inputs(A, D) + + identity_diffusion_cov = ( + A_y.sym_eigvecs + @ jnp.diag((1 - jnp.exp(-2 * A_y.sym_eigvals * t)) / (2 * A_y.sym_eigvals)) + @ A_y.sym_eigvecs.T + ) + return D.sqrt @ identity_diffusion_cov @ D.sqrt.T + + +def mean_and_covariance( + t: float, + x0: Array, + A: Array | ProcessedDriftMatrix, + b: Array, + D: Array | ProcessedDiffusionMatrix, +) -> tuple[Array, Array]: + """Computes the mean and covariance of p(x_t | x_0) + + For x evolving according to the SDE: + + dx = - A * (x - b) dt + sqrt(D) dW + + Args: + ts: Times at which samples are collected. Includes time for x0. + x0: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + Note: If a thermox.ProcessedDriftMatrix instance is used as input, + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils.preprocess_drift_matrix. + b: Drift displacement vector. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + + """ + A, D = handle_matrix_inputs(A, D) + mean_val = mean(t, x0, A, b, D) + covariance_val = covariance(t, A, D) + return mean_val, covariance_val diff --git a/thermox/prob.py b/thermox/prob.py index 5fca143..b0364e9 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -36,7 +36,7 @@ def log_prob( Args: ts: Times at which samples are collected. Includes time for x0. - xs: Initial state of the process. + xs: States of the process. A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note: If a thermox.ProcessedDriftMatrix instance is used as input, must be the transformed drift matrix, A_y, given by thermox.preprocess,