diff --git a/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/csrc/BatchRulesDecompositions.cpp index 475748ddb..04a183a4b 100644 --- a/functorch/csrc/BatchRulesDecompositions.cpp +++ b/functorch/csrc/BatchRulesDecompositions.cpp @@ -19,6 +19,7 @@ namespace at { namespace functorch { TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE(absolute); OP_DECOMPOSE(avg_pool1d); + OP_DECOMPOSE(adaptive_max_pool1d); OP_DECOMPOSE(adaptive_avg_pool1d); OP_DECOMPOSE(adaptive_avg_pool2d); OP_DECOMPOSE(adaptive_avg_pool3d); diff --git a/functorch/csrc/BatchRulesHelper.h b/functorch/csrc/BatchRulesHelper.h index c19002740..8ef838170 100644 --- a/functorch/csrc/BatchRulesHelper.h +++ b/functorch/csrc/BatchRulesHelper.h @@ -265,7 +265,7 @@ inline void boxed_existing_bdim_all_batch_rule( #define EXISTING_BDIM_ALL_BOXED(op) \ m.impl(#op, torch::CppFunction::makeFromBoxedFunction()); -template +template inline void boxed_all_tensors_have_optional_bdim( const c10::OperatorHandle& op, torch::jit::Stack* stack) { const auto& schema = op.schema(); @@ -302,11 +302,19 @@ inline void boxed_all_tensors_have_optional_bdim( } if (*is_no_batch_dim_case) { TORCH_INTERNAL_ASSERT(logical_rank == feature_rank); - (*stack)[args_begin + tensor_pos[tensor_idx]] = moveBatchDimToFront(value_, bdim); + value_ = moveBatchDimToFront(value_, bdim); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; continue; } TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1); - (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_); + value_ = reshape_dim_into(*bdim, 0, value_); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; } op.callBoxed(stack); @@ -330,6 +338,14 @@ inline void boxed_all_tensors_have_optional_bdim( #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \ m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \ + m.impl(#op, \ + torch::CppFunction::makeFromBoxedFunction<\ + boxed_all_tensors_have_optional_bdim<\ + feature_rank, \ + contig_tensor_index>\ + >()); + template struct ExistingBdimBatchRuleHelper; diff --git a/functorch/csrc/BatchRulesPooling.cpp b/functorch/csrc/BatchRulesPooling.cpp index df144bf70..8a2bb2d7d 100644 --- a/functorch/csrc/BatchRulesPooling.cpp +++ b/functorch/csrc/BatchRulesPooling.cpp @@ -26,38 +26,6 @@ static Tensor reshape_bdim_into_front( return reshape_dim_into(*bdim, 0, value_); } -// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA -// kernel rightfully assumes that indices is contiguous. -std::tuple> max_pool2d_with_indices_backward_batch_rule( - const Tensor& gradOutput, optional gradOutput_bdim, - const Tensor& input, optional input_bdim, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode, - const Tensor& indices, optional indices_bdim) { - TORCH_INTERNAL_ASSERT(input_bdim.has_value() ^ !indices_bdim.has_value()); - const auto bdim_size = get_bdim_size2(gradOutput, gradOutput_bdim, input, input_bdim); - const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); - bool chw_case = input_logical_rank == 3; - - const auto gradOutput_ = reshape_bdim_into_front(gradOutput, gradOutput_bdim, bdim_size, chw_case); - const auto input_ = reshape_bdim_into_front(input, input_bdim, bdim_size, chw_case); - const auto indices_ = reshape_bdim_into_front(indices, indices_bdim, bdim_size, chw_case); - - const auto result = at::max_pool2d_with_indices_backward( - gradOutput_, input_, kernel_size, stride, padding, dilation, ceil_mode, - // max_pool2d_with_indices rightfully assumes that indices is contiguous - indices_.contiguous()); - - if (chw_case) { - return std::make_tuple(std::move(result), 0); - } else { - return std::make_tuple(reshape_dim_outof(0, bdim_size, result), 0); - } -} - std::tuple,Tensor,optional> max_pool2d_with_indices_batch_rule( const Tensor& self, optional self_bdim, @@ -91,8 +59,13 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM(avg_pool3d); EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward); EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward); + EXISTING_BDIM_ALL_BOXED(adaptive_max_pool2d); + EXISTING_BDIM_ALL_BOXED(adaptive_max_pool3d); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, adaptive_max_pool2d_backward, 2); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, adaptive_max_pool3d_backward, 2); + VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule); - VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, max_pool2d_with_indices_backward, 2); } }} diff --git a/test/test_ops.py b/test/test_ops.py index ffb08b3cc..9008c031e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -492,9 +492,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('diagonal_scatter'), xfail('double', 'channels_last'), xfail('linalg.cross'), - xfail('nn.functional.adaptive_max_pool1d'), - xfail('nn.functional.adaptive_max_pool2d'), - xfail('nn.functional.adaptive_max_pool3d'), xfail('nn.functional.conv1d'), xfail('nn.functional.gaussian_nll_loss'), xfail('nn.functional.hardsigmoid'), diff --git a/test/test_vmap.py b/test/test_vmap.py index 837bfb894..2b08818d8 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3201,9 +3201,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('slice_scatter'), xfail('unique_consecutive'), xfail('unique'), - xfail('nn.functional.adaptive_max_pool1d'), - xfail('nn.functional.adaptive_max_pool2d'), - xfail('nn.functional.adaptive_max_pool3d'), xfail('nn.functional.conv1d'), xfail('nn.functional.cosine_embedding_loss'), # xfail('nn.functional.cross_entropy'),