From deeca6464957eb8d9731c0d51b856f820f903f86 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Wed, 8 Dec 2021 13:04:31 -0800 Subject: [PATCH] Short-cut t_batch_mode_transform decorator on non-tensor inputs (#991) Summary: This essentially makes `t_batch_mode_tarnsform` a nullop in case the argument `X` passed to the acquisition function is not a `torch.Tensor` object. This allows using acquisition functions that use models with non non-standard input types such as strings, which is the case in some applications. Currently this just touches the decorator; in the future we should consider changing the types and signatures of the acquisition functions and models throughout to natively support this more generally. cc wjmaddox Pull Request resolved: https://github.com/pytorch/botorch/pull/991 Reviewed By: dme65 Differential Revision: D32903859 Pulled By: Balandat fbshipit-source-id: c3abc8b40db307358807fe60014e2c0e1fe49c58 --- botorch/utils/transforms.py | 25 +++++++++++++++++-------- test/utils/test_transforms.py | 5 +++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index a1d68e001a..552e338b7f 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -12,11 +12,13 @@ import warnings from functools import wraps -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, TypeVar import torch from torch import Tensor +ACQF = TypeVar("AcquisitionFunction") + def squeeze_last_dim(Y: Tensor) -> Tensor: r"""Squeeze the last dimension of a Tensor. @@ -172,8 +174,8 @@ def _verify_output_shape(acqf: Any, X: Tensor, output: Tensor) -> bool: def t_batch_mode_transform( expected_q: Optional[int] = None, assert_output_shape: bool = True, -) -> Callable[[Callable[[Any, Tensor], Any]], Callable[[Any, Tensor], Any]]: - r"""Factory for decorators taking a t-batched `X` tensor. +) -> Callable[[Callable[[ACQF, Any], Any]], Callable[[ACQF, Any], Any]]: + r"""Factory for decorators enabling consistent t-batch behavior. This method creates decorators for instance methods to transform an input tensor `X` to t-batch mode (i.e. with at least 3 dimensions). This assumes the tensor @@ -181,10 +183,10 @@ def t_batch_mode_transform( is provided, and the output shape if `assert_output_shape` is `True`. Args: - expected_q: The expected q-batch size of X. If specified, this will raise an - AssertionError if X's q-batch size does not equal expected_q. + expected_q: The expected q-batch size of `X`. If specified, this will raise an + AssertionError if `X`'s q-batch size does not equal expected_q. assert_output_shape: If `True`, this will raise an AssertionError if the - output shape does not match either the t-batch shape of X, + output shape does not match either the t-batch shape of `X`, or the `acqf.model.batch_shape` for acquisition functions using batched models. @@ -202,9 +204,16 @@ def t_batch_mode_transform( >>> ... """ - def decorator(method: Callable[[Any, Tensor], Any]) -> Callable[[Any, Tensor], Any]: + def decorator( + method: Callable[[ACQF, Any], Any], + ) -> Callable[[ACQF, Any], Any]: @wraps(method) - def decorated(acqf: Any, X: Tensor, *args: Any, **kwargs: Any) -> Any: + def decorated(acqf: ACQF, X: Any, *args: Any, **kwargs: Any) -> Any: + + # Allow using acquisition functions for other inputs (e.g. lists of strings) + if not isinstance(X, Tensor): + return method(acqf, X, *args, **kwargs) + if X.dim() < 2: raise ValueError( f"{type(acqf).__name__} requires X to have at least 2 dimensions," diff --git a/test/utils/test_transforms.py b/test/utils/test_transforms.py index 892e132ec6..9da08925eb 100644 --- a/test/utils/test_transforms.py +++ b/test/utils/test_transforms.py @@ -191,6 +191,11 @@ def test_t_batch_mode_transform(self): Xout = c.broadcast_batch_shape_method(X) self.assertEqual(Xout.shape, c.model.batch_shape) + # test with non-tensor argument + X = ((3, 4), {"foo": True}) + Xout = c.q_method(X) + self.assertEqual(X, Xout) + class TestConcatenatePendingPoints(BotorchTestCase): def test_concatenate_pending_points(self):