-
Notifications
You must be signed in to change notification settings - Fork 400
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation As discussed in #1064, this is an attempt to add a `EnsemblePosterior` to botorch, that could be used for example by NN ensembles. I have problems with implementing `rsample` properly. I think my current implementation is not correct, it is based on `DeterministicPosterior`, but I think we should sample directly solutions from the individual predictions of the ensemble. But I do not know how to interprete `sample_shape` in this context. As sampler, I registered the `StochasticSampler` for the new posterior class. But also, there I am not sure if this is correct. Furthermore, I have another question regarding `StochasticSampler`. It is stated in the docstring of `StochasticSampler` that it should not be used in combination with `optimize_acqf`. But `StochasticSampler` is assigned to the `DeterministicPosterior`. Does it mean that one cannot use a `ModelList` consisting of a `DeterministicModel` and GPs in combination with `optimize_acqf`? Balandat: any suggestions on this? ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1636 Test Plan: Unit tests. Not yet implemented/finished as it is still WIP. Reviewed By: saitcakmak Differential Revision: D43017184 Pulled By: Balandat fbshipit-source-id: fd2ede2dbba82a40c466f8a178138ced0fcba5fe
- Loading branch information
1 parent
4445045
commit a8efd76
Showing
15 changed files
with
520 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
r""" | ||
Ensemble Models: Simple wrappers that allow the usage of ensembles | ||
via the BoTorch Model and Posterior APIs. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, List, Optional | ||
|
||
from botorch.acquisition.objective import PosteriorTransform | ||
from botorch.exceptions.errors import UnsupportedError | ||
from botorch.models.model import Model | ||
from botorch.posteriors.ensemble import EnsemblePosterior | ||
from torch import Tensor | ||
|
||
|
||
class EnsembleModel(Model, ABC): | ||
r""" | ||
Abstract base class for ensemble models. | ||
:meta private: | ||
""" | ||
|
||
@abstractmethod | ||
def forward(self, X: Tensor) -> Tensor: | ||
r"""Compute the (ensemble) model output at X. | ||
Args: | ||
X: A `batch_shape x n x d`-dim input tensor `X`. | ||
Returns: | ||
A `batch_shape x s x n x m`-dimensional output tensor where | ||
`s` is the size of the ensemble. | ||
""" | ||
pass # pragma: no cover | ||
|
||
def _forward(self, X: Tensor) -> Tensor: | ||
return self.forward(X=X) | ||
|
||
@property | ||
def num_outputs(self) -> int: | ||
r"""The number of outputs of the model.""" | ||
return self._num_outputs | ||
|
||
def posterior( | ||
self, | ||
X: Tensor, | ||
output_indices: Optional[List[int]] = None, | ||
posterior_transform: Optional[PosteriorTransform] = None, | ||
**kwargs: Any, | ||
) -> EnsemblePosterior: | ||
r"""Compute the ensemble posterior at X. | ||
Args: | ||
X: A `batch_shape x q x d`-dim input tensor `X`. | ||
output_indices: A list of indices, corresponding to the outputs over | ||
which to compute the posterior. If omitted, computes the posterior | ||
over all model outputs. | ||
posterior_transform: An optional PosteriorTransform. | ||
Returns: | ||
An `EnsemblePosterior` object, representing `batch_shape` joint | ||
posteriors over `n` points and the outputs selected by `output_indices`. | ||
""" | ||
# Apply the input transforms in `eval` mode. | ||
self.eval() | ||
X = self.transform_inputs(X) | ||
# Note: we use a Tensor instance check so that `observation_noise = True` | ||
# just gets ignored. This avoids having to do a bunch of case distinctions | ||
# when using a ModelList. | ||
if isinstance(kwargs.get("observation_noise"), Tensor): | ||
# TODO: Consider returning an MVN here instead | ||
raise UnsupportedError("Ensemble models do not support observation noise.") | ||
values = self._forward(X) | ||
# NOTE: The `outcome_transform` `untransform`s the predictions rather than the | ||
# `posterior` (as is done in GP models). This is more general since it works | ||
# even if the transform doesn't support `untransform_posterior`. | ||
if hasattr(self, "outcome_transform"): | ||
values, _ = self.outcome_transform.untransform(values) | ||
if output_indices is not None: | ||
values = values[..., output_indices] | ||
posterior = EnsemblePosterior(values=values) | ||
if posterior_transform is not None: | ||
return posterior_transform(posterior) | ||
else: | ||
return posterior |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
r""" | ||
Ensemble posteriors. Used in conjunction with ensemble models. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from botorch.posteriors.posterior import Posterior | ||
from torch import Tensor | ||
|
||
|
||
class EnsemblePosterior(Posterior): | ||
r"""Ensemble posterior, that should be used for ensemble models that compute | ||
eagerly a finite number of samples per X value as for example a deep ensemble | ||
or a random forest.""" | ||
|
||
def __init__(self, values: Tensor) -> None: | ||
r""" | ||
Args: | ||
values: Values of the samples produced by this posterior as | ||
a `(b) x s x q x m` tensor where `m` is the output size of the | ||
model and `s` is the ensemble size. | ||
""" | ||
if values.ndim < 3: | ||
raise ValueError("Values has to be at least three-dimensional.") | ||
self.values = values | ||
|
||
@property | ||
def ensemble_size(self) -> int: | ||
r"""The size of the ensemble""" | ||
return self.values.shape[-3] | ||
|
||
@property | ||
def weights(self) -> Tensor: | ||
r"""The weights of the individual models in the ensemble. | ||
Equally weighted by default.""" | ||
return torch.ones(self.ensemble_size) / self.ensemble_size | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
r"""The torch device of the posterior.""" | ||
return self.values.device | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
r"""The torch dtype of the posterior.""" | ||
return self.values.dtype | ||
|
||
@property | ||
def mean(self) -> Tensor: | ||
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor.""" | ||
return self.values.mean(dim=-3) | ||
|
||
@property | ||
def variance(self) -> Tensor: | ||
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor. | ||
Computed as the sample variance across the ensemble outputs. | ||
""" | ||
if self.ensemble_size == 1: | ||
return torch.zeros_like(self.values.squeeze(-3)) | ||
return self.values.var(dim=-3) | ||
|
||
def _extended_shape( | ||
self, sample_shape: torch.Size = torch.Size() # noqa: B008 | ||
) -> torch.Size: | ||
r"""Returns the shape of the samples produced by the posterior with | ||
the given `sample_shape`. | ||
""" | ||
return sample_shape + self.values.shape[:-3] + self.values.shape[-2:] | ||
|
||
def rsample( | ||
self, | ||
sample_shape: Optional[torch.Size] = None, | ||
) -> Tensor: | ||
r"""Sample from the posterior (with gradients). | ||
Based on the sample shape, base samples are generated and passed to | ||
`rsample_from_base_samples`. | ||
Args: | ||
sample_shape: A `torch.Size` object specifying the sample shape. To | ||
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches | ||
of `n` samples each, set to `torch.Size([b, n])`. | ||
Returns: | ||
Samples from the posterior, a tensor of shape | ||
`self._extended_shape(sample_shape=sample_shape)`. | ||
""" | ||
if sample_shape is None: | ||
sample_shape = torch.Size([1]) | ||
# get indices as base_samples | ||
base_samples = ( | ||
torch.multinomial( | ||
self.weights, | ||
num_samples=sample_shape.numel(), | ||
replacement=True, | ||
) | ||
.reshape(sample_shape) | ||
.to(device=self.device) | ||
) | ||
return self.rsample_from_base_samples( | ||
sample_shape=sample_shape, base_samples=base_samples | ||
) | ||
|
||
def rsample_from_base_samples( | ||
self, sample_shape: torch.Size, base_samples: Tensor | ||
) -> Tensor: | ||
r"""Sample from the posterior (with gradients) using base samples. | ||
This is intended to be used with a sampler that produces the corresponding base | ||
samples, and enables acquisition optimization via Sample Average Approximation. | ||
Args: | ||
sample_shape: A `torch.Size` object specifying the sample shape. To | ||
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches | ||
of `n` samples each, set to `torch.Size([b, n])`. | ||
base_samples: A Tensor of indices as base samples of shape | ||
`sample_shape`, typically obtained from `IndexSampler`. | ||
This is used for deterministic optimization. The predictions of | ||
the ensemble corresponding to the indices are then sampled. | ||
Returns: | ||
Samples from the posterior, a tensor of shape | ||
`self._extended_shape(sample_shape=sample_shape)`. | ||
""" | ||
if base_samples.shape != sample_shape: | ||
raise ValueError("Base samples do not match sample shape.") | ||
# move sample axis to front | ||
values = self.values.movedim(-3, 0) | ||
# sample from the first dimension of values | ||
return values[base_samples, ...] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.