Skip to content

Commit

Permalink
Ensemble Posterior (#1636)
Browse files Browse the repository at this point in the history
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

As discussed in #1064, this is an attempt to add a `EnsemblePosterior` to botorch, that could be used for example by NN ensembles.

I have problems with implementing `rsample` properly. I think my current implementation is not correct, it is based on `DeterministicPosterior`, but I think we should sample directly solutions from the individual predictions of the ensemble. But I do not know how to interprete `sample_shape` in this context.

As sampler, I registered the `StochasticSampler` for the new posterior class. But also, there I am not sure if this is correct. Furthermore, I have another question regarding `StochasticSampler`. It is stated in the docstring of `StochasticSampler` that it should not be used in combination with `optimize_acqf`. But `StochasticSampler` is assigned to the `DeterministicPosterior`. Does it mean that one cannot use a `ModelList` consisting of a `DeterministicModel` and GPs in combination with `optimize_acqf`?

Balandat: any suggestions on this?

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes.

Pull Request resolved: #1636

Test Plan: Unit tests. Not yet implemented/finished as it is still WIP.

Reviewed By: saitcakmak

Differential Revision: D43017184

Pulled By: Balandat

fbshipit-source-id: fd2ede2dbba82a40c466f8a178138ced0fcba5fe
  • Loading branch information
jduerholt authored and facebook-github-bot committed Feb 11, 2023
1 parent 4445045 commit a8efd76
Show file tree
Hide file tree
Showing 15 changed files with 520 additions and 59 deletions.
62 changes: 7 additions & 55 deletions botorch/models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,16 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Union
from abc import abstractmethod
from typing import Callable, List, Optional, Union

import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import UnsupportedError
from botorch.models.ensemble import EnsembleModel
from botorch.models.model import Model
from botorch.posteriors.deterministic import DeterministicPosterior
from torch import Tensor


class DeterministicModel(Model, ABC):
class DeterministicModel(EnsembleModel):
r"""
Abstract base class for deterministic models.
Expand All @@ -57,55 +55,9 @@ def forward(self, X: Tensor) -> Tensor:
"""
pass # pragma: no cover

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
return self._num_outputs

def posterior(
self,
X: Tensor,
output_indices: Optional[List[int]] = None,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> DeterministicPosterior:
r"""Compute the (deterministic) posterior at X.
Args:
X: A `batch_shape x n x d`-dim input tensor `X`.
output_indices: A list of indices, corresponding to the outputs over
which to compute the posterior. If omitted, computes the posterior
over all model outputs.
posterior_transform: An optional PosteriorTransform.
Returns:
A `DeterministicPosterior` object, representing `batch_shape` joint
posteriors over `n` points and the outputs selected by `output_indices`.
"""
# Apply the input transforms in `eval` mode.
self.eval()
X = self.transform_inputs(X)
# Note: we use a Tensor instance check so that `observation_noise = True`
# just gets ignored. This avoids having to do a bunch of case distinctions
# when using a ModelList.
if isinstance(kwargs.get("observation_noise"), Tensor):
# TODO: Consider returning an MVN here instead
raise UnsupportedError(
"Deterministic models do not support observation noise."
)
values = self.forward(X)
# NOTE: The `outcome_transform` `untransform`s the predictions rather than the
# `posterior` (as is done in GP models). This is more general since it works
# even if the transform doesn't support `untransform_posterior`.
if hasattr(self, "outcome_transform"):
values, _ = self.outcome_transform.untransform(values)
if output_indices is not None:
values = values[..., output_indices]
posterior = DeterministicPosterior(values=values)
if posterior_transform is not None:
return posterior_transform(posterior)
else:
return posterior
def _forward(self, X: Tensor) -> Tensor:
r"""Compatibilizes the `DeterministicModel` with `EnsemblePosterior`"""
return self.forward(X=X).unsqueeze(-3)


class GenericDeterministicModel(DeterministicModel):
Expand Down
93 changes: 93 additions & 0 deletions botorch/models/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Ensemble Models: Simple wrappers that allow the usage of ensembles
via the BoTorch Model and Posterior APIs.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional

from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.posteriors.ensemble import EnsemblePosterior
from torch import Tensor


class EnsembleModel(Model, ABC):
r"""
Abstract base class for ensemble models.
:meta private:
"""

@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Compute the (ensemble) model output at X.
Args:
X: A `batch_shape x n x d`-dim input tensor `X`.
Returns:
A `batch_shape x s x n x m`-dimensional output tensor where
`s` is the size of the ensemble.
"""
pass # pragma: no cover

def _forward(self, X: Tensor) -> Tensor:
return self.forward(X=X)

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
return self._num_outputs

def posterior(
self,
X: Tensor,
output_indices: Optional[List[int]] = None,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> EnsemblePosterior:
r"""Compute the ensemble posterior at X.
Args:
X: A `batch_shape x q x d`-dim input tensor `X`.
output_indices: A list of indices, corresponding to the outputs over
which to compute the posterior. If omitted, computes the posterior
over all model outputs.
posterior_transform: An optional PosteriorTransform.
Returns:
An `EnsemblePosterior` object, representing `batch_shape` joint
posteriors over `n` points and the outputs selected by `output_indices`.
"""
# Apply the input transforms in `eval` mode.
self.eval()
X = self.transform_inputs(X)
# Note: we use a Tensor instance check so that `observation_noise = True`
# just gets ignored. This avoids having to do a bunch of case distinctions
# when using a ModelList.
if isinstance(kwargs.get("observation_noise"), Tensor):
# TODO: Consider returning an MVN here instead
raise UnsupportedError("Ensemble models do not support observation noise.")
values = self._forward(X)
# NOTE: The `outcome_transform` `untransform`s the predictions rather than the
# `posterior` (as is done in GP models). This is more general since it works
# even if the transform doesn't support `untransform_posterior`.
if hasattr(self, "outcome_transform"):
values, _ = self.outcome_transform.untransform(values)
if output_indices is not None:
values = values[..., output_indices]
posterior = EnsemblePosterior(values=values)
if posterior_transform is not None:
return posterior_transform(posterior)
else:
return posterior
11 changes: 10 additions & 1 deletion botorch/posteriors/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,29 @@
from __future__ import annotations

from typing import Optional
from warnings import warn

import torch
from botorch.posteriors.posterior import Posterior
from torch import Tensor


class DeterministicPosterior(Posterior):
r"""Deterministic posterior."""
r"""Deterministic posterior.
[DEPRECATED] Use `EnsemblePosterior` instead.
"""

def __init__(self, values: Tensor) -> None:
r"""
Args:
values: Values of the samples produced by this posterior.
"""
warn(
"`DeterministicPosterior` is marked for deprecation, consider using "
"`EnsemblePosterior`.",
DeprecationWarning,
)
self.values = values

@property
Expand Down
141 changes: 141 additions & 0 deletions botorch/posteriors/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Ensemble posteriors. Used in conjunction with ensemble models.
"""

from __future__ import annotations

from typing import Optional

import torch
from botorch.posteriors.posterior import Posterior
from torch import Tensor


class EnsemblePosterior(Posterior):
r"""Ensemble posterior, that should be used for ensemble models that compute
eagerly a finite number of samples per X value as for example a deep ensemble
or a random forest."""

def __init__(self, values: Tensor) -> None:
r"""
Args:
values: Values of the samples produced by this posterior as
a `(b) x s x q x m` tensor where `m` is the output size of the
model and `s` is the ensemble size.
"""
if values.ndim < 3:
raise ValueError("Values has to be at least three-dimensional.")
self.values = values

@property
def ensemble_size(self) -> int:
r"""The size of the ensemble"""
return self.values.shape[-3]

@property
def weights(self) -> Tensor:
r"""The weights of the individual models in the ensemble.
Equally weighted by default."""
return torch.ones(self.ensemble_size) / self.ensemble_size

@property
def device(self) -> torch.device:
r"""The torch device of the posterior."""
return self.values.device

@property
def dtype(self) -> torch.dtype:
r"""The torch dtype of the posterior."""
return self.values.dtype

@property
def mean(self) -> Tensor:
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
return self.values.mean(dim=-3)

@property
def variance(self) -> Tensor:
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.
Computed as the sample variance across the ensemble outputs.
"""
if self.ensemble_size == 1:
return torch.zeros_like(self.values.squeeze(-3))
return self.values.var(dim=-3)

def _extended_shape(
self, sample_shape: torch.Size = torch.Size() # noqa: B008
) -> torch.Size:
r"""Returns the shape of the samples produced by the posterior with
the given `sample_shape`.
"""
return sample_shape + self.values.shape[:-3] + self.values.shape[-2:]

def rsample(
self,
sample_shape: Optional[torch.Size] = None,
) -> Tensor:
r"""Sample from the posterior (with gradients).
Based on the sample shape, base samples are generated and passed to
`rsample_from_base_samples`.
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
Returns:
Samples from the posterior, a tensor of shape
`self._extended_shape(sample_shape=sample_shape)`.
"""
if sample_shape is None:
sample_shape = torch.Size([1])
# get indices as base_samples
base_samples = (
torch.multinomial(
self.weights,
num_samples=sample_shape.numel(),
replacement=True,
)
.reshape(sample_shape)
.to(device=self.device)
)
return self.rsample_from_base_samples(
sample_shape=sample_shape, base_samples=base_samples
)

def rsample_from_base_samples(
self, sample_shape: torch.Size, base_samples: Tensor
) -> Tensor:
r"""Sample from the posterior (with gradients) using base samples.
This is intended to be used with a sampler that produces the corresponding base
samples, and enables acquisition optimization via Sample Average Approximation.
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
base_samples: A Tensor of indices as base samples of shape
`sample_shape`, typically obtained from `IndexSampler`.
This is used for deterministic optimization. The predictions of
the ensemble corresponding to the indices are then sampled.
Returns:
Samples from the posterior, a tensor of shape
`self._extended_shape(sample_shape=sample_shape)`.
"""
if base_samples.shape != sample_shape:
raise ValueError("Base samples do not match sample shape.")
# move sample axis to front
values = self.values.movedim(-3, 0)
# sample from the first dimension of values
return values[base_samples, ...]
3 changes: 3 additions & 0 deletions botorch/sampling/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class DeterministicSampler(StochasticSampler):
r"""A sampler that simply calls `posterior.rsample`, intended to be used with
`DeterministicModel` & `DeterministicPosterior`.
[DEPRECATED] - Use `IndexSampler` in conjunction with `EnsemblePosterior`
instead of `DeterministicSampler` with `DeterministicPosterior`.
This is effectively signals that `StochasticSampler` is safe to use with
deterministic models since their output is deterministic by definition.
"""
Expand Down
12 changes: 11 additions & 1 deletion botorch/sampling/get_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import torch
from botorch.logging import logger
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.ensemble import EnsemblePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.posteriors.torch import TorchPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.base import MCSampler
from botorch.sampling.deterministic import DeterministicSampler
from botorch.sampling.index_sampler import IndexSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import (
IIDNormalSampler,
Expand Down Expand Up @@ -111,10 +113,18 @@ def _get_sampler_list(
def _get_sampler_deterministic(
posterior: DeterministicPosterior, sample_shape: torch.Size, **kwargs: Any
) -> MCSampler:
r"""Get the dummy `StochasticSampler` for the `DeterministicPosterior`."""
r"""Get the dummy `DeterministicSampler` for the `DeterministicPosterior`."""
return DeterministicSampler(sample_shape=sample_shape, **kwargs)


@GetSampler.register(EnsemblePosterior)
def _get_sampler_ensemble(
posterior: EnsemblePosterior, sample_shape: torch.Size, **kwargs: Any
) -> MCSampler:
r"""Get the `IndexSampler` for the `EnsemblePosterior`."""
return IndexSampler(sample_shape=sample_shape, **kwargs)


@GetSampler.register(object)
def _not_found_error(
posterior: Posterior, sample_shape: torch.Size, **kwargs: Any
Expand Down
Loading

0 comments on commit a8efd76

Please sign in to comment.