From a8efd76b43c19982d5d0b79f14912b27094b5d93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Sat, 11 Feb 2023 13:24:09 -0800 Subject: [PATCH] Ensemble Posterior (#1636) Summary: ## Motivation As discussed in https://github.com/pytorch/botorch/issues/1064, this is an attempt to add a `EnsemblePosterior` to botorch, that could be used for example by NN ensembles. I have problems with implementing `rsample` properly. I think my current implementation is not correct, it is based on `DeterministicPosterior`, but I think we should sample directly solutions from the individual predictions of the ensemble. But I do not know how to interprete `sample_shape` in this context. As sampler, I registered the `StochasticSampler` for the new posterior class. But also, there I am not sure if this is correct. Furthermore, I have another question regarding `StochasticSampler`. It is stated in the docstring of `StochasticSampler` that it should not be used in combination with `optimize_acqf`. But `StochasticSampler` is assigned to the `DeterministicPosterior`. Does it mean that one cannot use a `ModelList` consisting of a `DeterministicModel` and GPs in combination with `optimize_acqf`? Balandat: any suggestions on this? ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: https://github.com/pytorch/botorch/pull/1636 Test Plan: Unit tests. Not yet implemented/finished as it is still WIP. Reviewed By: saitcakmak Differential Revision: D43017184 Pulled By: Balandat fbshipit-source-id: fd2ede2dbba82a40c466f8a178138ced0fcba5fe --- botorch/models/deterministic.py | 62 ++--------- botorch/models/ensemble.py | 93 +++++++++++++++++ botorch/posteriors/deterministic.py | 11 +- botorch/posteriors/ensemble.py | 141 ++++++++++++++++++++++++++ botorch/sampling/deterministic.py | 3 + botorch/sampling/get_sampler.py | 12 ++- botorch/sampling/index_sampler.py | 64 ++++++++++++ sphinx/source/models.rst | 5 + sphinx/source/posteriors.rst | 5 + sphinx/source/sampling.rst | 5 + test/models/test_deterministic.py | 5 +- test/models/test_ensemble.py | 37 +++++++ test/posteriors/test_deterministic.py | 7 ++ test/posteriors/test_ensemble.py | 89 ++++++++++++++++ test/sampling/test_index_sampler.py | 40 ++++++++ 15 files changed, 520 insertions(+), 59 deletions(-) create mode 100644 botorch/models/ensemble.py create mode 100644 botorch/posteriors/ensemble.py create mode 100644 botorch/sampling/index_sampler.py create mode 100644 test/models/test_ensemble.py create mode 100644 test/posteriors/test_ensemble.py create mode 100644 test/sampling/test_index_sampler.py diff --git a/botorch/models/deterministic.py b/botorch/models/deterministic.py index fce72b4996..66bdb23649 100644 --- a/botorch/models/deterministic.py +++ b/botorch/models/deterministic.py @@ -26,18 +26,16 @@ from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional, Union +from abc import abstractmethod +from typing import Callable, List, Optional, Union import torch -from botorch.acquisition.objective import PosteriorTransform -from botorch.exceptions.errors import UnsupportedError +from botorch.models.ensemble import EnsembleModel from botorch.models.model import Model -from botorch.posteriors.deterministic import DeterministicPosterior from torch import Tensor -class DeterministicModel(Model, ABC): +class DeterministicModel(EnsembleModel): r""" Abstract base class for deterministic models. @@ -57,55 +55,9 @@ def forward(self, X: Tensor) -> Tensor: """ pass # pragma: no cover - @property - def num_outputs(self) -> int: - r"""The number of outputs of the model.""" - return self._num_outputs - - def posterior( - self, - X: Tensor, - output_indices: Optional[List[int]] = None, - posterior_transform: Optional[PosteriorTransform] = None, - **kwargs: Any, - ) -> DeterministicPosterior: - r"""Compute the (deterministic) posterior at X. - - Args: - X: A `batch_shape x n x d`-dim input tensor `X`. - output_indices: A list of indices, corresponding to the outputs over - which to compute the posterior. If omitted, computes the posterior - over all model outputs. - posterior_transform: An optional PosteriorTransform. - - Returns: - A `DeterministicPosterior` object, representing `batch_shape` joint - posteriors over `n` points and the outputs selected by `output_indices`. - """ - # Apply the input transforms in `eval` mode. - self.eval() - X = self.transform_inputs(X) - # Note: we use a Tensor instance check so that `observation_noise = True` - # just gets ignored. This avoids having to do a bunch of case distinctions - # when using a ModelList. - if isinstance(kwargs.get("observation_noise"), Tensor): - # TODO: Consider returning an MVN here instead - raise UnsupportedError( - "Deterministic models do not support observation noise." - ) - values = self.forward(X) - # NOTE: The `outcome_transform` `untransform`s the predictions rather than the - # `posterior` (as is done in GP models). This is more general since it works - # even if the transform doesn't support `untransform_posterior`. - if hasattr(self, "outcome_transform"): - values, _ = self.outcome_transform.untransform(values) - if output_indices is not None: - values = values[..., output_indices] - posterior = DeterministicPosterior(values=values) - if posterior_transform is not None: - return posterior_transform(posterior) - else: - return posterior + def _forward(self, X: Tensor) -> Tensor: + r"""Compatibilizes the `DeterministicModel` with `EnsemblePosterior`""" + return self.forward(X=X).unsqueeze(-3) class GenericDeterministicModel(DeterministicModel): diff --git a/botorch/models/ensemble.py b/botorch/models/ensemble.py new file mode 100644 index 0000000000..25c99f12f1 --- /dev/null +++ b/botorch/models/ensemble.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Ensemble Models: Simple wrappers that allow the usage of ensembles +via the BoTorch Model and Posterior APIs. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from botorch.acquisition.objective import PosteriorTransform +from botorch.exceptions.errors import UnsupportedError +from botorch.models.model import Model +from botorch.posteriors.ensemble import EnsemblePosterior +from torch import Tensor + + +class EnsembleModel(Model, ABC): + r""" + Abstract base class for ensemble models. + + :meta private: + """ + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + r"""Compute the (ensemble) model output at X. + + Args: + X: A `batch_shape x n x d`-dim input tensor `X`. + + Returns: + A `batch_shape x s x n x m`-dimensional output tensor where + `s` is the size of the ensemble. + """ + pass # pragma: no cover + + def _forward(self, X: Tensor) -> Tensor: + return self.forward(X=X) + + @property + def num_outputs(self) -> int: + r"""The number of outputs of the model.""" + return self._num_outputs + + def posterior( + self, + X: Tensor, + output_indices: Optional[List[int]] = None, + posterior_transform: Optional[PosteriorTransform] = None, + **kwargs: Any, + ) -> EnsemblePosterior: + r"""Compute the ensemble posterior at X. + + Args: + X: A `batch_shape x q x d`-dim input tensor `X`. + output_indices: A list of indices, corresponding to the outputs over + which to compute the posterior. If omitted, computes the posterior + over all model outputs. + posterior_transform: An optional PosteriorTransform. + + Returns: + An `EnsemblePosterior` object, representing `batch_shape` joint + posteriors over `n` points and the outputs selected by `output_indices`. + """ + # Apply the input transforms in `eval` mode. + self.eval() + X = self.transform_inputs(X) + # Note: we use a Tensor instance check so that `observation_noise = True` + # just gets ignored. This avoids having to do a bunch of case distinctions + # when using a ModelList. + if isinstance(kwargs.get("observation_noise"), Tensor): + # TODO: Consider returning an MVN here instead + raise UnsupportedError("Ensemble models do not support observation noise.") + values = self._forward(X) + # NOTE: The `outcome_transform` `untransform`s the predictions rather than the + # `posterior` (as is done in GP models). This is more general since it works + # even if the transform doesn't support `untransform_posterior`. + if hasattr(self, "outcome_transform"): + values, _ = self.outcome_transform.untransform(values) + if output_indices is not None: + values = values[..., output_indices] + posterior = EnsemblePosterior(values=values) + if posterior_transform is not None: + return posterior_transform(posterior) + else: + return posterior diff --git a/botorch/posteriors/deterministic.py b/botorch/posteriors/deterministic.py index 57df6b9b7c..1e39be3291 100644 --- a/botorch/posteriors/deterministic.py +++ b/botorch/posteriors/deterministic.py @@ -12,6 +12,7 @@ from __future__ import annotations from typing import Optional +from warnings import warn import torch from botorch.posteriors.posterior import Posterior @@ -19,13 +20,21 @@ class DeterministicPosterior(Posterior): - r"""Deterministic posterior.""" + r"""Deterministic posterior. + + [DEPRECATED] Use `EnsemblePosterior` instead. + """ def __init__(self, values: Tensor) -> None: r""" Args: values: Values of the samples produced by this posterior. """ + warn( + "`DeterministicPosterior` is marked for deprecation, consider using " + "`EnsemblePosterior`.", + DeprecationWarning, + ) self.values = values @property diff --git a/botorch/posteriors/ensemble.py b/botorch/posteriors/ensemble.py new file mode 100644 index 0000000000..7eeebb4cf2 --- /dev/null +++ b/botorch/posteriors/ensemble.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Ensemble posteriors. Used in conjunction with ensemble models. +""" + +from __future__ import annotations + +from typing import Optional + +import torch +from botorch.posteriors.posterior import Posterior +from torch import Tensor + + +class EnsemblePosterior(Posterior): + r"""Ensemble posterior, that should be used for ensemble models that compute + eagerly a finite number of samples per X value as for example a deep ensemble + or a random forest.""" + + def __init__(self, values: Tensor) -> None: + r""" + Args: + values: Values of the samples produced by this posterior as + a `(b) x s x q x m` tensor where `m` is the output size of the + model and `s` is the ensemble size. + """ + if values.ndim < 3: + raise ValueError("Values has to be at least three-dimensional.") + self.values = values + + @property + def ensemble_size(self) -> int: + r"""The size of the ensemble""" + return self.values.shape[-3] + + @property + def weights(self) -> Tensor: + r"""The weights of the individual models in the ensemble. + Equally weighted by default.""" + return torch.ones(self.ensemble_size) / self.ensemble_size + + @property + def device(self) -> torch.device: + r"""The torch device of the posterior.""" + return self.values.device + + @property + def dtype(self) -> torch.dtype: + r"""The torch dtype of the posterior.""" + return self.values.dtype + + @property + def mean(self) -> Tensor: + r"""The mean of the posterior as a `(b) x n x m`-dim Tensor.""" + return self.values.mean(dim=-3) + + @property + def variance(self) -> Tensor: + r"""The variance of the posterior as a `(b) x n x m`-dim Tensor. + + Computed as the sample variance across the ensemble outputs. + """ + if self.ensemble_size == 1: + return torch.zeros_like(self.values.squeeze(-3)) + return self.values.var(dim=-3) + + def _extended_shape( + self, sample_shape: torch.Size = torch.Size() # noqa: B008 + ) -> torch.Size: + r"""Returns the shape of the samples produced by the posterior with + the given `sample_shape`. + """ + return sample_shape + self.values.shape[:-3] + self.values.shape[-2:] + + def rsample( + self, + sample_shape: Optional[torch.Size] = None, + ) -> Tensor: + r"""Sample from the posterior (with gradients). + + Based on the sample shape, base samples are generated and passed to + `rsample_from_base_samples`. + + Args: + sample_shape: A `torch.Size` object specifying the sample shape. To + draw `n` samples, set to `torch.Size([n])`. To draw `b` batches + of `n` samples each, set to `torch.Size([b, n])`. + + Returns: + Samples from the posterior, a tensor of shape + `self._extended_shape(sample_shape=sample_shape)`. + """ + if sample_shape is None: + sample_shape = torch.Size([1]) + # get indices as base_samples + base_samples = ( + torch.multinomial( + self.weights, + num_samples=sample_shape.numel(), + replacement=True, + ) + .reshape(sample_shape) + .to(device=self.device) + ) + return self.rsample_from_base_samples( + sample_shape=sample_shape, base_samples=base_samples + ) + + def rsample_from_base_samples( + self, sample_shape: torch.Size, base_samples: Tensor + ) -> Tensor: + r"""Sample from the posterior (with gradients) using base samples. + + This is intended to be used with a sampler that produces the corresponding base + samples, and enables acquisition optimization via Sample Average Approximation. + + Args: + sample_shape: A `torch.Size` object specifying the sample shape. To + draw `n` samples, set to `torch.Size([n])`. To draw `b` batches + of `n` samples each, set to `torch.Size([b, n])`. + base_samples: A Tensor of indices as base samples of shape + `sample_shape`, typically obtained from `IndexSampler`. + This is used for deterministic optimization. The predictions of + the ensemble corresponding to the indices are then sampled. + + + Returns: + Samples from the posterior, a tensor of shape + `self._extended_shape(sample_shape=sample_shape)`. + """ + if base_samples.shape != sample_shape: + raise ValueError("Base samples do not match sample shape.") + # move sample axis to front + values = self.values.movedim(-3, 0) + # sample from the first dimension of values + return values[base_samples, ...] diff --git a/botorch/sampling/deterministic.py b/botorch/sampling/deterministic.py index 39c59f0932..c5d1c0bb4c 100644 --- a/botorch/sampling/deterministic.py +++ b/botorch/sampling/deterministic.py @@ -18,6 +18,9 @@ class DeterministicSampler(StochasticSampler): r"""A sampler that simply calls `posterior.rsample`, intended to be used with `DeterministicModel` & `DeterministicPosterior`. + [DEPRECATED] - Use `IndexSampler` in conjunction with `EnsemblePosterior` + instead of `DeterministicSampler` with `DeterministicPosterior`. + This is effectively signals that `StochasticSampler` is safe to use with deterministic models since their output is deterministic by definition. """ diff --git a/botorch/sampling/get_sampler.py b/botorch/sampling/get_sampler.py index 447d91a75b..11bf8ceaf9 100644 --- a/botorch/sampling/get_sampler.py +++ b/botorch/sampling/get_sampler.py @@ -10,6 +10,7 @@ import torch from botorch.logging import logger from botorch.posteriors.deterministic import DeterministicPosterior +from botorch.posteriors.ensemble import EnsemblePosterior from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior import Posterior from botorch.posteriors.posterior_list import PosteriorList @@ -17,6 +18,7 @@ from botorch.posteriors.transformed import TransformedPosterior from botorch.sampling.base import MCSampler from botorch.sampling.deterministic import DeterministicSampler +from botorch.sampling.index_sampler import IndexSampler from botorch.sampling.list_sampler import ListSampler from botorch.sampling.normal import ( IIDNormalSampler, @@ -111,10 +113,18 @@ def _get_sampler_list( def _get_sampler_deterministic( posterior: DeterministicPosterior, sample_shape: torch.Size, **kwargs: Any ) -> MCSampler: - r"""Get the dummy `StochasticSampler` for the `DeterministicPosterior`.""" + r"""Get the dummy `DeterministicSampler` for the `DeterministicPosterior`.""" return DeterministicSampler(sample_shape=sample_shape, **kwargs) +@GetSampler.register(EnsemblePosterior) +def _get_sampler_ensemble( + posterior: EnsemblePosterior, sample_shape: torch.Size, **kwargs: Any +) -> MCSampler: + r"""Get the `IndexSampler` for the `EnsemblePosterior`.""" + return IndexSampler(sample_shape=sample_shape, **kwargs) + + @GetSampler.register(object) def _not_found_error( posterior: Posterior, sample_shape: torch.Size, **kwargs: Any diff --git a/botorch/sampling/index_sampler.py b/botorch/sampling/index_sampler.py new file mode 100644 index 0000000000..ac64388a67 --- /dev/null +++ b/botorch/sampling/index_sampler.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Sampler to be used with `EnsemblePosteriors` to enable +deterministic optimization of acquisition functions with ensemble models. +""" + +from __future__ import annotations + +import torch +from botorch.posteriors.ensemble import EnsemblePosterior +from botorch.sampling.base import MCSampler +from torch import Tensor + + +class IndexSampler(MCSampler): + r"""A sampler that calls `posterior.rsample_from_base_samples` to + generate the samples via index base samples.""" + + def forward(self, posterior: EnsemblePosterior) -> Tensor: + r"""Draws MC samples from the posterior. + + Args: + posterior: The ensemble posterior to sample from. + + Returns: + The samples drawn from the posterior. + """ + self._construct_base_samples(posterior=posterior) + samples = posterior.rsample_from_base_samples( + sample_shape=self.sample_shape, base_samples=self.base_samples + ) + return samples + + def _construct_base_samples(self, posterior: EnsemblePosterior) -> None: + r"""Constructs base samples as indices to sample with them from + the Posterior. + + Args: + posterior: The ensemble posterior to construct the base samples + for. + """ + if self.base_samples is None or self.base_samples.shape != self.sample_shape: + with torch.random.fork_rng(): + torch.manual_seed(self.seed) + base_samples = torch.multinomial( + posterior.weights, + num_samples=self.sample_shape.numel(), + replacement=True, + ).reshape(self.sample_shape) + self.register_buffer("base_samples", base_samples) + if self.base_samples.device != posterior.device: + self.to(device=posterior.device) # pragma: nocover + + def _update_base_samples( + self, posterior: EnsemblePosterior, base_sampler: IndexSampler + ) -> None: + r"""Null operation just needed for compatibility with + `CachedCholeskyAcquisitionFunction`.""" + pass diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index 0660a7cc1f..533f6fe52c 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -25,6 +25,11 @@ Deterministic Model API .. automodule:: botorch.models.deterministic :members: +Ensemble Model API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.ensemble + :members: + Models ------------------------------------------- diff --git a/sphinx/source/posteriors.rst b/sphinx/source/posteriors.rst index 76ea39a2a5..655b22ea80 100644 --- a/sphinx/source/posteriors.rst +++ b/sphinx/source/posteriors.rst @@ -39,6 +39,11 @@ Determinstic Posterior .. automodule:: botorch.posteriors.deterministic :members: +Ensemble Posterior +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.posteriors.ensemble + :members: + Higher Order GP Posterior ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.posteriors.higher_order diff --git a/sphinx/source/sampling.rst b/sphinx/source/sampling.rst index 0113a50d43..cbfcae2a25 100644 --- a/sphinx/source/sampling.rst +++ b/sphinx/source/sampling.rst @@ -17,6 +17,11 @@ Deterministic Sampler .. automodule:: botorch.sampling.deterministic :members: +Index Sampler +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.index_sampler + :members: + Get Sampler Helper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.sampling.get_sampler diff --git a/test/models/test_deterministic.py b/test/models/test_deterministic.py index fdc3d97d0f..510c07a1e0 100644 --- a/test/models/test_deterministic.py +++ b/test/models/test_deterministic.py @@ -19,7 +19,7 @@ from botorch.models.gp_regression import SingleTaskGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize -from botorch.posteriors.deterministic import DeterministicPosterior +from botorch.posteriors.ensemble import EnsemblePosterior from botorch.utils.testing import BotorchTestCase @@ -61,7 +61,8 @@ def f(X): X = torch.rand(3, 2) # basic test p = model.posterior(X) - self.assertIsInstance(p, DeterministicPosterior) + self.assertIsInstance(p, EnsemblePosterior) + self.assertEqual(p.ensemble_size, 1) self.assertTrue(torch.equal(p.mean, f(X))) # check that observation noise doesn't change things p_noisy = model.posterior(X, observation_noise=True) diff --git a/test/models/test_ensemble.py b/test/models/test_ensemble.py new file mode 100644 index 0000000000..75b3e266d4 --- /dev/null +++ b/test/models/test_ensemble.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.models.ensemble import EnsembleModel +from botorch.utils.testing import BotorchTestCase + + +class DummyEnsembleModel(EnsembleModel): + r"""A dummy ensemble model.""" + + def __init__(self): + r"""Init model.""" + super().__init__() + self._num_outputs = 2 + self.a = torch.rand(4, 3, 2) + + def forward(self, X): + return torch.stack( + [torch.einsum("...d,dm", X, self.a[i]) for i in range(4)], dim=-3 + ) + + +class TestEnsembleModels(BotorchTestCase): + def test_abstract_base_model(self): + with self.assertRaises(TypeError): + EnsembleModel() + + def test_DummyEnsembleModel(self): + for shape in [(10, 3), (5, 10, 3)]: + e = DummyEnsembleModel() + X = torch.randn(*shape) + p = e.posterior(X) + self.assertEqual(p.ensemble_size, 4) diff --git a/test/posteriors/test_deterministic.py b/test/posteriors/test_deterministic.py index 053913ce17..0dbfa37bb9 100644 --- a/test/posteriors/test_deterministic.py +++ b/test/posteriors/test_deterministic.py @@ -6,6 +6,8 @@ import itertools +from warnings import catch_warnings + import torch from botorch.posteriors.deterministic import DeterministicPosterior from botorch.utils.testing import BotorchTestCase @@ -18,6 +20,11 @@ def test_DeterministicPosterior(self): ): values = torch.randn(*shape, device=self.device, dtype=dtype) p = DeterministicPosterior(values) + with catch_warnings(record=True) as ws: + p = DeterministicPosterior(values) + self.assertTrue( + any("marked for deprecation" in str(w.message) for w in ws) + ) self.assertEqual(p.device.type, self.device.type) self.assertEqual(p.dtype, dtype) self.assertEqual(p._extended_shape(), values.shape) diff --git a/test/posteriors/test_ensemble.py b/test/posteriors/test_ensemble.py new file mode 100644 index 0000000000..aa589091d2 --- /dev/null +++ b/test/posteriors/test_ensemble.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +import torch +from botorch.posteriors.ensemble import EnsemblePosterior +from botorch.utils.testing import BotorchTestCase + + +class TestEnsemblePosterior(BotorchTestCase): + def test_EnsemblePosterior_invalid(self): + for shape, dtype in itertools.product( + ((5, 2), (5, 1)), (torch.float, torch.double) + ): + values = torch.randn(*shape, device=self.device, dtype=dtype) + with self.assertRaisesRegex( + ValueError, + "Values has to be at least three-dimensional", + ): + EnsemblePosterior(values) + + def test_EnsemblePosterior_as_Deterministic(self): + for shape, dtype in itertools.product( + ((1, 3, 2), (2, 1, 3, 2)), (torch.float, torch.double) + ): + values = torch.randn(*shape, device=self.device, dtype=dtype) + p = EnsemblePosterior(values) + self.assertEqual(p.ensemble_size, 1) + self.assertEqual(p.device.type, self.device.type) + self.assertEqual(p.dtype, dtype) + self.assertEqual( + p._extended_shape(torch.Size((1,))), + torch.Size((1, 3, 2)) if len(shape) == 3 else torch.Size((1, 2, 3, 2)), + ) + self.assertEqual(p.weights, torch.ones(1)) + with self.assertRaises(NotImplementedError): + p.base_sample_shape + self.assertTrue(torch.equal(p.mean, values.squeeze(-3))) + self.assertTrue( + torch.equal(p.variance, torch.zeros_like(values.squeeze(-3))) + ) + # test sampling + samples = p.rsample() + self.assertTrue(torch.equal(samples, values.squeeze(-3).unsqueeze(0))) + samples = p.rsample(torch.Size([2])) + self.assertEqual(samples.shape, p._extended_shape(torch.Size([2]))) + + def test_EnsemblePosterior(self): + for shape, dtype in itertools.product( + ((16, 5, 2), (2, 16, 5, 2)), (torch.float, torch.double) + ): + values = torch.randn(*shape, device=self.device, dtype=dtype) + p = EnsemblePosterior(values) + self.assertEqual(p.device.type, self.device.type) + self.assertEqual(p.dtype, dtype) + self.assertEqual(p.ensemble_size, 16) + self.assertAllClose( + p.weights, torch.tensor([1.0 / p.ensemble_size] * p.ensemble_size) + ) + # test mean and variance + self.assertTrue(torch.equal(p.mean, values.mean(dim=-3))) + self.assertTrue(torch.equal(p.variance, values.var(dim=-3))) + # test extended shape + self.assertEqual( + p._extended_shape(torch.Size((128,))), + torch.Size((128, 5, 2)) + if len(shape) == 3 + else torch.Size((128, 2, 5, 2)), + ) + # test rsample + samples = p.rsample(torch.Size((1024,))) + self.assertEqual(samples.shape, p._extended_shape(torch.Size((1024,)))) + # test rsample from base samples + # test that produced samples are correct + samples = p.rsample_from_base_samples( + sample_shape=torch.Size((16,)), base_samples=torch.arange(16) + ) + self.assertEqual(samples.shape, p._extended_shape(torch.Size((16,)))) + self.assertAllClose(p.mean, samples.mean(dim=0)) + self.assertAllClose(p.variance, samples.var(dim=0)) + # test error on base_samples, sample_shape mismatch + with self.assertRaises(ValueError): + p.rsample_from_base_samples( + sample_shape=torch.Size((17,)), base_samples=torch.arange(16) + ) diff --git a/test/sampling/test_index_sampler.py b/test/sampling/test_index_sampler.py new file mode 100644 index 0000000000..02736751c5 --- /dev/null +++ b/test/sampling/test_index_sampler.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.posteriors.ensemble import EnsemblePosterior +from botorch.sampling.index_sampler import IndexSampler +from botorch.utils.testing import BotorchTestCase + + +class TestIndexSampler(BotorchTestCase): + def test_index_sampler(self): + # Basic usage. + posterior = EnsemblePosterior( + values=torch.randn(torch.Size((50, 16, 1, 1))).to(self.device) + ) + sampler = IndexSampler(sample_shape=torch.Size((128,))) + samples = sampler(posterior) + self.assertTrue(samples.shape == torch.Size((128, 50, 1, 1))) + self.assertTrue(sampler.base_samples.max() < 16) + self.assertTrue(sampler.base_samples.min() >= 0) + # check deterministic nature + samples2 = sampler(posterior) + self.assertAllClose(samples, samples2) + # test construct base samples + sampler = IndexSampler(sample_shape=torch.Size((4, 128)), seed=42) + self.assertTrue(sampler.base_samples is None) + sampler._construct_base_samples(posterior=posterior) + self.assertTrue(sampler.base_samples.shape == torch.Size((4, 128))) + self.assertTrue( + sampler.base_samples.device.type + == posterior.device.type + == self.device.type + ) + base_samples = sampler.base_samples + sampler = IndexSampler(sample_shape=torch.Size((4, 128)), seed=42) + sampler._construct_base_samples(posterior=posterior) + self.assertAllClose(base_samples, sampler.base_samples)