From 787cc7f13bd2c4a892d1c080172077c517d180a2 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Mon, 6 Nov 2023 13:27:25 -0800 Subject: [PATCH] Replace `squeeze` with `sum` in `_inf_max_helper` (#2083) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2083 This commit replaces `squeeze` with `sum` over length-1 dimensions in `_inf_max_helper` for PyTorch 1.13 compatibility. Reviewed By: Balandat Differential Revision: D51030343 fbshipit-source-id: 87ce5c5d71812a70553688d55c82aac56ddebbec --- botorch/utils/safe_math.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/botorch/utils/safe_math.py b/botorch/utils/safe_math.py index 7c4c30c984..4ec2892e90 100644 --- a/botorch/utils/safe_math.py +++ b/botorch/utils/safe_math.py @@ -180,7 +180,11 @@ def _inf_max_helper( y_inf.sum(dim=dim, keepdim=True), M_no_inf + max_fun(y_no_inf, dim=dim, keepdim=True), ) - return res if keepdim else res.squeeze(dim) + # NOTE: Using `sum` instead of `squeeze` because PyTorch < 2.0 does not support + # tuple `dim` arguments. `sum` and `squeeze` are equivalent here because the + # `dim` dimensions have length one after the reductions in the previous lines. + # TODO: Replace `sum` with `squeeze` once PyTorch >= 2.0 is required. + return res if keepdim else res.sum(dim=dim) def _any(x: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False) -> Tensor: