-
Notifications
You must be signed in to change notification settings - Fork 104
Added adaptive_max_poolNd batch rule #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@vfdev-5 Could you use one of the boxed fallbacks to handle this? Or modify one of the boxed fallbacks? |
@Chillee which boxed fallbacks you think of ? |
Looks similar at first glance to |
@Chillee I updated the PR according to your suggestion. Single cuda ci job is failing maybe due to a flaky test ... |
functorch/csrc/BatchRulesPooling.cpp
Outdated
@@ -82,6 +82,46 @@ max_pool2d_with_indices_batch_rule( | |||
reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0); | |||
} | |||
|
|||
// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it's better to just augment the boxed rule with the option to make some args contiguous?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Chillee good idea! I did an implementation with one tensor that could be contiguous (later we can make it more generic if needed). What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to pytorch#240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices.
- max_pool2d_with_indices_backward - adaptive_max_pool2d_backward - adaptive_max_pool3d_backward and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that
952a75b
to
f5777a6
Compare
* Added adaptive_max_poolNd fw/bw batch rules Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to pytorch/functorch#240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices. * Replaced EXISTING_BDIM_MULTIOUT by EXISTING_BDIM_ALL_BOXED * Removed specific implementations with indices.contiguous() for - max_pool2d_with_indices_backward - adaptive_max_pool2d_backward - adaptive_max_pool3d_backward and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that
* Added adaptive_max_poolNd fw/bw batch rules Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to pytorch/functorch#240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices. * Replaced EXISTING_BDIM_MULTIOUT by EXISTING_BDIM_ALL_BOXED * Removed specific implementations with indices.contiguous() for - max_pool2d_with_indices_backward - adaptive_max_pool2d_backward - adaptive_max_pool3d_backward and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that
Description:
Related to #240
Notes:
I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward.
Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices.