Skip to content

Commit 787cc7f

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Replace squeeze with sum in _inf_max_helper (#2083)
Summary: Pull Request resolved: #2083 This commit replaces `squeeze` with `sum` over length-1 dimensions in `_inf_max_helper` for PyTorch 1.13 compatibility. Reviewed By: Balandat Differential Revision: D51030343 fbshipit-source-id: 87ce5c5d71812a70553688d55c82aac56ddebbec
1 parent 6bde5d4 commit 787cc7f

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

botorch/utils/safe_math.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,11 @@ def _inf_max_helper(
180180
y_inf.sum(dim=dim, keepdim=True),
181181
M_no_inf + max_fun(y_no_inf, dim=dim, keepdim=True),
182182
)
183-
return res if keepdim else res.squeeze(dim)
183+
# NOTE: Using `sum` instead of `squeeze` because PyTorch < 2.0 does not support
184+
# tuple `dim` arguments. `sum` and `squeeze` are equivalent here because the
185+
# `dim` dimensions have length one after the reductions in the previous lines.
186+
# TODO: Replace `sum` with `squeeze` once PyTorch >= 2.0 is required.
187+
return res if keepdim else res.sum(dim=dim)
184188

185189

186190
def _any(x: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False) -> Tensor:

0 commit comments

Comments
 (0)