Skip to content

Commit

Permalink
MaxPool: Move the redispatch of fast path to operator level. (#1149)
Browse files Browse the repository at this point in the history
The dispatch of `at::max_outf` was originally at the kernel level, move
it to the op level.
Another impact is, before the commit, build of torch-xpu-ops happens to
fail due to missing dependency between CodeGen and Kernel lib in CMake.

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
chunhuanMeng and fengyuan14 authored Dec 11, 2024
1 parent 28cdc6b commit 66ac930
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 90 deletions.
57 changes: 57 additions & 0 deletions src/ATen/native/xpu/DilatedMaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/native/xpu/sycl/DilatedMaxPool2d.h>
#include <comm/RegisterUtils.h>

#include <xpu/ATen/ops/max.h>
#include <xpu/ATen/ops/max_pool2d_with_indices_backward_native.h>
#include <xpu/ATen/ops/max_pool2d_with_indices_native.h>

Expand Down Expand Up @@ -40,6 +41,62 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_xpu)
bool ceil_mode,
const Tensor& output,
const Tensor& indices) {
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1
? kH
: safe_downcast<int, int64_t>(kernel_size[1]);
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW =
padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);

const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);

const int64_t outputHeight = output.size(-2);
const int64_t outputWidth = output.size(-1);
if (outputHeight == 1 && outputWidth == 1 && inputHeight <= kH &&
inputWidth <= kW && padH == 0 && padW == 0) {
auto smf = input.suggest_memory_format();
Tensor input_ = input.contiguous(smf);
bool is_3d = input.ndimension() == 3;
Tensor indices_, output_;
if (is_3d) {
indices_ = indices.contiguous();
output_ = output.contiguous();
} else {
indices_ = indices.contiguous(smf);
output_ = output.contiguous(smf);
}
if (!is_3d) {
input_.resize_({nbatch, nInputPlane, 1, inputHeight * inputWidth}, smf);
output_.resize_(
{nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf);
indices_.resize_(
{nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf);
at::max_outf(input_, 3, true, output_, indices_);
} else {
at::max_outf(input_, 2, true, output_, indices_);
}

if (!is_3d) {
input_.resize_({nbatch, nInputPlane, inputHeight, inputWidth}, smf);
output_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf);
indices_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf);
}

if ((is_3d && !indices.is_contiguous()) ||
(!is_3d && !indices.is_contiguous(smf))) {
indices.copy_(indices_);
}

if ((is_3d && !output.is_contiguous()) ||
(!is_3d && !output.is_contiguous(smf))) {
output.copy_(output_);
}
return;
}
xpu::max_pool2d_with_indices_kernel(
input,
kernel_size,
Expand Down
141 changes: 51 additions & 90 deletions src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <ATen/Dispatch.h>
#include <ATen/native/Pool.h>
#include <ATen/native/utils/ParamUtils.h>
#include <xpu/ATen/ops/max.h>

#include <ATen/native/xpu/sycl/Atomics.h>
#include <ATen/native/xpu/sycl/BatchKernel.h>
Expand Down Expand Up @@ -542,96 +541,58 @@ void max_pool2d_with_indices_kernel(

const int64_t outputHeight = output.size(-2);
const int64_t outputWidth = output.size(-1);
if (outputHeight == 1 && outputWidth == 1 && inputHeight <= kH &&
inputWidth <= kW && padH == 0 && padW == 0) {
bool is_3d = input_.ndimension() == 3;
Tensor indices_, output_;
if (is_3d) {
indices_ = indices.contiguous();
output_ = output.contiguous();
} else {
indices_ = indices.contiguous(smf);
output_ = output.contiguous(smf);
}
if (!is_3d) {
input.resize_({nbatch, nInputPlane, 1, inputHeight * inputWidth}, smf);
output_.resize_(
{nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf);
indices_.resize_(
{nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf);
at::max_outf(input, 3, true, output_, indices_);
} else {
at::max_outf(input, 2, true, output_, indices_);
}

if (!is_3d) {
input.resize_({nbatch, nInputPlane, inputHeight, inputWidth}, smf);
output_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf);
indices_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf);
}

if ((is_3d && !indices.is_contiguous()) ||
(!is_3d && !indices.is_contiguous(smf))) {
indices.copy_(indices_);
}

if ((is_3d && !output.is_contiguous()) ||
(!is_3d && !output.is_contiguous(smf))) {
output.copy_(output_);
}
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "max_pool2d_xpu", [&] {
switch (smf) {
case MemoryFormat::ChannelsLast: {
launch_max_pool2d_kernel<scalar_t, true>(
output.mutable_data_ptr<scalar_t>(),
indices.mutable_data_ptr<int64_t>(),
input.const_data_ptr<scalar_t>(),
nbatch,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW);
break;
}
case MemoryFormat::Contiguous: {
launch_max_pool2d_kernel<scalar_t, false>(
output.mutable_data_ptr<scalar_t>(),
indices.mutable_data_ptr<int64_t>(),
input.const_data_ptr<scalar_t>(),
nbatch,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW);
break;
}
default:
TORCH_CHECK(
false,
"Unsupported memory format. Supports only ChannelsLast, Contiguous");

AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "max_pool2d_xpu", [&] {
switch (smf) {
case MemoryFormat::ChannelsLast: {
launch_max_pool2d_kernel<scalar_t, true>(
output.mutable_data_ptr<scalar_t>(),
indices.mutable_data_ptr<int64_t>(),
input.const_data_ptr<scalar_t>(),
nbatch,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW);
break;
}
});
}
case MemoryFormat::Contiguous: {
launch_max_pool2d_kernel<scalar_t, false>(
output.mutable_data_ptr<scalar_t>(),
indices.mutable_data_ptr<int64_t>(),
input.const_data_ptr<scalar_t>(),
nbatch,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW);
break;
}
default:
TORCH_CHECK(
false,
"Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
});
}

void max_pool2d_with_indices_backward_kernel(
Expand Down

0 comments on commit 66ac930

Please sign in to comment.