diff --git a/botorch/models/utils.py b/botorch/models/utils.py index e1bec1d9ed..ef55795a44 100644 --- a/botorch/models/utils.py +++ b/botorch/models/utils.py @@ -7,7 +7,7 @@ """ import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Any import torch from gpytorch.utils.broadcasting import _mul_broadcast_shape @@ -15,6 +15,16 @@ from ..exceptions import InputDataError, InputDataWarning +from .models import SingleTaskGP +from .models import HeteroskedasticSingleTaskGP +from ..sampling import IIDNormalSampler +from gpytorch.constraints import GreaterThan +from gpytorch.mlls import ExactMarginalLogLikelihood +from gpytorch.module import Module + +from gpytorch.kernels.scale_kernel import ScaleKernel +from gpytorch.kernels.rbf_kernel import RBFKernel + def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor: r"""Helper to construct input tensor with task indices. @@ -179,3 +189,125 @@ def check_standardization( if raise_on_fail: raise InputDataError(msg) warnings.warn(msg, InputDataWarning) + + +def fit_most_likely_HeteroskedasticGP( + train_X: Tensor, + train_Y: Tensor, + covar_module: Optional[Module] = None, + num_var_samples: int = 100, + max_iter: int = 10, + atol_mean: float = 1e-04, + atol_var: float = 1e-04, +) -> HeteroskedasticSingleTaskGP: + r"""Fit the Most Likely Heteroskedastic GP. + + The original algorithm is described in + http://people.csail.mit.edu/kersting/papers/kersting07icml_mlHetGP.pdf + + Args: + train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of training + features. + train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of + training observations. + covar_module: The covariance (kernel) matrix for the initial homoskedastic GP. + If omitted, use the RBFKernel. + num_var_samples: Number of samples to draw from posterior when estimating noise. + max_iter: Maximum number of iterations used when fitting the model. + atol_mean: The tolerance for the mean check. + atol_std: The tolerance for the var check. + Returns: + HeteroskedasticSingleTaskGP Model fit using the "most-likely" procedure. + """ + + if covar_module is None: + covar_module = ScaleKernel(RBFKernel()) + + # CANNOT CHECK RIGHT NOW BECAUSE NEED TO FIRST ADD BATCH DIMENSION + # check to see if input Tensors are normalized and standardized + # check_min_max_scaling(train_X) + # check_standardization(train_Y) + + # fit initial homoskedastic model used to estimate noise levels + homo_model = SingleTaskGP( + train_X=train_X, train_Y=train_Y, covar_module=covar_module + ) + homo_model.likelihood.noise_covar.register_constraint( + "raw_noise", GreaterThan(1e-5) + ) + homo_mll = gpytorch.mlls.ExactMarginalLogLikelihood( + homo_model.likelihood, homo_model + ) + botorch.fit.fit_gpytorch_model(homo_mll) + + # get estimates of noise + homo_mll.eval() + with torch.no_grad(): + homo_posterior = homo_mll.model.posterior(train_X.clone()) + homo_predictive_posterior = homo_mll.model.posterior( + train_X.clone(), observation_noise=True + ) + sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) + predictive_samples = sampler(homo_predictive_posterior) + observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1, 1)) ** 2).mean( + dim=0 + ) + + # save mean and variance to check if they change later + saved_mean = homo_posterior.mean + saved_var = homo_posterior.variance + + for i in range(max_iter): + + # now train hetero model using computed noise + hetero_model = HeteroskedasticSingleTaskGP( + train_X=train_X, train_Y=train_Y, train_Yvar=observed_var + ) + hetero_mll = gpytorch.mlls.ExactMarginalLogLikelihood( + hetero_model.likelihood, hetero_model + ) + try: + botorch.fit.fit_gpytorch_model(hetero_mll) + except Exception as e: + msg = f"Fitting failed on iteration {i}. Returning the current MLL" + warnings.warn(msg, e) + return saved_hetero_mll + + hetero_mll.eval() + with torch.no_grad(): + hetero_posterior = hetero_mll.model.posterior(train_X.clone()) + hetero_predictive_posterior = hetero_mll.model.posterior( + train_X.clone(), observation_noise=True + ) + + new_mean = hetero_posterior.mean + new_var = hetero_posterior.variance + + mean_equality = torch.all( + torch.lt(torch.abs(torch.add(saved_mean, -new_mean)), atol_mean) + ) + max_change_in_means = torch.max(torch.abs(torch.add(saved_mean, -new_mean))) + + var_equality = torch.all( + torch.lt(torch.abs(torch.add(saved_var, -new_var)), atol_var) + ) + max_change_in_var = torch.max(torch.abs(torch.add(saved_var, -new_var))) + + if mean_equality and var_equality: + return hetero_mll + else: + saved_hetero_mll = hetero_mll + + saved_mean = new_mean + saved_var = new_var + + # get new noise estimate + sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) + predictive_samples = sampler(hetero_predictive_posterior) + observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1, 1)) ** 2).mean( + dim=0 + ) + + msg = f"Did not reach convergence after {max_iter} iterations. Returning the current MLL." + warnings.warn(msg) + return hetero_mll