From e44280e32f64de9c8deed8ca989c258410ac423d Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Thu, 1 Aug 2024 14:23:09 -0700 Subject: [PATCH] Improvement of qBayesianActiveLearningByDisagreement (#2457) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2457 Improvement of the implementation of qBayesianActiveLearningByDisagreement - Utilizes a Monte Carlo approach for approximating the entropy - Does not use concatenate_pending_points, as it is not evident that fantasizing makes sense in the same way as for standard MC acquisition functions - Can accept posterior transforms - get_model and get_fully_bayesian_model are used in tests to be similar to other tests (e.g. JES & the subsequent active learning acqfs to enable move to test_helpers Reviewed By: saitcakmak Differential Revision: D60308502 fbshipit-source-id: 6de1dffc4f497ef4823428b2903b19ff8f0d60d7 --- .../acquisition/bayesian_active_learning.py | 77 ++++++++++------ botorch/acquisition/input_constructors.py | 4 + .../test_bayesian_active_learning.py | 88 +++++++++++++------ 3 files changed, 117 insertions(+), 52 deletions(-) diff --git a/botorch/acquisition/bayesian_active_learning.py b/botorch/acquisition/bayesian_active_learning.py index b4f06fc24c..a4608ab702 100644 --- a/botorch/acquisition/bayesian_active_learning.py +++ b/botorch/acquisition/bayesian_active_learning.py @@ -22,10 +22,11 @@ from typing import Optional -import torch from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin -from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP +from botorch.acquisition.objective import PosteriorTransform +from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP from botorch.models.model import Model +from botorch.sampling.base import MCSampler from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform from torch import Tensor @@ -54,24 +55,37 @@ class qBayesianActiveLearningByDisagreement( def __init__( self, model: SaasFullyBayesianSingleTaskGP, + sampler: Optional[MCSampler] = None, + posterior_transform: Optional[PosteriorTransform] = None, X_pending: Optional[Tensor] = None, ) -> None: """ Batch implementation [kirsch2019batchbald]_ of BALD [Houlsby2011bald]_, which maximizes the mutual information between the next observation and the - hyperparameters of the model. Computed by informational lower bound. + hyperparameters of the model. Computed by Monte Carlo integration. Args: - model: A fully bayesian single-outcome model. - X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points. + model: A fully bayesian model (SaasFullyBayesianSingleTaskGP). + sampler: The sampler used for drawing samples to approximate the entropy + of the Gaussian Mixture posterior. + posterior_transform: A PosteriorTransform. If using a multi-output model, + a PosteriorTransform that transforms the multi-output posterior into a + single-output posterior is required. + X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points + """ - super().__init__(model) + super().__init__(model=model) + MCSamplerMixin.__init__(self, sampler=sampler) self.set_X_pending(X_pending) + self.posterior_transform = posterior_transform @concatenate_pending_points @t_batch_mode_transform() def forward(self, X: Tensor) -> Tensor: r"""Evaluate qBayesianActiveLearningByDisagreement on the candidate set `X`. + A monte carlo-estimated information gain is computed over a Gaussian Mixture + marginal posterior, and the Gaussian conditional posterior to obtain the + qBayesianActiveLearningByDisagreement on the candidate set `X`. Args: X: `batch_shape x q x D`-dim Tensor of input points. @@ -79,23 +93,34 @@ def forward(self, X: Tensor) -> Tensor: Returns: A `batch_shape x num_models`-dim Tensor of BALD values. """ - return self._compute_lower_bound_information_gain(X) - - def _compute_lower_bound_information_gain(self, X: Tensor) -> Tensor: - r"""Evaluates the lower bounded information gain on the candidate set `X`. - - Args: - X: `batch_shape x q x D`-dim Tensor of input points. - - Returns: - A `batch_shape x num_models`-dim Tensor of information gains. - """ - posterior = self.model.posterior(X, observation_noise=True) - marg_covar = posterior.mixture_covariance_matrix - cond_variances = posterior.variance - - prev_entropy = torch.logdet(marg_covar).unsqueeze(-1) - # squeeze excess dim and mean over q-batch - post_ub_entropy = torch.log(cond_variances).squeeze(-1).mean(-1) - - return prev_entropy - post_ub_entropy + posterior = self.model.posterior( + X, observation_noise=True, posterior_transform=self.posterior_transform + ) + # draw samples from the mixture posterior. + # samples: num_samples x batch_shape x num_models x q x num_outputs + samples = self.get_posterior_samples(posterior=posterior) + + # Estimate the entropy of 'num_samples' samples from 'num_models' models by + # evaluating the log_prob on each sample on the mixture posterior + # (which constitutes of M models). thus, order N*M^2 computations + + # Make room and move the model dim to the front, squeeze the num_outputs dim. + # prev_samples: num_models x num_samples x batch_shape x 1 x q + prev_samples = samples.unsqueeze(0).transpose(0, MCMC_DIM).squeeze(-1) + + # avg the probs over models in the mixture - dim (-2) will be broadcasted + # with the num_models of the posterior --> querying all samples on all models + # posterior.mvn takes q-dimensional input by default, which removes the q-dim + # component_sample_probs: num_models x num_samples x batch_shape x num_models + component_sample_probs = posterior.mvn.log_prob(prev_samples).exp() + + # average over mixture components + mixture_sample_probs = component_sample_probs.mean(dim=-1) + + # this is the average over the model and sample dim + prev_entropy = -mixture_sample_probs.log().mean(dim=[0, 1]) + + # the posterior entropy is an average entropy over gaussians, so no mixture + post_entropy = -posterior.mvn.log_prob(samples.squeeze(-1)).mean(0) + bald = prev_entropy.unsqueeze(-1) - post_entropy + return bald diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index c6ed072aac..dcb8abda40 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -1678,9 +1678,13 @@ def construct_inputs_qJES( def construct_inputs_BALD( model: Model, X_pending: Optional[Tensor] = None, + sampler: Optional[MCSampler] = None, + posterior_transform: Optional[PosteriorTransform] = None, ): inputs = { "model": model, "X_pending": X_pending, + "sampler": sampler, + "posterior_transform": posterior_transform, } return inputs diff --git a/test/acquisition/test_bayesian_active_learning.py b/test/acquisition/test_bayesian_active_learning.py index bff8f61fc1..a4102f8e31 100644 --- a/test/acquisition/test_bayesian_active_learning.py +++ b/test/acquisition/test_bayesian_active_learning.py @@ -13,9 +13,32 @@ from botorch.models import SingleTaskGP from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP from botorch.models.transforms.outcome import Standardize +from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase +def get_model( + train_X, + train_Y, + standardize_model, + **tkwargs, +): + num_objectives = train_Y.shape[-1] + + if standardize_model: + outcome_transform = Standardize(m=num_objectives) + else: + outcome_transform = None + + model = SingleTaskGP( + train_X=train_X, + train_Y=train_Y, + outcome_transform=outcome_transform, + ) + + return model + + def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs): mcmc_samples = { @@ -28,7 +51,7 @@ def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs): return mcmc_samples -def get_model( +def get_fully_bayesian_model( train_X, train_Y, num_models, @@ -72,21 +95,26 @@ def test_q_bayesian_active_learning_by_disagreement(self): tkwargs = {"device": self.device} num_objectives = 1 num_models = 3 + input_dim = 2 + + X_pending_list = [None, torch.rand(2, input_dim)] for ( dtype, standardize_model, infer_noise, + X_pending, ) in product( (torch.float, torch.double), (False, True), # standardize_model (True,), # infer_noise - only one option avail in PyroModels + X_pending_list, ): + X_pending = X_pending.to(**tkwargs) if X_pending is not None else None tkwargs["dtype"] = dtype - input_dim = 2 train_X = torch.rand(4, input_dim, **tkwargs) train_Y = torch.rand(4, num_objectives, **tkwargs) - model = get_model( + model = get_fully_bayesian_model( train_X, train_Y, num_models, @@ -96,32 +124,40 @@ def test_q_bayesian_active_learning_by_disagreement(self): ) # test acquisition - X_pending_list = [None, torch.rand(2, input_dim, **tkwargs)] - for i in range(len(X_pending_list)): - X_pending = X_pending_list[i] - - acq = qBayesianActiveLearningByDisagreement( - model=model, - X_pending=X_pending, - ) - - test_Xs = [ - torch.rand(4, 1, input_dim, **tkwargs), - torch.rand(4, 3, input_dim, **tkwargs), - torch.rand(4, 5, 1, input_dim, **tkwargs), - torch.rand(4, 5, 3, input_dim, **tkwargs), - ] - - for j in range(len(test_Xs)): - acq_X = acq.forward(test_Xs[j]) - acq_X = acq(test_Xs[j]) - # assess shape - self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2]) + acq = qBayesianActiveLearningByDisagreement( + model=model, + X_pending=X_pending, + ) + + acq2 = qBayesianActiveLearningByDisagreement( + model=model, sampler=IIDNormalSampler(torch.Size([9])) + ) + self.assertIsInstance(acq2.sampler, IIDNormalSampler) + + test_Xs = [ + torch.rand(4, 1, input_dim, **tkwargs), + torch.rand(4, 3, input_dim, **tkwargs), + torch.rand(4, 5, 1, input_dim, **tkwargs), + torch.rand(4, 5, 3, input_dim, **tkwargs), + torch.rand(5, 13, input_dim, **tkwargs), + ] + + for j in range(len(test_Xs)): + acq_X = acq.forward(test_Xs[j]) + acq_X = acq(test_Xs[j]) + # assess shape + self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2]) + + self.assertTrue(torch.all(acq_X > 0)) # Support with non-fully bayesian models is not possible. Thus, we # throw an error. - non_fully_bayesian_model = SingleTaskGP(train_X, train_Y) - with self.assertRaises(ValueError): + non_fully_bayesian_model = get_model(train_X, train_Y, False) + with self.assertRaisesRegex( + ValueError, + "Fully Bayesian acquisition functions require a " + "SaasFullyBayesianSingleTaskGP to run.", + ): acq = qBayesianActiveLearningByDisagreement( model=non_fully_bayesian_model, )