-
Notifications
You must be signed in to change notification settings - Fork 401
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make BaseTestFunction.evaluate_true accept 1d inputs (#2492)
Summary: Pull Request resolved: #2492 Context: Currently, every test function except for `AugmentedBranin` has an `evaluate_true` method that works with 1d inputs. It is actually surprising that so many work, since `BaseTestFunction` is currently written so that `BaseTestFunction.forward` casts inputs to 2d before passing them to `BaseTestFunction.evaluate_true`. So currently, it's not clear if we should expect `evaluate_true` to work with 1d inputs, but nonetheless this is happening downstream. This PR: * Requires `evaluate_true` to work with 1d inputs * Removes the logic that expands the dimension of unbatched tensors in `forward` before passign to `evaluate_true` and then removes the batch dimension in favor of leaving unbatched tensors unbatched everywhere * Changes `AugmentedBranin` to work with 1d inputs * Fixes a couple type errors * Expands docstrings Differential Revision: D61916387
- Loading branch information
1 parent
017a124
commit 22040c4
Showing
3 changed files
with
34 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters