Skip to content

Commit

Permalink
Short-cut t_batch_mode_transform decorator on non-tensor inputs (#991)
Browse files Browse the repository at this point in the history
Summary:
This essentially makes `t_batch_mode_tarnsform` a nullop in case the argument `X` passed to the acquisition function is not a `torch.Tensor` object. This allows using acquisition functions that use models with non non-standard input types such as strings, which is the case in some applications.

Currently this just touches the decorator; in the future we should consider changing the types and signatures of the acquisition functions and models throughout to natively support this more generally.

cc wjmaddox

Pull Request resolved: #991

Reviewed By: dme65

Differential Revision: D32903859

Pulled By: Balandat

fbshipit-source-id: c3abc8b40db307358807fe60014e2c0e1fe49c58
  • Loading branch information
Balandat authored and facebook-github-bot committed Dec 8, 2021
1 parent c6bc8f9 commit deeca64
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
25 changes: 17 additions & 8 deletions botorch/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

import warnings
from functools import wraps
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, TypeVar

import torch
from torch import Tensor

ACQF = TypeVar("AcquisitionFunction")


def squeeze_last_dim(Y: Tensor) -> Tensor:
r"""Squeeze the last dimension of a Tensor.
Expand Down Expand Up @@ -172,19 +174,19 @@ def _verify_output_shape(acqf: Any, X: Tensor, output: Tensor) -> bool:
def t_batch_mode_transform(
expected_q: Optional[int] = None,
assert_output_shape: bool = True,
) -> Callable[[Callable[[Any, Tensor], Any]], Callable[[Any, Tensor], Any]]:
r"""Factory for decorators taking a t-batched `X` tensor.
) -> Callable[[Callable[[ACQF, Any], Any]], Callable[[ACQF, Any], Any]]:
r"""Factory for decorators enabling consistent t-batch behavior.
This method creates decorators for instance methods to transform an input tensor
`X` to t-batch mode (i.e. with at least 3 dimensions). This assumes the tensor
has a q-batch dimension. The decorator also checks the q-batch size if `expected_q`
is provided, and the output shape if `assert_output_shape` is `True`.
Args:
expected_q: The expected q-batch size of X. If specified, this will raise an
AssertionError if X's q-batch size does not equal expected_q.
expected_q: The expected q-batch size of `X`. If specified, this will raise an
AssertionError if `X`'s q-batch size does not equal expected_q.
assert_output_shape: If `True`, this will raise an AssertionError if the
output shape does not match either the t-batch shape of X,
output shape does not match either the t-batch shape of `X`,
or the `acqf.model.batch_shape` for acquisition functions using
batched models.
Expand All @@ -202,9 +204,16 @@ def t_batch_mode_transform(
>>> ...
"""

def decorator(method: Callable[[Any, Tensor], Any]) -> Callable[[Any, Tensor], Any]:
def decorator(
method: Callable[[ACQF, Any], Any],
) -> Callable[[ACQF, Any], Any]:
@wraps(method)
def decorated(acqf: Any, X: Tensor, *args: Any, **kwargs: Any) -> Any:
def decorated(acqf: ACQF, X: Any, *args: Any, **kwargs: Any) -> Any:

# Allow using acquisition functions for other inputs (e.g. lists of strings)
if not isinstance(X, Tensor):
return method(acqf, X, *args, **kwargs)

if X.dim() < 2:
raise ValueError(
f"{type(acqf).__name__} requires X to have at least 2 dimensions,"
Expand Down
5 changes: 5 additions & 0 deletions test/utils/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ def test_t_batch_mode_transform(self):
Xout = c.broadcast_batch_shape_method(X)
self.assertEqual(Xout.shape, c.model.batch_shape)

# test with non-tensor argument
X = ((3, 4), {"foo": True})
Xout = c.q_method(X)
self.assertEqual(X, Xout)


class TestConcatenatePendingPoints(BotorchTestCase):
def test_concatenate_pending_points(self):
Expand Down

0 comments on commit deeca64

Please sign in to comment.