diff --git a/src/ATen/native/xpu/ReflectionPad.cpp b/src/ATen/native/xpu/ReflectionPad.cpp index 57ea9bbc6..2488ed229 100644 --- a/src/ATen/native/xpu/ReflectionPad.cpp +++ b/src/ATen/native/xpu/ReflectionPad.cpp @@ -1,10 +1,322 @@ #include #include +#include +#include #include #include +#include namespace at { +void reflection_pad1d_meta( + Tensor& output, + const Tensor& input, + IntArrayRef padding) { + int64_t dim_plane = 0; + int64_t dim_w = 1; + int64_t nbatch = 1; + + if (input.ndimension() == 3) { + nbatch = input.size(0); + dim_w++; + dim_plane++; + } + + at::native::padding::check_valid_input<1>(input, padding); + + /* sizes */ + auto pad_l = padding[0]; + auto pad_r = padding[1]; + + int64_t nplane = input.size(dim_plane); + int64_t input_w = input.size(dim_w); + int64_t output_w = input_w + pad_l + pad_r; + + TORCH_CHECK( + pad_l < input_w && pad_r < input_w, + "Argument #4: Padding size " + "should be less than the corresponding input dimension, but got: padding (", + pad_l, + ", ", + pad_r, + ") at dimension ", + dim_w, + " of input ", + input.sizes()); + + TORCH_CHECK( + output_w >= 1, + "input (W: ", + input_w, + ") is too small. Calculated output W: ", + output_w); + + if (output.defined()) { + if (input.ndimension() == 2) { + xpu::resize_out(output, {nplane, output_w}, {}, input.options()); + } else { + xpu::resize_out(output, {nbatch, nplane, output_w}, {}, input.options()); + } + } else { + if (input.ndimension() == 2) { + output = xpu::create_out({nplane, output_w}, {}, input.options()); + } else { + output = xpu::create_out({nbatch, nplane, output_w}, {}, input.options()); + } + } +} + +void reflection_pad1d_backward_meta( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding) { + int64_t dim_w = 1; + if (input.ndimension() == 3) { + dim_w++; + } + + /* sizes */ + auto pad_l = padding[0]; + auto pad_r = padding[1]; + int64_t input_w = input.size(dim_w); + int64_t output_w = input_w + pad_l + pad_r; + + TORCH_CHECK( + pad_l < input_w && pad_r < input_w, + "Argument #4: Padding size " + "should be less than the corresponding input dimension, but got: padding (", + pad_l, + ", ", + pad_r, + ") at dimension ", + dim_w, + " of input ", + input.sizes()); + + TORCH_CHECK( + output_w == grad_output.size(dim_w), + "grad_output width unexpected." + " Expected: ", + output_w, + ", Got: ", + grad_output.size(dim_w)); + + if (grad_input.defined()) { + xpu::resize_out(grad_input, input.sizes(), {}, input.options()); + } else { + grad_input = xpu::create_out(input.sizes(), {}, input.options()); + } +} + +void reflection_pad3d_meta( + Tensor& output, + const Tensor& input, + IntArrayRef padding) { + int64_t pad_left = padding[0]; + int64_t pad_right = padding[1]; + int64_t pad_top = padding[2]; + int64_t pad_bottom = padding[3]; + int64_t pad_front = padding[4]; + int64_t pad_back = padding[5]; + int64_t dim_w = 3; + int64_t dim_h = 2; + int64_t dim_d = 1; + int64_t dim_plane = 0; + + at::native::padding::check_valid_input<3>(input, padding); + + bool batch_mode = (input.dim() == 5); + if (batch_mode) { + dim_w++; + dim_h++; + dim_d++; + dim_plane++; + } + + int64_t nplane = input.size(dim_plane); + int64_t input_d = input.size(dim_d); + int64_t input_h = input.size(dim_h); + int64_t input_w = input.size(dim_w); + int64_t output_d = input_d + pad_front + pad_back; + int64_t output_h = input_h + pad_top + pad_bottom; + int64_t output_w = input_w + pad_left + pad_right; + + TORCH_CHECK( + pad_left < input_w && pad_right < input_w, + "Argument #4: Padding size " + "should be less than the corresponding input dimension, but got: padding (", + pad_left, + ", ", + pad_right, + ") at dimension ", + dim_w, + " of input ", + input.sizes()); + TORCH_CHECK( + pad_top < input_h && pad_bottom < input_h, + "Argument #6: Padding size " + "should be less than the corresponding input dimension, but got: padding (", + pad_top, + ", ", + pad_bottom, + ") at dimension ", + dim_h, + " of input ", + input.sizes()); + TORCH_CHECK( + pad_front < input_d && pad_back < input_d, + "Argument #8: Padding size " + "should be less than the corresponding input dimension, but got: padding (", + pad_front, + ", ", + pad_back, + ") at dimension ", + dim_d, + " of input ", + input.sizes()); + + TORCH_CHECK( + output_w >= 1 || output_h >= 1 || output_d >= 1, + "input (D: ", + input_d, + " H: ", + input_h, + ", W: ", + input_w, + ") is too small." + " Calculated output D: ", + output_d, + " H: ", + output_h, + " W: ", + output_w); + + if (output.defined()) { + if (batch_mode) { + xpu::resize_out( + output, + {input.size(0), nplane, output_d, output_h, output_w}, + {}, + input.options()); + } else { + xpu::resize_out( + output, {nplane, output_d, output_h, output_w}, {}, input.options()); + } + } else { + if (batch_mode) { + output = xpu::create_out( + {input.size(0), nplane, output_d, output_h, output_w}, + {}, + input.options()); + } else { + output = xpu::create_out( + {nplane, output_d, output_h, output_w}, {}, input.options()); + } + } +} + +void reflection_pad3d_backward_meta( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding) { + TORCH_CHECK(padding.size() == 6, "padding size is expected to be 6"); + TORCH_CHECK(input.dim() > 3); + TORCH_CHECK(grad_output.dim() == input.dim()); + + int64_t pad_left = padding[0]; + int64_t pad_right = padding[1]; + int64_t pad_top = padding[2]; + int64_t pad_bottom = padding[3]; + int64_t pad_front = padding[4]; + int64_t pad_back = padding[5]; + int64_t dim_w = 3; + int64_t dim_h = 2; + int64_t dim_d = 1; + + if (input.dim() == 5) { + // batch mode + dim_w++; + dim_h++; + dim_d++; + } + + int64_t input_d = input.size(dim_d); + int64_t input_h = input.size(dim_h); + int64_t input_w = input.size(dim_w); + int64_t output_d = input_d + pad_front + pad_back; + int64_t output_h = input_h + pad_top + pad_bottom; + int64_t output_w = input_w + pad_left + pad_right; + + TORCH_CHECK( + output_w == grad_output.size(dim_w), + "grad_output width unexpected." + " Expected: ", + output_w, + ", Got: ", + grad_output.size(dim_w)); + TORCH_CHECK( + output_h == grad_output.size(dim_h), + "grad_output height unexpected." + " Expected: ", + output_h, + ", Got: ", + grad_output.size(dim_h)); + TORCH_CHECK( + output_d == grad_output.size(dim_d), + "grad_output depth unexpected." + " Expected: ", + output_d, + ", Got: ", + grad_output.size(dim_d)); + + if (grad_input.defined()) { + xpu::resize_out(grad_input, input.sizes(), {}, input.options()); + } else { + grad_input = xpu::create_out(input.sizes(), {}, input.options()); + } +} + +Tensor XPUNativeFunctions::reflection_pad1d( + const Tensor& input, + IntArrayRef padding) { + Tensor output; + reflection_pad1d_meta(output, input, padding); + native::xpu::reflection_pad1d_kernel(output, input, padding); + return output; +} + +Tensor& XPUNativeFunctions::reflection_pad1d_out( + const Tensor& input, + IntArrayRef padding, + Tensor& output) { + reflection_pad1d_meta(output, input, padding); + native::xpu::reflection_pad1d_kernel(output, input, padding); + return output; +} + +Tensor XPUNativeFunctions::reflection_pad1d_backward( + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding) { + Tensor grad_input; + reflection_pad1d_backward_meta(grad_input, grad_output, input, padding); + native::xpu::reflection_pad1d_backward_kernel( + grad_input, grad_output, input, padding); + return grad_input; +} + +Tensor& XPUNativeFunctions::reflection_pad1d_backward_out( + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding, + Tensor& grad_input) { + native::xpu::reflection_pad1d_backward_kernel( + grad_input, grad_output, input, padding); + return grad_input; +} + Tensor& XPUNativeFunctions::reflection_pad2d_out( const Tensor& input, IntArrayRef padding, @@ -49,4 +361,44 @@ Tensor XPUNativeFunctions::reflection_pad2d_backward( return grad_input; } +Tensor XPUNativeFunctions::reflection_pad3d( + const Tensor& input, + IntArrayRef padding) { + Tensor output; + reflection_pad3d_meta(output, input, padding); + native::xpu::reflection_pad3d_kernel(output, input, padding); + return output; +} + +Tensor& XPUNativeFunctions::reflection_pad3d_out( + const Tensor& input, + IntArrayRef padding, + Tensor& output) { + reflection_pad3d_meta(output, input, padding); + native::xpu::reflection_pad3d_kernel(output, input, padding); + return output; +} + +Tensor XPUNativeFunctions::reflection_pad3d_backward( + const Tensor& grad_output, + const Tensor& input, + at::IntArrayRef padding) { + Tensor grad_input; + reflection_pad3d_backward_meta(grad_input, grad_output, input, padding); + native::xpu::reflection_pad3d_backward_kernel( + grad_input, grad_output, input, padding); + return grad_input; +} + +Tensor& XPUNativeFunctions::reflection_pad3d_backward_out( + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding, + Tensor& grad_input) { + reflection_pad3d_backward_meta(grad_input, grad_output, input, padding); + native::xpu::reflection_pad3d_backward_kernel( + grad_input, grad_output, input, padding); + return grad_input; +} + } // namespace at \ No newline at end of file diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 7bfdd6abd..471081ccd 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -299,8 +299,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "prod", "prod.int_out", "put_", - "reflection_pad1d_backward.grad_input", - "reflection_pad1d.out", "renorm.out", "repeat_interleave.Tensor", "replication_pad1d_backward.grad_input", @@ -380,8 +378,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "vdot", "xlogy.OutTensor", "_upsample_bicubic2d_aa.out", - "reflection_pad3d.out", - "reflection_pad3d_backward.grad_input", "replication_pad3d.out", "replication_pad3d_backward", }; diff --git a/src/ATen/native/xpu/sycl/ReflectionPadKernels.cpp b/src/ATen/native/xpu/sycl/ReflectionPadKernels.cpp index b96bb2d7a..1627a5def 100644 --- a/src/ATen/native/xpu/sycl/ReflectionPadKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReflectionPadKernels.cpp @@ -5,6 +5,7 @@ #pragma GCC diagnostic ignored "-Wreturn-type" #include +#include #include #include #include @@ -15,6 +16,30 @@ namespace at::native::xpu { +inline std::pair get_index_mapping1d( + int64_t input_w, + int64_t output_w, + int64_t output_x, + int64_t pad_l, + const sycl::nd_item<3> item) { + auto input_offset = + (item.get_group(1) + item.get_group(0) * item.get_group_range(1)) * + input_w; + auto output_offset = + (item.get_group(1) + item.get_group(0) * item.get_group_range(1)) * + output_w; + + auto i_start_x = std::max(int64_t(0), -pad_l); + auto o_start_x = std::max(int64_t(0), pad_l); + + int64_t input_x = std::abs(output_x - pad_l) - + std::abs(output_x - (input_w + pad_l - 1)) - output_x + 2 * pad_l + + input_w - 1 - o_start_x + i_start_x; + + return std::make_pair( + input_offset + input_x, output_offset + output_x); +} + inline std::pair get_index_mapping2d( int64_t input_dim_x, int64_t input_dim_y, @@ -26,10 +51,10 @@ inline std::pair get_index_mapping2d( const sycl::nd_item<3> item) { // 3D grid of 1D blocks auto input_offset = - (item.get_group(1) + item.get_group(2) * item.get_group_range(1)) * + (item.get_group(1) + item.get_group(0) * item.get_group_range(1)) * input_dim_x * input_dim_y; auto output_offset = - (item.get_group(1) + item.get_group(2) * item.get_group_range(1)) * + (item.get_group(1) + item.get_group(0) * item.get_group_range(1)) * output_dim_x * output_dim_y; auto output_x = output_xy % output_dim_x; @@ -53,10 +78,122 @@ inline std::pair get_index_mapping2d( output_offset + output_y * output_dim_x + output_x); } +template +struct ReflectionPad1dKernelFunctor { + void operator()(sycl::nd_item<3> item) const { + auto output_x = item.get_global_id(2); + + if (output_x < output_w_) { + // input index and output index mapping + auto index_pair = + get_index_mapping1d(input_w_, output_w_, output_x, pad_l_, item); + output_data_[index_pair.second] = input_data_[index_pair.first]; + } + } + ReflectionPad1dKernelFunctor( + scalar_t* input_data, + scalar_t* output_data, + int64_t input_w, + int64_t pad_l, + int64_t output_w) + : input_data_(input_data), + output_data_(output_data), + input_w_(input_w), + pad_l_(pad_l), + output_w_(output_w) {} + + private: + scalar_t* input_data_; + scalar_t* output_data_; + int64_t input_w_; + int64_t pad_l_; + int64_t output_w_; +}; + +template +void reflection_pad1d_template( + scalar_t* input, + scalar_t* output, + int64_t input_w, + int64_t pad_l, + int64_t pad_r, + int64_t nbatch, + int64_t nplane, + int64_t output_w) { + auto queue = getCurrentSYCLQueue(); + int64_t work_group_size = syclMaxWorkItemsPerEU(); + int64_t work_group_num = at::ceil_div(output_w, work_group_size); + + ReflectionPad1dKernelFunctor kfn( + input, output, input_w, pad_l, output_w); + sycl_kernel_submit( + sycl::range<3>(nbatch, nplane, work_group_size * work_group_num), + sycl::range<3>(1, 1, work_group_size), + queue, + kfn); +} + +template +struct ReflectionPad1dBackwardKernelFunctor { + void operator()(sycl::nd_item<3> item) const { + auto output_x = item.get_global_id(2); + + if (output_x < output_w_) { + // grad input index and grad output index mapping + auto index_pair = + get_index_mapping1d(input_w_, output_w_, output_x, pad_l_, item); + atomicAdd( + (sycl_global_ptr)&grad_input_data_[index_pair.first], + grad_output_data_[index_pair.second]); + } + } + ReflectionPad1dBackwardKernelFunctor( + scalar_t* grad_input_data, + scalar_t* grad_output_data, + int64_t input_w, + int64_t pad_l, + int64_t output_w) + : grad_input_data_(grad_input_data), + grad_output_data_(grad_output_data), + input_w_(input_w), + pad_l_(pad_l), + output_w_(output_w) {} + + private: + scalar_t* grad_input_data_; + scalar_t* grad_output_data_; + int64_t input_w_; + int64_t pad_l_; + int64_t output_w_; +}; + +template +void reflection_pad1d_backward_template( + scalar_t* grad_input, + scalar_t* grad_output, + int64_t input_w, + int64_t pad_l, + int64_t pad_r, + int64_t nbatch, + int64_t nplane, + int64_t output_w) { + auto queue = getCurrentSYCLQueue(); + int64_t work_group_size = syclMaxWorkItemsPerEU(); + int64_t work_group_num = at::ceil_div(output_w, work_group_size); + + ReflectionPad1dBackwardKernelFunctor kfn( + grad_input, grad_output, input_w, pad_l, output_w); + sycl_kernel_submit( + sycl::range<3>(nbatch, nplane, work_group_size * work_group_num), + sycl::range<3>(1, 1, work_group_size), + queue, + kfn); +} + template struct ReflectionPad2dKernellFunctor { void operator()(sycl::nd_item<3> item) const { - auto output_xy = item.get_global_id(0); + auto output_xy = item.get_global_id(2); if (output_xy < output_dim_x_ * output_dim_y_) { // input index and output index mapping @@ -131,8 +268,8 @@ void reflection_pad2d_template( output_dim_x, output_dim_y); sycl_kernel_submit( - sycl::range<3>(work_group_size * work_group_num, nplane, nbatch), - sycl::range<3>(work_group_size, 1, 1), + sycl::range<3>(nbatch, nplane, work_group_size * work_group_num), + sycl::range<3>(1, 1, work_group_size), queue, kfn); } @@ -140,7 +277,7 @@ void reflection_pad2d_template( template struct ReflectionPad2dBackwardKernelFunctor { void operator()(sycl::nd_item<3> item) const { - auto output_xy = item.get_global_id(0); + auto output_xy = item.get_global_id(2); if (output_xy < output_dim_x_ * output_dim_y_) { // grad input index and grad output index mapping @@ -216,12 +353,266 @@ void reflection_pad2d_backward_template( output_dim_x, output_dim_y); sycl_kernel_submit( - sycl::range<3>(work_group_size * work_group_num, nplane, nbatch), - sycl::range<3>(work_group_size, 1, 1), + sycl::range<3>(nbatch, nplane, work_group_size * work_group_num), + sycl::range<3>(1, 1, work_group_size), queue, kfn); } +template +struct ParallelReflectionPad3dKernelFunctor { + void operator()(sycl::nd_item<3> item) const { + auto output_id = item.get_global_id(2); + if (output_id >= output_plane_size_) { + return; + } + + int64_t output_x = output_id % output_.size(4); + int64_t output_y = (output_id / output_.size(4)) % output_.size(3); + int64_t output_z = output_id / (output_.size(3) * output_.size(4)); + + int64_t i_start_x = std::max(int64_t(0), -pad_left_); + int64_t o_start_x = std::max(int64_t(0), pad_left_); + int64_t i_start_y = std::max(int64_t(0), -pad_top_); + int64_t o_start_y = std::max(int64_t(0), pad_top_); + int64_t i_start_z = std::max(int64_t(0), -pad_front_); + int64_t o_start_z = std::max(int64_t(0), pad_front_); + + int64_t input_x = std::abs(output_x - pad_left_) - + std::abs(output_x - (input_.size(4) + pad_left_ - 1)) - output_x + + 2 * pad_left_ + input_.size(4) - 1 - o_start_x + i_start_x; + int64_t input_y = std::abs(output_y - pad_top_) - + std::abs(output_y - (input_.size(3) + pad_top_ - 1)) - output_y + + 2 * pad_top_ + input_.size(3) - 1 - o_start_y + i_start_y; + + int64_t input_z = std::abs(output_z - pad_front_) - + std::abs(output_z - (input_.size(2) + pad_front_ - 1)) - output_z + + 2 * pad_front_ + input_.size(2) - 1 - o_start_z + i_start_z; + + f_(input_, + output_, + item.get_group(1), + item.get_group(0), + output_z, + output_y, + output_x, + input_z, + input_y, + input_x); + } + ParallelReflectionPad3dKernelFunctor( + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + int64_t pad_left, + int64_t pad_top, + int64_t pad_front, + const F f, + int64_t output_plane_size) + : input_(input), + output_(output), + pad_left_(pad_left), + pad_top_(pad_top), + pad_front_(pad_front), + f_(f), + output_plane_size_(output_plane_size) {} + + private: + PackedTensorAccessor64 input_; + PackedTensorAccessor64 output_; + int64_t pad_left_; + int64_t pad_top_; + int64_t pad_front_; + const F f_; + int64_t output_plane_size_; +}; + +template +inline void parallel_reflection_pad3d( + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + int64_t pad_left, + int64_t pad_top, + int64_t pad_front, + const F& f) { + auto queue = getCurrentSYCLQueue(); + int64_t output_plane_size = output.size(2) * output.size(3) * output.size(4); + int64_t work_group_size = syclMaxWorkItemsPerEU(); + int64_t work_group_num = at::ceil_div(output_plane_size, work_group_size); + int64_t nplane = input.size(1); + int64_t nbatch = input.size(0); + + ParallelReflectionPad3dKernelFunctor kfn( + input, output, pad_left, pad_top, pad_front, f, output_plane_size); + sycl_kernel_submit( + sycl::range<3>(nbatch, nplane, work_group_size * work_group_num), + sycl::range<3>(1, 1, work_group_size), + queue, + kfn); +} + +template +struct reflection_pad3d_kernel_functor { + void operator()( + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + int64_t plane, + int64_t batch, + int64_t output_z, + int64_t output_y, + int64_t output_x, + int64_t input_z, + int64_t input_y, + int64_t input_x) const { + auto value_to_copy = input[batch][plane][input_z][input_y][input_x]; + output[batch][plane][output_z][output_y][output_x] = value_to_copy; + } +}; + +template +void reflection_pad3d_template( + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + int64_t pad_left, + int64_t pad_top, + int64_t pad_front) { + reflection_pad3d_kernel_functor f; + parallel_reflection_pad3d(input, output, pad_left, pad_top, pad_front, f); +} + +template +struct reflection_pad3d_backward_kernel_functor { + void operator()( + PackedTensorAccessor64 grad_input, + PackedTensorAccessor64 grad_output, + int64_t plane, + int64_t batch, + int64_t output_z, + int64_t output_y, + int64_t output_x, + int64_t input_z, + int64_t input_y, + int64_t input_x) const { + auto value_to_add = grad_output[batch][plane][output_z][output_y][output_x]; + auto target = (sycl_global_ptr)&grad_input[batch][plane][input_z] + [input_y][input_x]; + atomicAdd(target, value_to_add); + } +}; + +template +void reflection_pad3d_backward_template( + PackedTensorAccessor64 grad_input, + PackedTensorAccessor64 grad_output, + int64_t pad_left, + int64_t pad_top, + int64_t pad_front) { + reflection_pad3d_backward_kernel_functor f; + parallel_reflection_pad3d( + grad_input, grad_output, pad_left, pad_top, pad_front, f); +} + +void reflection_pad1d_kernel( + Tensor& output, + const Tensor& input_, + IntArrayRef padding) { + TORCH_CHECK( + canUse32BitIndexMath(input_), + "input tensor must fit into 32-bit index math"); + + if (output.numel() == 0) { + return; + } + + int64_t dim_plane = 0; + int64_t dim_w = 1; + int64_t nbatch = 1; + + if (input_.ndimension() == 3) { + nbatch = input_.size(0); + dim_plane++; + dim_w++; + } + + int64_t pad_l = padding[0]; + int64_t pad_r = padding[1]; + + int64_t nplane = input_.size(dim_plane); + int64_t input_w = input_.size(dim_w); + int64_t output_w = input_w + pad_l + pad_r; + + Tensor input = input_.contiguous(); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_xpu", [&] { + reflection_pad1d_template( + input.data_ptr(), + output.data_ptr(), + input_w, + pad_l, + pad_r, + nbatch, + nplane, + output_w); + }); +} + +void reflection_pad1d_backward_kernel( + Tensor& grad_input, + const Tensor& grad_output_, + const Tensor& input, + IntArrayRef padding) { + globalContext().alertNotDeterministic("reflection_pad1d_backward_out_xpu"); + grad_input.zero_(); + + if (grad_input.numel() == 0) { + return; + } + + TORCH_CHECK( + canUse32BitIndexMath(input), + "input tensor must fit into 32-bit index math"); + + TORCH_CHECK( + canUse32BitIndexMath(grad_output_), + "input tensor must fit into 32-bit index math"); + + int64_t dim_plane = 0; + int64_t dim_w = 1; + int64_t nbatch = 1; + + if (input.ndimension() == 3) { + nbatch = input.size(0); + dim_plane++; + dim_w++; + } + + int64_t pad_l = padding[0]; + int64_t pad_r = padding[1]; + + int64_t nplane = input.size(dim_plane); + int64_t input_w = input.size(dim_w); + int64_t output_w = input_w + pad_l + pad_r; + + Tensor grad_output = grad_output_.contiguous(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, + kBFloat16, + grad_input.scalar_type(), + "reflection_pad1d_backward_xpu", + [&] { + reflection_pad1d_backward_template( + grad_input.data_ptr(), + grad_output.data_ptr(), + input_w, + pad_l, + pad_r, + nbatch, + nplane, + output_w); + }); +} + void reflection_pad2d_kernel( Tensor& output, const Tensor& input_, @@ -396,6 +787,90 @@ void reflection_pad2d_backward_kernel( }); } +void reflection_pad3d_kernel( + Tensor& output, + const Tensor& input_, + IntArrayRef padding) { + TORCH_CHECK( + canUse32BitIndexMath(input_), + "input tensor must fit into 32-bit index math"); + + if (output.numel() == 0) { + return; + } + + int64_t pad_left = padding[0]; + int64_t pad_top = padding[2]; + int64_t pad_front = padding[4]; + + auto input = input_.contiguous(); + bool batch_mode = (input.dim() == 5); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kHalf, kBFloat16, input.scalar_type(), "reflection_pad3d_xpu", [&] { + auto input_inner = input; + auto output_inner = output; + if (!batch_mode) { + input_inner = input.unsqueeze(0); + output_inner = output.unsqueeze(0); + } + + auto input_packed = input_inner.packed_accessor64(); + auto output_packed = output_inner.packed_accessor64(); + + reflection_pad3d_template( + input_packed, output_packed, pad_left, pad_top, pad_front); + }); +} + +void reflection_pad3d_backward_kernel( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding) { + globalContext().alertNotDeterministic("reflection_pad3d_backward_out_xpu"); + TORCH_CHECK( + canUse32BitIndexMath(input), + "input tensor must fit into 32-bit index math"); + TORCH_CHECK( + canUse32BitIndexMath(grad_output), + "input tensor must fit into 32-bit index math"); + + if (grad_input.numel() == 0) { + return; + } + grad_input.zero_(); + + int64_t pad_left = padding[0]; + int64_t pad_top = padding[2]; + int64_t pad_front = padding[4]; + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "reflection_pad3d_backward_xpu", + [&] { + auto grad_input_ = grad_input; + auto grad_output_ = grad_output; + if (input.dim() == 4) { + // non-batch mode + grad_input_ = grad_input.unsqueeze(0); + grad_output_ = grad_output.unsqueeze(0); + } + + auto grad_input_packed = grad_input_.packed_accessor64(); + auto grad_output_packed = grad_output_.packed_accessor64(); + + reflection_pad3d_backward_template( + grad_input_packed, + grad_output_packed, + pad_left, + pad_top, + pad_front); + }); +} + } // namespace at::native::xpu #pragma GCC diagnostic pop diff --git a/src/ATen/native/xpu/sycl/ReflectionPadKernels.h b/src/ATen/native/xpu/sycl/ReflectionPadKernels.h index 8f103b73b..a21f6c8ee 100644 --- a/src/ATen/native/xpu/sycl/ReflectionPadKernels.h +++ b/src/ATen/native/xpu/sycl/ReflectionPadKernels.h @@ -4,6 +4,17 @@ namespace at::native::xpu { +void reflection_pad1d_kernel( + Tensor& output, + const Tensor& input_, + IntArrayRef padding); + +void reflection_pad1d_backward_kernel( + Tensor& grad_input, + const Tensor& grad_output_, + const Tensor& input, + IntArrayRef padding); + void reflection_pad2d_kernel( Tensor& output, const Tensor& input_, @@ -15,4 +26,15 @@ void reflection_pad2d_backward_kernel( const Tensor& input, IntArrayRef padding); +void reflection_pad3d_kernel( + Tensor& output, + const Tensor& input_, + IntArrayRef padding); + +void reflection_pad3d_backward_kernel( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding); + } // namespace at::native::xpu \ No newline at end of file diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 2ecc6790b..f103a7795 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -464,10 +464,18 @@ supported: - searchsorted.Scalar - searchsorted.Scalar_out - trace + - reflection_pad1d + - reflection_pad1d.out + - reflection_pad1d_backward + - reflection_pad1d_backward.grad_input - reflection_pad2d - reflection_pad2d.out - reflection_pad2d_backward - reflection_pad2d_backward.grad_input + - reflection_pad3d + - reflection_pad3d.out + - reflection_pad3d_backward + - reflection_pad3d_backward.grad_input - native_group_norm - native_group_norm_backward - elu