Skip to content

Commit a8efd76

Browse files
jduerholtfacebook-github-bot
authored andcommitted
Ensemble Posterior (#1636)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation As discussed in #1064, this is an attempt to add a `EnsemblePosterior` to botorch, that could be used for example by NN ensembles. I have problems with implementing `rsample` properly. I think my current implementation is not correct, it is based on `DeterministicPosterior`, but I think we should sample directly solutions from the individual predictions of the ensemble. But I do not know how to interprete `sample_shape` in this context. As sampler, I registered the `StochasticSampler` for the new posterior class. But also, there I am not sure if this is correct. Furthermore, I have another question regarding `StochasticSampler`. It is stated in the docstring of `StochasticSampler` that it should not be used in combination with `optimize_acqf`. But `StochasticSampler` is assigned to the `DeterministicPosterior`. Does it mean that one cannot use a `ModelList` consisting of a `DeterministicModel` and GPs in combination with `optimize_acqf`? Balandat: any suggestions on this? ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1636 Test Plan: Unit tests. Not yet implemented/finished as it is still WIP. Reviewed By: saitcakmak Differential Revision: D43017184 Pulled By: Balandat fbshipit-source-id: fd2ede2dbba82a40c466f8a178138ced0fcba5fe
1 parent 4445045 commit a8efd76

File tree

15 files changed

+520
-59
lines changed

15 files changed

+520
-59
lines changed

botorch/models/deterministic.py

Lines changed: 7 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,16 @@
2626

2727
from __future__ import annotations
2828

29-
from abc import ABC, abstractmethod
30-
from typing import Any, Callable, List, Optional, Union
29+
from abc import abstractmethod
30+
from typing import Callable, List, Optional, Union
3131

3232
import torch
33-
from botorch.acquisition.objective import PosteriorTransform
34-
from botorch.exceptions.errors import UnsupportedError
33+
from botorch.models.ensemble import EnsembleModel
3534
from botorch.models.model import Model
36-
from botorch.posteriors.deterministic import DeterministicPosterior
3735
from torch import Tensor
3836

3937

40-
class DeterministicModel(Model, ABC):
38+
class DeterministicModel(EnsembleModel):
4139
r"""
4240
Abstract base class for deterministic models.
4341
@@ -57,55 +55,9 @@ def forward(self, X: Tensor) -> Tensor:
5755
"""
5856
pass # pragma: no cover
5957

60-
@property
61-
def num_outputs(self) -> int:
62-
r"""The number of outputs of the model."""
63-
return self._num_outputs
64-
65-
def posterior(
66-
self,
67-
X: Tensor,
68-
output_indices: Optional[List[int]] = None,
69-
posterior_transform: Optional[PosteriorTransform] = None,
70-
**kwargs: Any,
71-
) -> DeterministicPosterior:
72-
r"""Compute the (deterministic) posterior at X.
73-
74-
Args:
75-
X: A `batch_shape x n x d`-dim input tensor `X`.
76-
output_indices: A list of indices, corresponding to the outputs over
77-
which to compute the posterior. If omitted, computes the posterior
78-
over all model outputs.
79-
posterior_transform: An optional PosteriorTransform.
80-
81-
Returns:
82-
A `DeterministicPosterior` object, representing `batch_shape` joint
83-
posteriors over `n` points and the outputs selected by `output_indices`.
84-
"""
85-
# Apply the input transforms in `eval` mode.
86-
self.eval()
87-
X = self.transform_inputs(X)
88-
# Note: we use a Tensor instance check so that `observation_noise = True`
89-
# just gets ignored. This avoids having to do a bunch of case distinctions
90-
# when using a ModelList.
91-
if isinstance(kwargs.get("observation_noise"), Tensor):
92-
# TODO: Consider returning an MVN here instead
93-
raise UnsupportedError(
94-
"Deterministic models do not support observation noise."
95-
)
96-
values = self.forward(X)
97-
# NOTE: The `outcome_transform` `untransform`s the predictions rather than the
98-
# `posterior` (as is done in GP models). This is more general since it works
99-
# even if the transform doesn't support `untransform_posterior`.
100-
if hasattr(self, "outcome_transform"):
101-
values, _ = self.outcome_transform.untransform(values)
102-
if output_indices is not None:
103-
values = values[..., output_indices]
104-
posterior = DeterministicPosterior(values=values)
105-
if posterior_transform is not None:
106-
return posterior_transform(posterior)
107-
else:
108-
return posterior
58+
def _forward(self, X: Tensor) -> Tensor:
59+
r"""Compatibilizes the `DeterministicModel` with `EnsemblePosterior`"""
60+
return self.forward(X=X).unsqueeze(-3)
10961

11062

11163
class GenericDeterministicModel(DeterministicModel):

botorch/models/ensemble.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Ensemble Models: Simple wrappers that allow the usage of ensembles
9+
via the BoTorch Model and Posterior APIs.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from abc import ABC, abstractmethod
15+
from typing import Any, List, Optional
16+
17+
from botorch.acquisition.objective import PosteriorTransform
18+
from botorch.exceptions.errors import UnsupportedError
19+
from botorch.models.model import Model
20+
from botorch.posteriors.ensemble import EnsemblePosterior
21+
from torch import Tensor
22+
23+
24+
class EnsembleModel(Model, ABC):
25+
r"""
26+
Abstract base class for ensemble models.
27+
28+
:meta private:
29+
"""
30+
31+
@abstractmethod
32+
def forward(self, X: Tensor) -> Tensor:
33+
r"""Compute the (ensemble) model output at X.
34+
35+
Args:
36+
X: A `batch_shape x n x d`-dim input tensor `X`.
37+
38+
Returns:
39+
A `batch_shape x s x n x m`-dimensional output tensor where
40+
`s` is the size of the ensemble.
41+
"""
42+
pass # pragma: no cover
43+
44+
def _forward(self, X: Tensor) -> Tensor:
45+
return self.forward(X=X)
46+
47+
@property
48+
def num_outputs(self) -> int:
49+
r"""The number of outputs of the model."""
50+
return self._num_outputs
51+
52+
def posterior(
53+
self,
54+
X: Tensor,
55+
output_indices: Optional[List[int]] = None,
56+
posterior_transform: Optional[PosteriorTransform] = None,
57+
**kwargs: Any,
58+
) -> EnsemblePosterior:
59+
r"""Compute the ensemble posterior at X.
60+
61+
Args:
62+
X: A `batch_shape x q x d`-dim input tensor `X`.
63+
output_indices: A list of indices, corresponding to the outputs over
64+
which to compute the posterior. If omitted, computes the posterior
65+
over all model outputs.
66+
posterior_transform: An optional PosteriorTransform.
67+
68+
Returns:
69+
An `EnsemblePosterior` object, representing `batch_shape` joint
70+
posteriors over `n` points and the outputs selected by `output_indices`.
71+
"""
72+
# Apply the input transforms in `eval` mode.
73+
self.eval()
74+
X = self.transform_inputs(X)
75+
# Note: we use a Tensor instance check so that `observation_noise = True`
76+
# just gets ignored. This avoids having to do a bunch of case distinctions
77+
# when using a ModelList.
78+
if isinstance(kwargs.get("observation_noise"), Tensor):
79+
# TODO: Consider returning an MVN here instead
80+
raise UnsupportedError("Ensemble models do not support observation noise.")
81+
values = self._forward(X)
82+
# NOTE: The `outcome_transform` `untransform`s the predictions rather than the
83+
# `posterior` (as is done in GP models). This is more general since it works
84+
# even if the transform doesn't support `untransform_posterior`.
85+
if hasattr(self, "outcome_transform"):
86+
values, _ = self.outcome_transform.untransform(values)
87+
if output_indices is not None:
88+
values = values[..., output_indices]
89+
posterior = EnsemblePosterior(values=values)
90+
if posterior_transform is not None:
91+
return posterior_transform(posterior)
92+
else:
93+
return posterior

botorch/posteriors/deterministic.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,29 @@
1212
from __future__ import annotations
1313

1414
from typing import Optional
15+
from warnings import warn
1516

1617
import torch
1718
from botorch.posteriors.posterior import Posterior
1819
from torch import Tensor
1920

2021

2122
class DeterministicPosterior(Posterior):
22-
r"""Deterministic posterior."""
23+
r"""Deterministic posterior.
24+
25+
[DEPRECATED] Use `EnsemblePosterior` instead.
26+
"""
2327

2428
def __init__(self, values: Tensor) -> None:
2529
r"""
2630
Args:
2731
values: Values of the samples produced by this posterior.
2832
"""
33+
warn(
34+
"`DeterministicPosterior` is marked for deprecation, consider using "
35+
"`EnsemblePosterior`.",
36+
DeprecationWarning,
37+
)
2938
self.values = values
3039

3140
@property

botorch/posteriors/ensemble.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Ensemble posteriors. Used in conjunction with ensemble models.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from typing import Optional
14+
15+
import torch
16+
from botorch.posteriors.posterior import Posterior
17+
from torch import Tensor
18+
19+
20+
class EnsemblePosterior(Posterior):
21+
r"""Ensemble posterior, that should be used for ensemble models that compute
22+
eagerly a finite number of samples per X value as for example a deep ensemble
23+
or a random forest."""
24+
25+
def __init__(self, values: Tensor) -> None:
26+
r"""
27+
Args:
28+
values: Values of the samples produced by this posterior as
29+
a `(b) x s x q x m` tensor where `m` is the output size of the
30+
model and `s` is the ensemble size.
31+
"""
32+
if values.ndim < 3:
33+
raise ValueError("Values has to be at least three-dimensional.")
34+
self.values = values
35+
36+
@property
37+
def ensemble_size(self) -> int:
38+
r"""The size of the ensemble"""
39+
return self.values.shape[-3]
40+
41+
@property
42+
def weights(self) -> Tensor:
43+
r"""The weights of the individual models in the ensemble.
44+
Equally weighted by default."""
45+
return torch.ones(self.ensemble_size) / self.ensemble_size
46+
47+
@property
48+
def device(self) -> torch.device:
49+
r"""The torch device of the posterior."""
50+
return self.values.device
51+
52+
@property
53+
def dtype(self) -> torch.dtype:
54+
r"""The torch dtype of the posterior."""
55+
return self.values.dtype
56+
57+
@property
58+
def mean(self) -> Tensor:
59+
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
60+
return self.values.mean(dim=-3)
61+
62+
@property
63+
def variance(self) -> Tensor:
64+
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.
65+
66+
Computed as the sample variance across the ensemble outputs.
67+
"""
68+
if self.ensemble_size == 1:
69+
return torch.zeros_like(self.values.squeeze(-3))
70+
return self.values.var(dim=-3)
71+
72+
def _extended_shape(
73+
self, sample_shape: torch.Size = torch.Size() # noqa: B008
74+
) -> torch.Size:
75+
r"""Returns the shape of the samples produced by the posterior with
76+
the given `sample_shape`.
77+
"""
78+
return sample_shape + self.values.shape[:-3] + self.values.shape[-2:]
79+
80+
def rsample(
81+
self,
82+
sample_shape: Optional[torch.Size] = None,
83+
) -> Tensor:
84+
r"""Sample from the posterior (with gradients).
85+
86+
Based on the sample shape, base samples are generated and passed to
87+
`rsample_from_base_samples`.
88+
89+
Args:
90+
sample_shape: A `torch.Size` object specifying the sample shape. To
91+
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
92+
of `n` samples each, set to `torch.Size([b, n])`.
93+
94+
Returns:
95+
Samples from the posterior, a tensor of shape
96+
`self._extended_shape(sample_shape=sample_shape)`.
97+
"""
98+
if sample_shape is None:
99+
sample_shape = torch.Size([1])
100+
# get indices as base_samples
101+
base_samples = (
102+
torch.multinomial(
103+
self.weights,
104+
num_samples=sample_shape.numel(),
105+
replacement=True,
106+
)
107+
.reshape(sample_shape)
108+
.to(device=self.device)
109+
)
110+
return self.rsample_from_base_samples(
111+
sample_shape=sample_shape, base_samples=base_samples
112+
)
113+
114+
def rsample_from_base_samples(
115+
self, sample_shape: torch.Size, base_samples: Tensor
116+
) -> Tensor:
117+
r"""Sample from the posterior (with gradients) using base samples.
118+
119+
This is intended to be used with a sampler that produces the corresponding base
120+
samples, and enables acquisition optimization via Sample Average Approximation.
121+
122+
Args:
123+
sample_shape: A `torch.Size` object specifying the sample shape. To
124+
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
125+
of `n` samples each, set to `torch.Size([b, n])`.
126+
base_samples: A Tensor of indices as base samples of shape
127+
`sample_shape`, typically obtained from `IndexSampler`.
128+
This is used for deterministic optimization. The predictions of
129+
the ensemble corresponding to the indices are then sampled.
130+
131+
132+
Returns:
133+
Samples from the posterior, a tensor of shape
134+
`self._extended_shape(sample_shape=sample_shape)`.
135+
"""
136+
if base_samples.shape != sample_shape:
137+
raise ValueError("Base samples do not match sample shape.")
138+
# move sample axis to front
139+
values = self.values.movedim(-3, 0)
140+
# sample from the first dimension of values
141+
return values[base_samples, ...]

botorch/sampling/deterministic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class DeterministicSampler(StochasticSampler):
1818
r"""A sampler that simply calls `posterior.rsample`, intended to be used with
1919
`DeterministicModel` & `DeterministicPosterior`.
2020
21+
[DEPRECATED] - Use `IndexSampler` in conjunction with `EnsemblePosterior`
22+
instead of `DeterministicSampler` with `DeterministicPosterior`.
23+
2124
This is effectively signals that `StochasticSampler` is safe to use with
2225
deterministic models since their output is deterministic by definition.
2326
"""

botorch/sampling/get_sampler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
import torch
1111
from botorch.logging import logger
1212
from botorch.posteriors.deterministic import DeterministicPosterior
13+
from botorch.posteriors.ensemble import EnsemblePosterior
1314
from botorch.posteriors.gpytorch import GPyTorchPosterior
1415
from botorch.posteriors.posterior import Posterior
1516
from botorch.posteriors.posterior_list import PosteriorList
1617
from botorch.posteriors.torch import TorchPosterior
1718
from botorch.posteriors.transformed import TransformedPosterior
1819
from botorch.sampling.base import MCSampler
1920
from botorch.sampling.deterministic import DeterministicSampler
21+
from botorch.sampling.index_sampler import IndexSampler
2022
from botorch.sampling.list_sampler import ListSampler
2123
from botorch.sampling.normal import (
2224
IIDNormalSampler,
@@ -111,10 +113,18 @@ def _get_sampler_list(
111113
def _get_sampler_deterministic(
112114
posterior: DeterministicPosterior, sample_shape: torch.Size, **kwargs: Any
113115
) -> MCSampler:
114-
r"""Get the dummy `StochasticSampler` for the `DeterministicPosterior`."""
116+
r"""Get the dummy `DeterministicSampler` for the `DeterministicPosterior`."""
115117
return DeterministicSampler(sample_shape=sample_shape, **kwargs)
116118

117119

120+
@GetSampler.register(EnsemblePosterior)
121+
def _get_sampler_ensemble(
122+
posterior: EnsemblePosterior, sample_shape: torch.Size, **kwargs: Any
123+
) -> MCSampler:
124+
r"""Get the `IndexSampler` for the `EnsemblePosterior`."""
125+
return IndexSampler(sample_shape=sample_shape, **kwargs)
126+
127+
118128
@GetSampler.register(object)
119129
def _not_found_error(
120130
posterior: Posterior, sample_shape: torch.Size, **kwargs: Any

0 commit comments

Comments
 (0)