-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize CDF Calculation and Convert NumPy Arrays to Tensors in Bench…
…mark (#399) Summary: ### PR Description This PR addresses the first step in making AEPsych's functions consistently return PyTorch tensors and expect tensors as input, improving compatibility with GPUs and reducing redundant conversions between NumPy arrays and PyTorch tensors(partially solving #365). #### Key changes include: 1. **Conversion of `np.arrays` to tensors** in the following files: - **`aepsych/models/base.py`**: - Refactored the `p_below_threshold` method to operate fully with PyTorch tensors. - Replaced `norm.cdf()` with `torch.distributions.Normal(0, 1).cdf()` for better GPU compatibility. - **`aepsych/benchmark/problem.py`**: - Significant changes made to ensure consistent use of tensors across the pipeline. - The result of `f_threshold()` now directly returns a PyTorch tensor, ensuring consistency. - Additionally, used `detach().cpu().numpy()` in places where the `super().evaluate()` method returns float values, ensuring compatibility. 2. **Updates in `aepsych/tests/test_benchmark.py`**: - Migrated all operations from NumPy to PyTorch. - This includes calculations for Brier score and misclassification error, now utilizing `torch.mean()`, `torch.square()`, `torch.isclose()`, and `torch.all()` to fully align with tensor operations. #### Stability: All test cases have passed successfully in the workflow. Pull Request resolved: #399 Reviewed By: crasanders Differential Revision: D64245698 Pulled By: JasonKChow fbshipit-source-id: 3ed3d7b627f488ec61da5b9013a46cafc8b83556
- Loading branch information
1 parent
b278dd1
commit 45d8e2d
Showing
3 changed files
with
101 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters