forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement PReLU in a compositional way (pytorch#91238)
The PReLU implementation was all over the place. This lead to a number of bugs like pytorch#68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap and many other functorch-related issues. - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes pytorch#68760 Fixes pytorch#89895 Pull Request resolved: pytorch#91238 Approved by: https://github.com/kshitij12345, https://github.com/jbschlosser, https://github.com/albanD
- Loading branch information
1 parent
0e8565d
commit 484dd40
Showing
23 changed files
with
152 additions
and
916 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.