Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable nondeterministic implementation for max_pool2d_backward #619

Merged
merged 4 commits into from
Jul 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 67 additions & 57 deletions src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,68 +416,78 @@ void launch_max_pool2d_backward_kernel(
int dilation_h,
int dilation_w) {
auto& queue = at::xpu::getCurrentSYCLQueue();
int64_t gradOutputSize =
numBatch * numPlane * gradOutputSizeH * gradOutputSizeW;
int64_t gradInputSize = numBatch * numPlane * gradInputSizeH * gradInputSizeW;
auto out_cf_c_stride = gradOutputSizeH * gradOutputSizeW;
auto in_cf_c_stride = gradInputSizeH * gradInputSizeW;
auto out_n_stride = numPlane * out_cf_c_stride;
auto in_n_stride = numPlane * in_cf_c_stride;
if (globalContext().deterministicAlgorithms() ||
std::is_same_v<scalar_t, at::Half> ||
std::is_same_v<scalar_t, at::BFloat16>) {
using KernelClass =
MaxPool2dBackwardDeterministicKernelFunctor<scalar_t, is_channels_last>;
BatchKernelConfig cfg = {
1, gradInputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive};
cfg.template build<KernelClass>();
auto kfn = KernelClass(
gradInput,
gradOutput,
indices,
numPlane,
gradInputSizeH,
gradInputSizeW,
gradOutputSizeH,
gradOutputSizeW,
gradInputSize,
out_cf_c_stride,
in_cf_c_stride,
out_n_stride,
in_n_stride,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
cfg);
sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn);
} else {
using KernelClass =
MaxPool2dBackwardKernelFunctor<scalar_t, is_channels_last>;
BatchKernelConfig cfg = {
1, gradOutputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive};
cfg.template build<KernelClass>();
auto kfn = KernelClass(
gradInput,
gradOutput,
indices,
numPlane,
gradInputSizeH,
gradInputSizeW,
gradOutputSizeH,
gradOutputSizeW,
gradOutputSize,
out_cf_c_stride,
in_cf_c_stride,
out_n_stride,
in_n_stride,
cfg);
sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn);
}

#ifndef XPU_ALLOW_UNDETERMINISTIC
// [Deterministic Note]

// By default, we disable the un-derterministic path in this kernel,
// so that we make sure there will no side-effect with the accuracy.
// In the future, we will re-enable the un-deterministic path to improve
// performance.
//
// The background of this is that we found this kernel has different behavior
// with CUDA in alexnet To avoid future problem, we decided to always use
// deterministic path.

using KernelClass =
MaxPool2dBackwardDeterministicKernelFunctor<scalar_t, is_channels_last>;
BatchKernelConfig cfg = {
1, gradInputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive};
cfg.template build<KernelClass>();
auto kfn = KernelClass(
gradInput,
gradOutput,
indices,
numPlane,
gradInputSizeH,
gradInputSizeW,
gradOutputSizeH,
gradOutputSizeW,
gradInputSize,
out_cf_c_stride,
in_cf_c_stride,
out_n_stride,
in_n_stride,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
cfg);
sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn);
#else
int64_t gradOutputSize =
numBatch * numPlane * gradOutputSizeH * gradOutputSizeW;
using KernelClass =
MaxPool2dBackwardKernelFunctor<scalar_t, is_channels_last>;
BatchKernelConfig cfg = {
1, gradOutputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive};
cfg.template build<KernelClass>();
auto kfn = KernelClass(
gradInput,
gradOutput,
indices,
numPlane,
gradInputSizeH,
gradInputSizeW,
gradOutputSizeH,
gradOutputSizeW,
gradOutputSize,
out_cf_c_stride,
in_cf_c_stride,
out_n_stride,
in_n_stride,
cfg);
sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn);
#endif
}

void max_pool2d_with_indices_kernel(
Expand Down