-
Notifications
You must be signed in to change notification settings - Fork 400
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow broadcasting across dimensions in eval mode; always require X t…
…o be at least 2d Summary: Context: A discussion on allowable shapes for transforms concluded: * We should not allow for broadcasting across the -1 dimension, so the first check in _check_shape should always happen. * The shapes always need to be broadcastable, so the torch.broadcast_shapes check in _check_shape should always happen. * We want to allow for broadcasting across the batch dimension in eval model, so the check that X has dimension of at least len(batch_shape) + 2 should only happen in training mode. * For clarity, we should disallow 1d X, even if broadcastable. BoTorch tends to be strict about requiring explicit dimensions, e.g. GPyTorchModel._validate_tensor_args, and that's a good thing because confusion about tensor dimensions causes a lot of pain. This diff: * Only checks that X has number of dimensions equal to 2 + the number of batch dimensions in training mode. * Disallows <2d X. Differential Revision: D62404492
- Loading branch information
1 parent
33e11f4
commit 1618797
Showing
2 changed files
with
39 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters