From b666bc811c87d04c1dab8f82860c295b03d2f790 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 15 Nov 2021 09:42:50 +0000 Subject: [PATCH 1/3] Added adaptive_max_poolNd fw/bw batch rules Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to #240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices. --- functorch/csrc/BatchRulesDecompositions.cpp | 1 + functorch/csrc/BatchRulesPooling.cpp | 75 +++++++++++++++++++++ test/test_ops.py | 3 - test/test_vmap.py | 3 - 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/csrc/BatchRulesDecompositions.cpp index 475748ddb..04a183a4b 100644 --- a/functorch/csrc/BatchRulesDecompositions.cpp +++ b/functorch/csrc/BatchRulesDecompositions.cpp @@ -19,6 +19,7 @@ namespace at { namespace functorch { TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE(absolute); OP_DECOMPOSE(avg_pool1d); + OP_DECOMPOSE(adaptive_max_pool1d); OP_DECOMPOSE(adaptive_avg_pool1d); OP_DECOMPOSE(adaptive_avg_pool2d); OP_DECOMPOSE(adaptive_avg_pool3d); diff --git a/functorch/csrc/BatchRulesPooling.cpp b/functorch/csrc/BatchRulesPooling.cpp index df144bf70..f0b418faf 100644 --- a/functorch/csrc/BatchRulesPooling.cpp +++ b/functorch/csrc/BatchRulesPooling.cpp @@ -82,6 +82,76 @@ max_pool2d_with_indices_batch_rule( reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0); } +template +struct ExistingBdimMultiOutputBatchRuleHelper; + +template +struct ExistingBdimMultiOutputBatchRuleHelper> { + static std::tuple,Tensor,optional> apply( + const Tensor& self, + optional self_bdim, + T... extra_args) { + auto self_ = reshape_dim_into(*self_bdim, 0, self); + auto out = Func(self_, std::forward(extra_args)...); + auto odim = self.sizes()[*self_bdim]; + return std::make_tuple( + reshape_dim_outof(0, odim, std::get<0>(out)), + 0, + reshape_dim_outof(0, odim, std::get<1>(out)), + 0 + ); + } +}; + +#define EXISTING_BDIM_MULTIOUT_BATCH_RULE(fn) SINGLE_ARG(\ + ExistingBdimMultiOutputBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +#define EXISTING_BDIM_MULTIOUT(op) \ + VMAP_SUPPORT(#op, EXISTING_BDIM_MULTIOUT_BATCH_RULE(ATEN_FN(op))); + +// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA +// kernel rightfully assumes that indices is contiguous. +template +std::tuple> adaptive_max_poolNd_backward_batch_rule( + const Tensor& gradOutput, optional gradOutput_bdim, + const Tensor& input, optional input_bdim, + const Tensor& indices, optional indices_bdim) { + TORCH_INTERNAL_ASSERT(input_bdim.has_value() ^ !indices_bdim.has_value()); + const auto bdim_size = get_bdim_size2(gradOutput, gradOutput_bdim, input, input_bdim); + const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); + bool is_no_batch_dim_case = input_logical_rank == N + 1; + + const auto gradOutput_ = reshape_bdim_into_front(gradOutput, gradOutput_bdim, bdim_size, is_no_batch_dim_case); + const auto input_ = reshape_bdim_into_front(input, input_bdim, bdim_size, is_no_batch_dim_case); + const auto indices_ = reshape_bdim_into_front(indices, indices_bdim, bdim_size, is_no_batch_dim_case); + + const auto result = func(gradOutput_, input_, indices_.contiguous()); + + if (is_no_batch_dim_case) { + return std::make_tuple(std::move(result), 0); + } else { + return std::make_tuple(reshape_dim_outof(0, bdim_size, result), 0); + } +} + +template +struct AdaptiveMaxPoolNdBatchRuleHelper { + static std::tuple> apply( + const Tensor& gradOutput, optional gradOutput_bdim, + const Tensor& input, optional input_bdim, + const Tensor& indices, optional indices_bdim) { + return adaptive_max_poolNd_backward_batch_rule( + gradOutput, gradOutput_bdim, input, input_bdim, indices, indices_bdim); + } +}; + +#define ADAPTIVE_MAX_POOL_ND_BATCH_RULE(fn, n) \ + AdaptiveMaxPoolNdBatchRuleHelper::apply + + TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM(_adaptive_avg_pool2d); EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward); @@ -91,6 +161,11 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM(avg_pool3d); EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward); EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward); + EXISTING_BDIM_MULTIOUT(adaptive_max_pool2d); + EXISTING_BDIM_MULTIOUT(adaptive_max_pool3d); + VMAP_SUPPORT("adaptive_max_pool2d_backward", ADAPTIVE_MAX_POOL_ND_BATCH_RULE(adaptive_max_pool2d_backward, 2)); + VMAP_SUPPORT("adaptive_max_pool3d_backward", ADAPTIVE_MAX_POOL_ND_BATCH_RULE(adaptive_max_pool3d_backward, 3)); + VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule); VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule); } diff --git a/test/test_ops.py b/test/test_ops.py index ffb08b3cc..9008c031e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -492,9 +492,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('diagonal_scatter'), xfail('double', 'channels_last'), xfail('linalg.cross'), - xfail('nn.functional.adaptive_max_pool1d'), - xfail('nn.functional.adaptive_max_pool2d'), - xfail('nn.functional.adaptive_max_pool3d'), xfail('nn.functional.conv1d'), xfail('nn.functional.gaussian_nll_loss'), xfail('nn.functional.hardsigmoid'), diff --git a/test/test_vmap.py b/test/test_vmap.py index 837bfb894..2b08818d8 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3201,9 +3201,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('slice_scatter'), xfail('unique_consecutive'), xfail('unique'), - xfail('nn.functional.adaptive_max_pool1d'), - xfail('nn.functional.adaptive_max_pool2d'), - xfail('nn.functional.adaptive_max_pool3d'), xfail('nn.functional.conv1d'), xfail('nn.functional.cosine_embedding_loss'), # xfail('nn.functional.cross_entropy'), From e24fc35bb8762a02b041facdb562362964010b5d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 17 Nov 2021 10:50:51 +0000 Subject: [PATCH 2/3] Replaced EXISTING_BDIM_MULTIOUT by EXISTING_BDIM_ALL_BOXED --- functorch/csrc/BatchRulesPooling.cpp | 34 ++-------------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/functorch/csrc/BatchRulesPooling.cpp b/functorch/csrc/BatchRulesPooling.cpp index f0b418faf..e4d7bd5c9 100644 --- a/functorch/csrc/BatchRulesPooling.cpp +++ b/functorch/csrc/BatchRulesPooling.cpp @@ -82,36 +82,6 @@ max_pool2d_with_indices_batch_rule( reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0); } -template -struct ExistingBdimMultiOutputBatchRuleHelper; - -template -struct ExistingBdimMultiOutputBatchRuleHelper> { - static std::tuple,Tensor,optional> apply( - const Tensor& self, - optional self_bdim, - T... extra_args) { - auto self_ = reshape_dim_into(*self_bdim, 0, self); - auto out = Func(self_, std::forward(extra_args)...); - auto odim = self.sizes()[*self_bdim]; - return std::make_tuple( - reshape_dim_outof(0, odim, std::get<0>(out)), - 0, - reshape_dim_outof(0, odim, std::get<1>(out)), - 0 - ); - } -}; - -#define EXISTING_BDIM_MULTIOUT_BATCH_RULE(fn) SINGLE_ARG(\ - ExistingBdimMultiOutputBatchRuleHelper<\ - decltype(&fn),\ - &fn,\ - c10::guts::function_traits::parameter_types>::apply) - -#define EXISTING_BDIM_MULTIOUT(op) \ - VMAP_SUPPORT(#op, EXISTING_BDIM_MULTIOUT_BATCH_RULE(ATEN_FN(op))); - // We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA // kernel rightfully assumes that indices is contiguous. template @@ -161,8 +131,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM(avg_pool3d); EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward); EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward); - EXISTING_BDIM_MULTIOUT(adaptive_max_pool2d); - EXISTING_BDIM_MULTIOUT(adaptive_max_pool3d); + EXISTING_BDIM_ALL_BOXED(adaptive_max_pool2d); + EXISTING_BDIM_ALL_BOXED(adaptive_max_pool3d); VMAP_SUPPORT("adaptive_max_pool2d_backward", ADAPTIVE_MAX_POOL_ND_BATCH_RULE(adaptive_max_pool2d_backward, 2)); VMAP_SUPPORT("adaptive_max_pool3d_backward", ADAPTIVE_MAX_POOL_ND_BATCH_RULE(adaptive_max_pool3d_backward, 3)); From f5777a66e1f0e53cc99bd0bffc7bbc7bc33aad10 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 23 Nov 2021 12:41:38 +0000 Subject: [PATCH 3/3] Removed specific implementations with indices.contiguous() for - max_pool2d_with_indices_backward - adaptive_max_pool2d_backward - adaptive_max_pool3d_backward and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that --- functorch/csrc/BatchRulesHelper.h | 22 ++++++-- functorch/csrc/BatchRulesPooling.cpp | 78 ++-------------------------- 2 files changed, 22 insertions(+), 78 deletions(-) diff --git a/functorch/csrc/BatchRulesHelper.h b/functorch/csrc/BatchRulesHelper.h index c19002740..8ef838170 100644 --- a/functorch/csrc/BatchRulesHelper.h +++ b/functorch/csrc/BatchRulesHelper.h @@ -265,7 +265,7 @@ inline void boxed_existing_bdim_all_batch_rule( #define EXISTING_BDIM_ALL_BOXED(op) \ m.impl(#op, torch::CppFunction::makeFromBoxedFunction()); -template +template inline void boxed_all_tensors_have_optional_bdim( const c10::OperatorHandle& op, torch::jit::Stack* stack) { const auto& schema = op.schema(); @@ -302,11 +302,19 @@ inline void boxed_all_tensors_have_optional_bdim( } if (*is_no_batch_dim_case) { TORCH_INTERNAL_ASSERT(logical_rank == feature_rank); - (*stack)[args_begin + tensor_pos[tensor_idx]] = moveBatchDimToFront(value_, bdim); + value_ = moveBatchDimToFront(value_, bdim); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; continue; } TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1); - (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_); + value_ = reshape_dim_into(*bdim, 0, value_); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; } op.callBoxed(stack); @@ -330,6 +338,14 @@ inline void boxed_all_tensors_have_optional_bdim( #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \ m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \ + m.impl(#op, \ + torch::CppFunction::makeFromBoxedFunction<\ + boxed_all_tensors_have_optional_bdim<\ + feature_rank, \ + contig_tensor_index>\ + >()); + template struct ExistingBdimBatchRuleHelper; diff --git a/functorch/csrc/BatchRulesPooling.cpp b/functorch/csrc/BatchRulesPooling.cpp index e4d7bd5c9..8a2bb2d7d 100644 --- a/functorch/csrc/BatchRulesPooling.cpp +++ b/functorch/csrc/BatchRulesPooling.cpp @@ -26,38 +26,6 @@ static Tensor reshape_bdim_into_front( return reshape_dim_into(*bdim, 0, value_); } -// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA -// kernel rightfully assumes that indices is contiguous. -std::tuple> max_pool2d_with_indices_backward_batch_rule( - const Tensor& gradOutput, optional gradOutput_bdim, - const Tensor& input, optional input_bdim, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode, - const Tensor& indices, optional indices_bdim) { - TORCH_INTERNAL_ASSERT(input_bdim.has_value() ^ !indices_bdim.has_value()); - const auto bdim_size = get_bdim_size2(gradOutput, gradOutput_bdim, input, input_bdim); - const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); - bool chw_case = input_logical_rank == 3; - - const auto gradOutput_ = reshape_bdim_into_front(gradOutput, gradOutput_bdim, bdim_size, chw_case); - const auto input_ = reshape_bdim_into_front(input, input_bdim, bdim_size, chw_case); - const auto indices_ = reshape_bdim_into_front(indices, indices_bdim, bdim_size, chw_case); - - const auto result = at::max_pool2d_with_indices_backward( - gradOutput_, input_, kernel_size, stride, padding, dilation, ceil_mode, - // max_pool2d_with_indices rightfully assumes that indices is contiguous - indices_.contiguous()); - - if (chw_case) { - return std::make_tuple(std::move(result), 0); - } else { - return std::make_tuple(reshape_dim_outof(0, bdim_size, result), 0); - } -} - std::tuple,Tensor,optional> max_pool2d_with_indices_batch_rule( const Tensor& self, optional self_bdim, @@ -82,46 +50,6 @@ max_pool2d_with_indices_batch_rule( reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0); } -// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA -// kernel rightfully assumes that indices is contiguous. -template -std::tuple> adaptive_max_poolNd_backward_batch_rule( - const Tensor& gradOutput, optional gradOutput_bdim, - const Tensor& input, optional input_bdim, - const Tensor& indices, optional indices_bdim) { - TORCH_INTERNAL_ASSERT(input_bdim.has_value() ^ !indices_bdim.has_value()); - const auto bdim_size = get_bdim_size2(gradOutput, gradOutput_bdim, input, input_bdim); - const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); - bool is_no_batch_dim_case = input_logical_rank == N + 1; - - const auto gradOutput_ = reshape_bdim_into_front(gradOutput, gradOutput_bdim, bdim_size, is_no_batch_dim_case); - const auto input_ = reshape_bdim_into_front(input, input_bdim, bdim_size, is_no_batch_dim_case); - const auto indices_ = reshape_bdim_into_front(indices, indices_bdim, bdim_size, is_no_batch_dim_case); - - const auto result = func(gradOutput_, input_, indices_.contiguous()); - - if (is_no_batch_dim_case) { - return std::make_tuple(std::move(result), 0); - } else { - return std::make_tuple(reshape_dim_outof(0, bdim_size, result), 0); - } -} - -template -struct AdaptiveMaxPoolNdBatchRuleHelper { - static std::tuple> apply( - const Tensor& gradOutput, optional gradOutput_bdim, - const Tensor& input, optional input_bdim, - const Tensor& indices, optional indices_bdim) { - return adaptive_max_poolNd_backward_batch_rule( - gradOutput, gradOutput_bdim, input, input_bdim, indices, indices_bdim); - } -}; - -#define ADAPTIVE_MAX_POOL_ND_BATCH_RULE(fn, n) \ - AdaptiveMaxPoolNdBatchRuleHelper::apply - - TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM(_adaptive_avg_pool2d); EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward); @@ -133,11 +61,11 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward); EXISTING_BDIM_ALL_BOXED(adaptive_max_pool2d); EXISTING_BDIM_ALL_BOXED(adaptive_max_pool3d); - VMAP_SUPPORT("adaptive_max_pool2d_backward", ADAPTIVE_MAX_POOL_ND_BATCH_RULE(adaptive_max_pool2d_backward, 2)); - VMAP_SUPPORT("adaptive_max_pool3d_backward", ADAPTIVE_MAX_POOL_ND_BATCH_RULE(adaptive_max_pool3d_backward, 3)); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, adaptive_max_pool2d_backward, 2); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, adaptive_max_pool3d_backward, 2); VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule); - VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, max_pool2d_with_indices_backward, 2); } }}