Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Cholesky decomp instead of inverting kernel #1688

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,28 @@ def run_inference(model, args, rng_key, X, Y):


# do GP prediction for a given set of hyperparameters. this makes use of the well-known
# formula for gaussian process predictions
def predict(rng_key, X, Y, X_test, var, length, noise):
# formula for Gaussian process predictions
def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
K = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))

# since K_xx is symmetric positive-definite, we can use the more efficient and
# stable Cholesky decomposition instead of matrix inversion
if use_cholesky:
K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
K = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))
mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, Y))
else:
K_xx_inv = jnp.linalg.inv(k_XX)
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
rng_key, X_test.shape[:1]
)
mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, Y))

# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean, mean + sigma_noise
Expand Down Expand Up @@ -148,7 +158,7 @@ def main(args):
)
means, predictions = vmap(
lambda rng_key, var, length, noise: predict(
rng_key, X, Y, X_test, var, length, noise
rng_key, X, Y, X_test, var, length, noise, use_cholesky=args.use_cholesky
)
)(*vmap_args)

Expand Down Expand Up @@ -184,6 +194,7 @@ def main(args):
type=str,
choices=["median", "feasible", "value", "uniform", "sample"],
)
parser.add_argument("--no_cholesky", dest="use_cholesky", action="store_false")
DanWaxman marked this conversation as resolved.
Show resolved Hide resolved
args = parser.parse_args()

numpyro.set_platform(args.device)
Expand Down