Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing the most-likely heteroskedastic GP described in #180 #250

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion botorch/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,24 @@
"""

import warnings
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Any

import torch
from gpytorch.utils.broadcasting import _mul_broadcast_shape
from torch import Tensor

from ..exceptions import InputDataError, InputDataWarning

from .models import SingleTaskGP
from .models import HeteroskedasticSingleTaskGP
from ..sampling import IIDNormalSampler
from gpytorch.constraints import GreaterThan
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.module import Module

from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.kernels.rbf_kernel import RBFKernel


def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor:
r"""Helper to construct input tensor with task indices.
Expand Down Expand Up @@ -179,3 +189,125 @@ def check_standardization(
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)


def fit_most_likely_HeteroskedasticGP(
train_X: Tensor,
train_Y: Tensor,
covar_module: Optional[Module] = None,
num_var_samples: int = 100,
max_iter: int = 10,
atol_mean: float = 1e-04,
atol_var: float = 1e-04,
) -> HeteroskedasticSingleTaskGP:
r"""Fit the Most Likely Heteroskedastic GP.

The original algorithm is described in
http://people.csail.mit.edu/kersting/papers/kersting07icml_mlHetGP.pdf

Args:
train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of training
features.
train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of
training observations.
covar_module: The covariance (kernel) matrix for the initial homoskedastic GP.
If omitted, use the RBFKernel.
num_var_samples: Number of samples to draw from posterior when estimating noise.
max_iter: Maximum number of iterations used when fitting the model.
atol_mean: The tolerance for the mean check.
atol_std: The tolerance for the var check.
Returns:
HeteroskedasticSingleTaskGP Model fit using the "most-likely" procedure.
"""

if covar_module is None:
covar_module = ScaleKernel(RBFKernel())

# CANNOT CHECK RIGHT NOW BECAUSE NEED TO FIRST ADD BATCH DIMENSION
# check to see if input Tensors are normalized and standardized
# check_min_max_scaling(train_X)
# check_standardization(train_Y)

# fit initial homoskedastic model used to estimate noise levels
homo_model = SingleTaskGP(
train_X=train_X, train_Y=train_Y, covar_module=covar_module
)
homo_model.likelihood.noise_covar.register_constraint(
"raw_noise", GreaterThan(1e-5)
)
homo_mll = gpytorch.mlls.ExactMarginalLogLikelihood(
homo_model.likelihood, homo_model
)
botorch.fit.fit_gpytorch_model(homo_mll)

# get estimates of noise
homo_mll.eval()
with torch.no_grad():
homo_posterior = homo_mll.model.posterior(train_X.clone())
homo_predictive_posterior = homo_mll.model.posterior(
train_X.clone(), observation_noise=True
)
sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True)
predictive_samples = sampler(homo_predictive_posterior)
observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1, 1)) ** 2).mean(
dim=0
)

# save mean and variance to check if they change later
saved_mean = homo_posterior.mean
saved_var = homo_posterior.variance

for i in range(max_iter):

# now train hetero model using computed noise
hetero_model = HeteroskedasticSingleTaskGP(
train_X=train_X, train_Y=train_Y, train_Yvar=observed_var
)
hetero_mll = gpytorch.mlls.ExactMarginalLogLikelihood(
hetero_model.likelihood, hetero_model
)
try:
botorch.fit.fit_gpytorch_model(hetero_mll)
except Exception as e:
msg = f"Fitting failed on iteration {i}. Returning the current MLL"
warnings.warn(msg, e)
return saved_hetero_mll

hetero_mll.eval()
with torch.no_grad():
hetero_posterior = hetero_mll.model.posterior(train_X.clone())
hetero_predictive_posterior = hetero_mll.model.posterior(
train_X.clone(), observation_noise=True
)

new_mean = hetero_posterior.mean
new_var = hetero_posterior.variance

mean_equality = torch.all(
torch.lt(torch.abs(torch.add(saved_mean, -new_mean)), atol_mean)
)
max_change_in_means = torch.max(torch.abs(torch.add(saved_mean, -new_mean)))

var_equality = torch.all(
torch.lt(torch.abs(torch.add(saved_var, -new_var)), atol_var)
)
max_change_in_var = torch.max(torch.abs(torch.add(saved_var, -new_var)))

if mean_equality and var_equality:
return hetero_mll
else:
saved_hetero_mll = hetero_mll

saved_mean = new_mean
saved_var = new_var

# get new noise estimate
sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True)
predictive_samples = sampler(hetero_predictive_posterior)
observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1, 1)) ** 2).mean(
dim=0
)

msg = f"Did not reach convergence after {max_iter} iterations. Returning the current MLL."
warnings.warn(msg)
return hetero_mll