From 66ac930b0ea2704606b98a1d522e886694458b20 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:02:51 +0800 Subject: [PATCH] MaxPool: Move the redispatch of fast path to operator level. (#1149) The dispatch of `at::max_outf` was originally at the kernel level, move it to the op level. Another impact is, before the commit, build of torch-xpu-ops happens to fail due to missing dependency between CodeGen and Kernel lib in CMake. Co-authored-by: Feng Yuan --- src/ATen/native/xpu/DilatedMaxPool2d.cpp | 57 +++++++ src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp | 141 +++++++----------- 2 files changed, 108 insertions(+), 90 deletions(-) diff --git a/src/ATen/native/xpu/DilatedMaxPool2d.cpp b/src/ATen/native/xpu/DilatedMaxPool2d.cpp index 600d29e85..a08227b47 100644 --- a/src/ATen/native/xpu/DilatedMaxPool2d.cpp +++ b/src/ATen/native/xpu/DilatedMaxPool2d.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -40,6 +41,62 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_xpu) bool ceil_mode, const Tensor& output, const Tensor& indices) { + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 + ? kH + : safe_downcast(kernel_size[1]); + const int padH = safe_downcast(padding[0]); + const int padW = + padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1; + const int64_t nInputPlane = input.size(-3); + const int64_t inputHeight = input.size(-2); + const int64_t inputWidth = input.size(-1); + + const int64_t outputHeight = output.size(-2); + const int64_t outputWidth = output.size(-1); + if (outputHeight == 1 && outputWidth == 1 && inputHeight <= kH && + inputWidth <= kW && padH == 0 && padW == 0) { + auto smf = input.suggest_memory_format(); + Tensor input_ = input.contiguous(smf); + bool is_3d = input.ndimension() == 3; + Tensor indices_, output_; + if (is_3d) { + indices_ = indices.contiguous(); + output_ = output.contiguous(); + } else { + indices_ = indices.contiguous(smf); + output_ = output.contiguous(smf); + } + if (!is_3d) { + input_.resize_({nbatch, nInputPlane, 1, inputHeight * inputWidth}, smf); + output_.resize_( + {nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf); + indices_.resize_( + {nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf); + at::max_outf(input_, 3, true, output_, indices_); + } else { + at::max_outf(input_, 2, true, output_, indices_); + } + + if (!is_3d) { + input_.resize_({nbatch, nInputPlane, inputHeight, inputWidth}, smf); + output_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf); + indices_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf); + } + + if ((is_3d && !indices.is_contiguous()) || + (!is_3d && !indices.is_contiguous(smf))) { + indices.copy_(indices_); + } + + if ((is_3d && !output.is_contiguous()) || + (!is_3d && !output.is_contiguous(smf))) { + output.copy_(output_); + } + return; + } xpu::max_pool2d_with_indices_kernel( input, kernel_size, diff --git a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp index cba138a5f..d94db11c9 100644 --- a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp +++ b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include @@ -542,96 +541,58 @@ void max_pool2d_with_indices_kernel( const int64_t outputHeight = output.size(-2); const int64_t outputWidth = output.size(-1); - if (outputHeight == 1 && outputWidth == 1 && inputHeight <= kH && - inputWidth <= kW && padH == 0 && padW == 0) { - bool is_3d = input_.ndimension() == 3; - Tensor indices_, output_; - if (is_3d) { - indices_ = indices.contiguous(); - output_ = output.contiguous(); - } else { - indices_ = indices.contiguous(smf); - output_ = output.contiguous(smf); - } - if (!is_3d) { - input.resize_({nbatch, nInputPlane, 1, inputHeight * inputWidth}, smf); - output_.resize_( - {nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf); - indices_.resize_( - {nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf); - at::max_outf(input, 3, true, output_, indices_); - } else { - at::max_outf(input, 2, true, output_, indices_); - } - - if (!is_3d) { - input.resize_({nbatch, nInputPlane, inputHeight, inputWidth}, smf); - output_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf); - indices_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf); - } - - if ((is_3d && !indices.is_contiguous()) || - (!is_3d && !indices.is_contiguous(smf))) { - indices.copy_(indices_); - } - - if ((is_3d && !output.is_contiguous()) || - (!is_3d && !output.is_contiguous(smf))) { - output.copy_(output_); - } - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, kBFloat16, input.scalar_type(), "max_pool2d_xpu", [&] { - switch (smf) { - case MemoryFormat::ChannelsLast: { - launch_max_pool2d_kernel( - output.mutable_data_ptr(), - indices.mutable_data_ptr(), - input.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); - break; - } - case MemoryFormat::Contiguous: { - launch_max_pool2d_kernel( - output.mutable_data_ptr(), - indices.mutable_data_ptr(), - input.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); - break; - } - default: - TORCH_CHECK( - false, - "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "max_pool2d_xpu", [&] { + switch (smf) { + case MemoryFormat::ChannelsLast: { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + break; } - }); - } + case MemoryFormat::Contiguous: { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } + }); } void max_pool2d_with_indices_backward_kernel(