From d87b635efb8866358ef7cd93ccfac0f4a1c08fc0 Mon Sep 17 00:00:00 2001 From: "Huaiyu, Zheng" Date: Wed, 24 Jul 2024 02:08:31 +0000 Subject: [PATCH 1/5] add aten::upsample_linear1d aten::upsample_linear1d_backward --- src/ATen/native/xpu/UpSampleLinear1d.cpp | 137 +++++++++++ src/ATen/native/xpu/XPUFallback.template | 2 - .../xpu/sycl/UpSampleLinear1dKernels.cpp | 224 ++++++++++++++++++ .../native/xpu/sycl/UpSampleLinear1dKernels.h | 19 ++ test/xpu/extended/run_test_with_skip.py | 7 +- test/xpu/run_test_with_skip.py | 1 - yaml/xpu_functions.yaml | 4 + 7 files changed, 389 insertions(+), 5 deletions(-) create mode 100644 src/ATen/native/xpu/UpSampleLinear1d.cpp create mode 100644 src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.h diff --git a/src/ATen/native/xpu/UpSampleLinear1d.cpp b/src/ATen/native/xpu/UpSampleLinear1d.cpp new file mode 100644 index 000000000..b6c49eead --- /dev/null +++ b/src/ATen/native/xpu/UpSampleLinear1d.cpp @@ -0,0 +1,137 @@ +#include +#include +#include +#include +#include "ATen/core/ATen_fwd.h" + +namespace at { + +static C10_UNUSED std::array upsample_1d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + int64_t output_width = output_size[0]; + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_width = input_size[2]; + + TORCH_CHECK( + input_width > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", + input_width, + ") and output (W: ", + output_width, + ")"); + + return {nbatch, channels, output_width}; +} +void upsample_linear1d_meta( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales, + Tensor& output) { + auto full_output_size = upsample_1d_common_check(input.sizes(), output_size); + + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + (input.size(1) != 0 && input.size(2) != 0) && input.dim() == 3, + "Non-empty 3D data tensor expected but got a tensor with sizes ", + input.sizes()); + + if (output.defined()) { + at::xpu::resize_out(output, full_output_size, {}, input.options()); + } else { + output = at::xpu::create_out(full_output_size, {}, input.options()); + } +} +void upsample_linear1d_backward_meta( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales, + Tensor& grad_input) { + auto full_output_size = upsample_1d_common_check(input_size, output_size); + + TORCH_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + check_dim_size(grad_output, 3, 0, full_output_size[0]); + check_dim_size(grad_output, 3, 1, full_output_size[1]); + check_dim_size(grad_output, 3, 2, full_output_size[2]); + + if (grad_input.defined()) { + at::xpu::resize_out(grad_input, input_size, {}, grad_output.options()); + } else { + grad_input = at::xpu::create_out(input_size, {}, grad_output.options()); + } +} + +Tensor XPUNativeFunctions::upsample_linear1d( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales) { + Tensor output; + return upsample_linear1d_out( + input, output_size, align_corners, scales, output); +} + +Tensor& XPUNativeFunctions::upsample_linear1d_out( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales, + Tensor& output) { + upsample_linear1d_meta(input, output_size, align_corners, scales, output); + + TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2}; + checkAllSameGPU(__func__, {input_arg, output_arg}); + + native::xpu::upsample_linear1d_kernel( + input, output_size, align_corners, scales, output); + return output; +} +Tensor XPUNativeFunctions::upsample_linear1d_backward( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales) { + Tensor grad_input; + return upsample_linear1d_backward_out( + grad_output, output_size, input_size, align_corners, scales, grad_input); +} + +Tensor& XPUNativeFunctions::upsample_linear1d_backward_out( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales, + Tensor& grad_input) { + upsample_linear1d_backward_meta( + grad_output, output_size, input_size, align_corners, scales, grad_input); + + globalContext().alertNotDeterministic("upsample_linear1d_backward_out_xpu"); + TensorArg grad_output_arg{grad_output, "grad_output", 1}, + grad_input_arg{grad_input, "grad_input", 2}; + checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg}); + native::xpu::upsample_linear1d_backward_kernel( + grad_output, output_size, input_size, align_corners, scales, grad_input); + return grad_input; +} + +} // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 68d399460..165e5e7ab 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -337,8 +337,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "unique_consecutive", "upsample_bicubic2d_backward.grad_input", "_upsample_bilinear2d_aa.out", - "upsample_linear1d_backward.grad_input", - "upsample_linear1d.out", "upsample_nearest3d.out", "upsample_nearest3d_backward.grad_input", "_upsample_nearest_exact3d.out", diff --git a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp new file mode 100644 index 000000000..2b96c186c --- /dev/null +++ b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp @@ -0,0 +1,224 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "ATen/Context.h" +#include "ATen/core/TensorBase.h" + +namespace at::native::xpu { +template +struct UpsampleLinear1dKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + int index = + item.get_local_id(0) + item.get_group(0) * item.get_local_range(0); + + const int batchsize = idata_.size(0); + const int channels = idata_.size(1); + const int width1 = idata_.size(2); + const int width2 = odata_.size(2); + PackedTensorAccessor64 odata_res = odata_; + + if (index < n_) { + const int w2 = index % width2; + // special case: just copy + if (width1 == width2) { + const int w1 = w2; + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + const scalar_t val = idata_[n][c][w1]; + odata_res[n][c][w2] = val; + } + } + return; + } + + const accscalar_t w1r = area_pixel_compute_source_index( + rwidth_, w2, align_corners_, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const accscalar_t w1lambda = w1r - w1; + const accscalar_t w0lambda = static_cast(1) - w1lambda; + + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + const accscalar_t val = + w0lambda * idata_[n][c][w1] + w1lambda * idata_[n][c][w1 + w1p]; + odata_res[n][c][w2] = static_cast(val); + } + } + } + } + UpsampleLinear1dKernelFunctor( + const int n, + const accscalar_t rwidth, + const bool align_corners, + const PackedTensorAccessor64 idata, + PackedTensorAccessor64 odata) + : n_(n), + rwidth_(rwidth), + align_corners_(align_corners), + idata_(idata), + odata_(odata) {} + + private: + const int n_; + const accscalar_t rwidth_; + const bool align_corners_; + const PackedTensorAccessor64 idata_; + PackedTensorAccessor64 odata_; +}; + +void upsample_linear1d_kernel( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales, + Tensor& output) { + int output_width = output_size[0]; + output.zero_(); + int input_width = input.size(2); + + AT_ASSERT(input_width > 0 && output_width > 0); + + const int num_kernels = output_width; + const int num_threads = 512; + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "upsample_linear1d_kernel", + [&] { + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); + + using accscalar_t = at::acc_type_device; + const accscalar_t rwidth = area_pixel_compute_scale( + input_width, output_width, align_corners, scales); + UpsampleLinear1dKernelFunctor kfn( + num_kernels, rwidth, align_corners, idata, odata); + auto global_range = ceil_div(num_kernels, num_threads); + auto local_range = num_threads; + sycl_kernel_submit( + global_range * local_range, + local_range, + getCurrentSYCLQueue(), + kfn); + }); +} + +template +struct UpsampleLinear1dBackwardKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + int index = + item.get_local_id(0) + item.get_group(0) * item.get_local_range(0); + + const int batchsize = idata_.size(0); + const int channels = idata_.size(1); + const int width1 = idata_.size(2); + const int width2 = odata_.size(2); + PackedTensorAccessor64 idata_res = idata_; + if (index < n_) { + const int w2 = index % width2; + if (width1 == width2) { + const int w1 = w2; + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + const scalar_t val = odata_[n][c][w1]; + idata_res[n][c][w2] = val; + } + } + return; + } + const accscalar_t w1r = area_pixel_compute_source_index( + rwidth_, w2, align_corners_, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const accscalar_t w1lambda = w1r - w1; + const accscalar_t w0lambda = static_cast(1) - w1lambda; + + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + const scalar_t d2val = odata_[n][c][w2]; + atomicAdd( + (sycl_global_ptr)(&idata_res[n][c][w1]), + static_cast(w0lambda * d2val)); + atomicAdd( + (sycl_global_ptr)(&idata_res[n][c][w1 + w1p]), + static_cast(w1lambda * d2val)); + } + } + } + } + UpsampleLinear1dBackwardKernelFunctor( + const int n, + const accscalar_t rwidth, + const bool align_corners, + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) + : n_(n), + rwidth_(rwidth), + align_corners_(align_corners), + idata_(idata), + odata_(odata) {} + + private: + const int n_; + const accscalar_t rwidth_; + const bool align_corners_; + PackedTensorAccessor64 idata_; + const PackedTensorAccessor64 odata_; +}; + +void upsample_linear1d_backward_kernel( + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales, + Tensor& grad_input) { + int output_width = output_size[0]; + int input_width = input_size[2]; + Tensor grad_output = grad_output_.contiguous(); + grad_input.zero_(); + + const int num_kernels = output_width; + const int num_threads = 512; + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + grad_output.scalar_type(), + "upsample_linear1d_backward", + [&] { + using accscalar_t = at::acc_type_device; + + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); + const accscalar_t rwidth = area_pixel_compute_scale( + input_width, output_width, align_corners, scales); + UpsampleLinear1dBackwardKernelFunctor kfn( + num_kernels, rwidth, align_corners, idata, odata); + auto global_range = ceil_div(num_kernels, num_threads); + auto local_range = num_threads; + sycl_kernel_submit( + global_range * local_range, + local_range, + getCurrentSYCLQueue(), + kfn); + }); +} +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.h b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.h new file mode 100644 index 000000000..21db6faae --- /dev/null +++ b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.h @@ -0,0 +1,19 @@ +#include + +namespace at::native::xpu { +void upsample_linear1d_kernel( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales, + Tensor& output); + +void upsample_linear1d_backward_kernel( + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales, + Tensor& grad_input); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index ac9a80db9..8d46109f3 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -137,11 +137,14 @@ "test_compare_cpu_nn_functional_huber_loss_xpu_bfloat16", # Not implemented operators, aten::upsample_linear1d, aten::upsample_bilinear2d, - # aten::upsample_trilinear3d - "nn_functional_interpolate_linear", + # aten::upsample_trilinear3d, "nn_functional_interpolate_bilinear", "nn_functional_interpolate_trilinear", + #The results of XPU and CUDA are consistent, but the results of CPU and CUDA are inconsistent + "test_compare_cpu_nn_functional_interpolate_linear_xpu_bfloat16", + "test_compare_cpu_nn_functional_interpolate_linear_xpu_float16", + # bicubic interpolate includes large calculation steps, accuracy reduces in half-precision # Not in CUDA test scope too "test_compare_cpu_nn_functional_interpolate_bicubic_xpu_bfloat16", diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index eed4eec71..f884f615d 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -2820,7 +2820,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_nondeterministic_alert_histc_xpu", "test_nondeterministic_alert_interpolate_bicubic_xpu", "test_nondeterministic_alert_interpolate_bilinear_xpu", - "test_nondeterministic_alert_interpolate_linear_xpu", "test_nondeterministic_alert_interpolate_trilinear_xpu", "test_nondeterministic_alert_kthvalue_xpu_float64", "test_nondeterministic_alert_median_xpu_float64", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 71ecfecdd..0dbde97be 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -597,3 +597,7 @@ supported: - ceil_ - ceil.out - nan_to_num.out + - upsample_linear1d_backward.grad_input + - upsample_linear1d.out + - upsample_linear1d + - upsample_linear1d_backward From e00b44ac9f28015353eb15d0cf9a637c084f686f Mon Sep 17 00:00:00 2001 From: "Huaiyu, Zheng" Date: Mon, 29 Jul 2024 07:37:46 +0000 Subject: [PATCH 2/5] fix comments --- src/ATen/native/xpu/ReplicationPadding.cpp | 73 ++++++--- src/ATen/native/xpu/UpSample.h | 29 ++++ src/ATen/native/xpu/UpSampleLinear1d.cpp | 36 +---- .../xpu/sycl/ReplicationPaddingKernels.cpp | 141 ++++++++++-------- .../xpu/sycl/UpSampleLinear1dKernels.cpp | 30 ++-- 5 files changed, 179 insertions(+), 130 deletions(-) diff --git a/src/ATen/native/xpu/ReplicationPadding.cpp b/src/ATen/native/xpu/ReplicationPadding.cpp index b4f6d3272..062d5bc1c 100644 --- a/src/ATen/native/xpu/ReplicationPadding.cpp +++ b/src/ATen/native/xpu/ReplicationPadding.cpp @@ -34,9 +34,13 @@ void replication_pad1d_meta( int64_t iwidth = input.size(dimw); int64_t owidth = iwidth + pad_l + pad_r; - TORCH_CHECK(owidth >= 1, - "input (W: ", iwidth, ") is too small." - " Calculated output W: ", owidth); + TORCH_CHECK( + owidth >= 1, + "input (W: ", + iwidth, + ") is too small." + " Calculated output W: ", + owidth); if (output.defined()) { if (input.ndimension() == 2) { @@ -69,11 +73,14 @@ void replication_pad1d_backward_meta( /* sizes */ int64_t iwidth = input.size(dimw); - int64_t owidth = iwidth + pad_l + pad_r; + int64_t owidth = iwidth + pad_l + pad_r; - TORCH_CHECK(owidth == grad_output.size(dimw), - "grad_output width unexpected. Expected: ", owidth, - " Got: ", grad_output.size(dimw)); + TORCH_CHECK( + owidth == grad_output.size(dimw), + "grad_output width unexpected. Expected: ", + owidth, + " Got: ", + grad_output.size(dimw)); if (grad_input.defined()) { xpu::resize_out(grad_input, input.sizes(), {}, input.options()); @@ -110,25 +117,30 @@ void replication_pad2d_meta( int64_t iheight = input.size(dimh); int64_t iwidth = input.size(dimw); int64_t oheight = iheight + pad_t + pad_b; - int64_t owidth = iwidth + pad_l + pad_r; + int64_t owidth = iwidth + pad_l + pad_r; - TORCH_CHECK(owidth >= 1 || oheight >= 1, - "input (H: ", iheight, ", W: ", iwidth, " ) is too small." - " Calculated output H: ", oheight, " W: ", owidth); + TORCH_CHECK( + owidth >= 1 || oheight >= 1, + "input (H: ", + iheight, + ", W: ", + iwidth, + " ) is too small." + " Calculated output H: ", + oheight, + " W: ", + owidth); if (output.defined()) { if (input.dim() == 3) { - xpu::resize_out( - output, - {nslices, oheight, owidth}, {}, input.options()); + xpu::resize_out(output, {nslices, oheight, owidth}, {}, input.options()); } else { xpu::resize_out( output, {nbatch, nslices, oheight, owidth}, {}, input.options()); } } else { if (input.dim() == 3) { - output = xpu::create_out( - {nslices, oheight, owidth}, {}, input.options()); + output = xpu::create_out({nslices, oheight, owidth}, {}, input.options()); } else { output = xpu::create_out( {nbatch, nslices, oheight, owidth}, {}, input.options()); @@ -170,21 +182,34 @@ void replication_pad3d_meta( int64_t iwidth = input.size(dimw); int64_t odepth = idepth + pfront + pback; int64_t oheight = iheight + ptop + pbottom; - int64_t owidth = iwidth + pleft + pright; - - TORCH_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1, - "input (D: ", idepth, " H: ", iheight, ", W: ", iwidth, + int64_t owidth = iwidth + pleft + pright; + + TORCH_CHECK( + owidth >= 1 || oheight >= 1 || odepth >= 1, + "input (D: ", + idepth, + " H: ", + iheight, + ", W: ", + iwidth, ") is too small." - " Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth); + " Calculated output D: ", + odepth, + " H: ", + oheight, + " W: ", + owidth); if (output.defined()) { if (input.dim() == 4) { xpu::resize_out( - output, - {nslices, odepth, oheight, owidth}, {}, input.options()); + output, {nslices, odepth, oheight, owidth}, {}, input.options()); } else { xpu::resize_out( - output, {nbatch, nslices, odepth, oheight, owidth}, {}, input.options()); + output, + {nbatch, nslices, odepth, oheight, owidth}, + {}, + input.options()); } } else { if (input.dim() == 4) { diff --git a/src/ATen/native/xpu/UpSample.h b/src/ATen/native/xpu/UpSample.h index 5ca47c4d4..44e9f5829 100644 --- a/src/ATen/native/xpu/UpSample.h +++ b/src/ATen/native/xpu/UpSample.h @@ -228,4 +228,33 @@ static scalar_t upsample_get_value_bounded( return data[batch][channel][access_y][access_x]; } +static C10_UNUSED std::array upsample_1d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + int64_t output_width = output_size[0]; + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_width = input_size[2]; + + TORCH_CHECK( + input_width > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", + input_width, + ") and output (W: ", + output_width, + ")"); + + return {nbatch, channels, output_width}; +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/UpSampleLinear1d.cpp b/src/ATen/native/xpu/UpSampleLinear1d.cpp index b6c49eead..fcce31524 100644 --- a/src/ATen/native/xpu/UpSampleLinear1d.cpp +++ b/src/ATen/native/xpu/UpSampleLinear1d.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -6,41 +7,14 @@ namespace at { -static C10_UNUSED std::array upsample_1d_common_check( - IntArrayRef input_size, - IntArrayRef output_size) { - TORCH_CHECK( - output_size.size() == 1, - "It is expected output_size equals to 1, but got size ", - output_size.size()); - - TORCH_CHECK( - input_size.size() == 3, - "It is expected input_size equals to 3, but got size ", - input_size.size()); - - int64_t output_width = output_size[0]; - int64_t nbatch = input_size[0]; - int64_t channels = input_size[1]; - int64_t input_width = input_size[2]; - - TORCH_CHECK( - input_width > 0 && output_width > 0, - "Input and output sizes should be greater than 0, but got input (W: ", - input_width, - ") and output (W: ", - output_width, - ")"); - - return {nbatch, channels, output_width}; -} void upsample_linear1d_meta( const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional scales, Tensor& output) { - auto full_output_size = upsample_1d_common_check(input.sizes(), output_size); + auto full_output_size = + at::native::xpu::upsample_1d_common_check(input.sizes(), output_size); // Allow for empty batch size but not other dimensions TORCH_CHECK( @@ -61,7 +35,8 @@ void upsample_linear1d_backward_meta( bool align_corners, std::optional scales, Tensor& grad_input) { - auto full_output_size = upsample_1d_common_check(input_size, output_size); + auto full_output_size = + at::native::xpu::upsample_1d_common_check(input_size, output_size); TORCH_CHECK( input_size.size() == 3, @@ -125,7 +100,6 @@ Tensor& XPUNativeFunctions::upsample_linear1d_backward_out( upsample_linear1d_backward_meta( grad_output, output_size, input_size, align_corners, scales, grad_input); - globalContext().alertNotDeterministic("upsample_linear1d_backward_out_xpu"); TensorArg grad_output_arg{grad_output, "grad_output", 1}, grad_input_arg{grad_input, "grad_input", 2}; checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg}); diff --git a/src/ATen/native/xpu/sycl/ReplicationPaddingKernels.cpp b/src/ATen/native/xpu/sycl/ReplicationPaddingKernels.cpp index bb42aa327..4ba4eafb9 100644 --- a/src/ATen/native/xpu/sycl/ReplicationPaddingKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReplicationPaddingKernels.cpp @@ -7,8 +7,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -34,7 +34,12 @@ struct ParallelReplicationPad1dKernelFunctor { imin(imax(pad_left_, output_x), input_.size(2) + pad_left_ - 1) - o_start_x + i_start_x; - f_(input_, output_, item.get_group(1), item.get_group(0), output_x, input_x); + f_(input_, + output_, + item.get_group(1), + item.get_group(0), + output_x, + input_x); } } ParallelReplicationPad1dKernelFunctor( @@ -143,31 +148,31 @@ struct ParallelReplicationPad2dKernelFunctor { const int plane = item.get_global_id(1); if (output_id < output_.size(2) * output_.size(3)) { - const int output_x = output_id / output_.size(3); // height - const int output_y = output_id % output_.size(3); // width + const int output_x = output_id / output_.size(3); // height + const int output_y = output_id % output_.size(3); // width const int iStartX = imax(0, -padT_); const int iStartY = imax(0, -padL_); const int oStartX = imax(0, padT_); const int oStartY = imax(0, padL_); - const int input_x = imin(imax(padT_, output_x), input_.size(2) + padT_ - 1) - oStartX + iStartX; - const int input_y = imin(imax(padL_, output_y), input_.size(3) + padL_ - 1) - oStartY + iStartY; + const int input_x = + imin(imax(padT_, output_x), input_.size(2) + padT_ - 1) - oStartX + + iStartX; + const int input_y = + imin(imax(padL_, output_y), input_.size(3) + padL_ - 1) - oStartY + + iStartY; f_(input_, output_, batch, plane, input_x, input_y, output_x, output_y); } -} + } ParallelReplicationPad2dKernelFunctor( PackedTensorAccessor64 input, PackedTensorAccessor64 output, int64_t padT, int64_t padL, const F f) - : input_(input), - output_(output), - padT_(padT), - padL_(padL), - f_(f) {} + : input_(input), output_(output), padT_(padT), padL_(padL), f_(f) {} private: PackedTensorAccessor64 input_; @@ -276,22 +281,22 @@ struct ParallelReplicationPad3dKernelFunctor { imin(imax(pad_left_, output_x), input_.size(4) + pad_left_ - 1) - o_start_x + i_start_x; int64_t input_y = - imin(imax(pad_top_, output_y), input_.size(3) + pad_top_ - 1) - o_start_y + - i_start_y; + imin(imax(pad_top_, output_y), input_.size(3) + pad_top_ - 1) - + o_start_y + i_start_y; int64_t input_z = imin(imax(pad_front_, output_z), input_.size(2) + pad_front_ - 1) - o_start_z + i_start_z; f_(input_, - output_, - item.get_group(1), - item.get_group(0), - output_z, - output_y, - output_x, - input_z, - input_y, - input_x); + output_, + item.get_group(1), + item.get_group(0), + output_z, + output_y, + output_x, + input_z, + input_y, + input_x); } } ParallelReplicationPad3dKernelFunctor( @@ -339,8 +344,9 @@ void parallel_replication_pad3d( int64_t nbatch = output.size(0); sycl_kernel_submit( - sycl::range<3>(nbatch, nplane, work_group_size * work_group_num), - sycl::range<3>(1, 1, work_group_size), queue, + sycl::range<3>(nbatch, nplane, work_group_size * work_group_num), + sycl::range<3>(1, 1, work_group_size), + queue, kfn); } @@ -387,8 +393,8 @@ struct ReplicationPad3dBackwardFunctor { int64_t intput_y, int64_t intput_x) const { auto value_to_add = grad_output[batch][plane][output_z][output_y][output_x]; - auto target = - (sycl_global_ptr)&grad_input[batch][plane][intput_z][intput_y][intput_x]; + auto target = (sycl_global_ptr)&grad_input[batch][plane][intput_z] + [intput_y][intput_x]; atomicAdd(target, value_to_add); } }; @@ -409,7 +415,8 @@ void replication_pad1d_kernel( Tensor& output, const Tensor& input, IntArrayRef padding) { - TORCH_CHECK(input.numel() < std::numeric_limits::max(), + TORCH_CHECK( + input.numel() < std::numeric_limits::max(), "replication_pad1d only supports input tensors with less than 2^63 - 1 elements"); if (input.numel() == 0) { @@ -446,9 +453,11 @@ void replication_pad1d_backward_kernel( // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad1d_backward_xpu"); - TORCH_CHECK(input.numel() < std::numeric_limits::max(), + TORCH_CHECK( + input.numel() < std::numeric_limits::max(), "replication_pad1d only supports input tensors with less than 2^63 - 1 elements"); - TORCH_CHECK(grad_output.numel() < std::numeric_limits::max(), + TORCH_CHECK( + grad_output.numel() < std::numeric_limits::max(), "replication_pad1d only supports output tensors with less than 2^63 - 1 elements"); if (grad_input.numel() == 0) { @@ -484,26 +493,27 @@ void replication_pad2d_kernel( Tensor& output, const Tensor& input, IntArrayRef padding) { - TORCH_CHECK(canUse32BitIndexMath(input), + TORCH_CHECK( + canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); if (input.numel() == 0) { return; } const auto padL = padding[0]; const auto padT = padding[2]; - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, - input.scalar_type(), "replication_pad2d_xpu", [&] { - Tensor input_ = input; - Tensor output_ = output; - if (input.dim() == 3) { - input_ = input.unsqueeze(0); - output_ = output.unsqueeze(0); - } - auto devInput = input_.packed_accessor64(); - auto devOutput = output_.packed_accessor64(); - replication_pad2d_forward_template(devInput, devOutput, padT, padL); - } - ); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kHalf, kBFloat16, input.scalar_type(), "replication_pad2d_xpu", [&] { + Tensor input_ = input; + Tensor output_ = output; + if (input.dim() == 3) { + input_ = input.unsqueeze(0); + output_ = output.unsqueeze(0); + } + auto devInput = input_.packed_accessor64(); + auto devOutput = output_.packed_accessor64(); + replication_pad2d_forward_template( + devInput, devOutput, padT, padL); + }); } void replication_pad2d_backward_kernel( @@ -514,9 +524,11 @@ void replication_pad2d_backward_kernel( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad2d_backward_xpu"); - TORCH_CHECK(canUse32BitIndexMath(input), + TORCH_CHECK( + canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); - TORCH_CHECK(canUse32BitIndexMath(grad_output), + TORCH_CHECK( + canUse32BitIndexMath(grad_output), "output gradient tensor must fit into 32-bit index math"); TORCH_CHECK(padding.size() == 4, "padding Size is expected to be 4"); @@ -535,13 +547,19 @@ void replication_pad2d_backward_kernel( const auto iheight = input.size(dimh); const auto iwidth = input.size(dimw); const auto oheight = iheight + padT + padB; - const auto owidth = iwidth + padL + padR; + const auto owidth = iwidth + padL + padR; - TORCH_CHECK(owidth == grad_output.size(dimw), - "grad_output width unexpected. Expected: ", owidth, ", Got: ", + TORCH_CHECK( + owidth == grad_output.size(dimw), + "grad_output width unexpected. Expected: ", + owidth, + ", Got: ", grad_output.size(dimw)); - TORCH_CHECK(oheight == grad_output.size(dimh), - "grad_output height unexpected. Expected: ", oheight, ", Got: ", + TORCH_CHECK( + oheight == grad_output.size(dimh), + "grad_output height unexpected. Expected: ", + oheight, + ", Got: ", grad_output.size(dimh)); grad_input.resize_as_(input); @@ -550,9 +568,12 @@ void replication_pad2d_backward_kernel( } grad_input.zero_(); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, - input.scalar_type(), "replication_pad2d_backward_xpu", [&] { - + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "replication_pad2d_backward_xpu", + [&] { auto grad_input_ = grad_input; auto grad_output_ = grad_output; if (numInputDims == 3) { @@ -562,9 +583,9 @@ void replication_pad2d_backward_kernel( auto grad_input_packed = grad_input_.packed_accessor64(); auto grad_output_packed = grad_output_.packed_accessor64(); - replication_pad2d_backward_template(grad_input_packed, grad_output_packed, padT, padL); - } - ); + replication_pad2d_backward_template( + grad_input_packed, grad_output_packed, padT, padL); + }); } void replication_pad3d_kernel( @@ -606,7 +627,8 @@ static inline void shapeAndGradOutputCheck3d( int64_t pad_bottom, int64_t pad_front, int64_t pad_back) { - TORCH_CHECK(canUse32BitIndexMath(input), + TORCH_CHECK( + canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); int64_t num_input_dims = input.dim(); @@ -652,7 +674,8 @@ static inline void shapeAndGradOutputCheck3d( " W: ", owidth); - TORCH_CHECK(canUse32BitIndexMath(grad_output), + TORCH_CHECK( + canUse32BitIndexMath(grad_output), "output gradient tensor must fit into 32-bit index math"); TORCH_CHECK( diff --git a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp index 2b96c186c..7745c4a00 100644 --- a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -90,9 +89,6 @@ void upsample_linear1d_kernel( AT_ASSERT(input_width > 0 && output_width > 0); - const int num_kernels = output_width; - const int num_threads = 512; - AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -105,10 +101,12 @@ void upsample_linear1d_kernel( using accscalar_t = at::acc_type_device; const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales); - UpsampleLinear1dKernelFunctor kfn( - num_kernels, rwidth, align_corners, idata, odata); - auto global_range = ceil_div(num_kernels, num_threads); - auto local_range = num_threads; + const int num_kernels = output_width; + using KernelClass = + UpsampleLinear1dKernelFunctor; + int64_t local_range = syclMaxWorkGroupSize(); + KernelClass kfn(num_kernels, rwidth, align_corners, idata, odata); + auto global_range = (num_kernels + local_range - 1) / local_range; sycl_kernel_submit( global_range * local_range, local_range, @@ -187,14 +185,13 @@ void upsample_linear1d_backward_kernel( bool align_corners, std::optional scales, Tensor& grad_input) { + globalContext().alertNotDeterministic("upsample_linear1d_backward_xpu"); + int output_width = output_size[0]; int input_width = input_size[2]; Tensor grad_output = grad_output_.contiguous(); grad_input.zero_(); - const int num_kernels = output_width; - const int num_threads = 512; - AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -202,15 +199,16 @@ void upsample_linear1d_backward_kernel( "upsample_linear1d_backward", [&] { using accscalar_t = at::acc_type_device; - + const int num_kernels = output_width; auto idata = grad_input.packed_accessor64(); auto odata = grad_output.packed_accessor64(); const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales); - UpsampleLinear1dBackwardKernelFunctor kfn( - num_kernels, rwidth, align_corners, idata, odata); - auto global_range = ceil_div(num_kernels, num_threads); - auto local_range = num_threads; + using KernelClass = + UpsampleLinear1dBackwardKernelFunctor; + int64_t local_range = syclMaxWorkGroupSize(); + KernelClass kfn(num_kernels, rwidth, align_corners, idata, odata); + auto global_range = (num_kernels + local_range - 1) / local_range; sycl_kernel_submit( global_range * local_range, local_range, From b8417e1963e2db6aba5931745f8ec41c49fae596 Mon Sep 17 00:00:00 2001 From: "Huaiyu, Zheng" Date: Mon, 29 Jul 2024 08:25:41 +0000 Subject: [PATCH 3/5] fix test_nondeterministic_alert_interpolate_linear_xpu --- src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp index 7745c4a00..4e25ad94b 100644 --- a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp @@ -185,7 +185,7 @@ void upsample_linear1d_backward_kernel( bool align_corners, std::optional scales, Tensor& grad_input) { - globalContext().alertNotDeterministic("upsample_linear1d_backward_xpu"); + globalContext().alertNotDeterministic("upsample_linear1d_backward_out_xpu"); int output_width = output_size[0]; int input_width = input_size[2]; From 7a8c3b72d5e4a8a96415c39b9e1dbd183ae2f0de Mon Sep 17 00:00:00 2001 From: "Huaiyu, Zheng" Date: Tue, 30 Jul 2024 01:09:15 +0000 Subject: [PATCH 4/5] fix comments --- src/ATen/native/xpu/Loss.cpp | 1 - .../xpu/sycl/UpSampleLinear1dKernels.cpp | 18 ++++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/ATen/native/xpu/Loss.cpp b/src/ATen/native/xpu/Loss.cpp index 050ff07b9..f2ca7d9c0 100644 --- a/src/ATen/native/xpu/Loss.cpp +++ b/src/ATen/native/xpu/Loss.cpp @@ -80,7 +80,6 @@ Tensor& XPUNativeFunctions::mse_loss_backward_out( return grad_input; } - Tensor& XPUNativeFunctions::smooth_l1_loss_out( const Tensor& input, const Tensor& target, diff --git a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp index 4e25ad94b..8b82a6a71 100644 --- a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp @@ -93,7 +93,7 @@ void upsample_linear1d_kernel( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), - "upsample_linear1d_kernel", + "upsample_linear1d_xpu", [&] { auto idata = input.packed_accessor64(); auto odata = output.packed_accessor64(); @@ -102,10 +102,9 @@ void upsample_linear1d_kernel( const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales); const int num_kernels = output_width; - using KernelClass = - UpsampleLinear1dKernelFunctor; - int64_t local_range = syclMaxWorkGroupSize(); - KernelClass kfn(num_kernels, rwidth, align_corners, idata, odata); + UpsampleLinear1dKernelFunctor kfn( + num_kernels, rwidth, align_corners, idata, odata); + const auto local_range = syclMaxWorkGroupSize(kfn); auto global_range = (num_kernels + local_range - 1) / local_range; sycl_kernel_submit( global_range * local_range, @@ -196,7 +195,7 @@ void upsample_linear1d_backward_kernel( at::ScalarType::Half, at::ScalarType::BFloat16, grad_output.scalar_type(), - "upsample_linear1d_backward", + "upsample_linear1d_backward_xpu", [&] { using accscalar_t = at::acc_type_device; const int num_kernels = output_width; @@ -204,10 +203,9 @@ void upsample_linear1d_backward_kernel( auto odata = grad_output.packed_accessor64(); const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales); - using KernelClass = - UpsampleLinear1dBackwardKernelFunctor; - int64_t local_range = syclMaxWorkGroupSize(); - KernelClass kfn(num_kernels, rwidth, align_corners, idata, odata); + UpsampleLinear1dBackwardKernelFunctor kfn( + num_kernels, rwidth, align_corners, idata, odata); + const auto local_range = syclMaxWorkGroupSize(kfn); auto global_range = (num_kernels + local_range - 1) / local_range; sycl_kernel_submit( global_range * local_range, From 3d4da941942d2c0e4654d73a04c34a99edc10c2a Mon Sep 17 00:00:00 2001 From: "Huaiyu, Zheng" Date: Tue, 30 Jul 2024 02:30:57 +0000 Subject: [PATCH 5/5] fix comments --- .../native/xpu/sycl/UpSampleLinear1dKernels.cpp | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp index 8b82a6a71..71fc04ab1 100644 --- a/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleLinear1dKernels.cpp @@ -105,12 +105,10 @@ void upsample_linear1d_kernel( UpsampleLinear1dKernelFunctor kfn( num_kernels, rwidth, align_corners, idata, odata); const auto local_range = syclMaxWorkGroupSize(kfn); - auto global_range = (num_kernels + local_range - 1) / local_range; + auto global_range = + (num_kernels + local_range - 1) / local_range * local_range; sycl_kernel_submit( - global_range * local_range, - local_range, - getCurrentSYCLQueue(), - kfn); + global_range, local_range, getCurrentSYCLQueue(), kfn); }); } @@ -206,12 +204,10 @@ void upsample_linear1d_backward_kernel( UpsampleLinear1dBackwardKernelFunctor kfn( num_kernels, rwidth, align_corners, idata, odata); const auto local_range = syclMaxWorkGroupSize(kfn); - auto global_range = (num_kernels + local_range - 1) / local_range; + auto global_range = + (num_kernels + local_range - 1) / local_range * local_range; sycl_kernel_submit( - global_range * local_range, - local_range, - getCurrentSYCLQueue(), - kfn); + global_range, local_range, getCurrentSYCLQueue(), kfn); }); } } // namespace at::native::xpu