diff --git a/src/ATen/native/xpu/GridSampler.cpp b/src/ATen/native/xpu/GridSampler.cpp index fa9a5d17e..5d69d8c85 100644 --- a/src/ATen/native/xpu/GridSampler.cpp +++ b/src/ATen/native/xpu/GridSampler.cpp @@ -47,5 +47,46 @@ std::tuple grid_sampler_2d_backward_xpu( output_mask); return std::make_tuple(grad_input, grad_grid); } + +Tensor grid_sampler_3d_xpu( + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + return xpu::grid_sampler_3d_kernel( + input, grid, interpolation_mode, padding_mode, align_corners); +} + +std::tuple grid_sampler_3d_backward_xpu( + const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask) { + auto input_requires_grad = output_mask[0]; + Tensor grad_input = ([&]() { + if (input_requires_grad) { + return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } else { + return Tensor(); + } + })(); + auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + xpu::grid_sampler_3d_backward_kernel( + grad_input, + grad_grid, + grad_output, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + output_mask); + return std::make_tuple(grad_input, grad_grid); +} + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/sycl/GridSampler.cpp b/src/ATen/native/xpu/sycl/GridSampler.cpp index 2bfd0505a..c2b51d581 100644 --- a/src/ATen/native/xpu/sycl/GridSampler.cpp +++ b/src/ATen/native/xpu/sycl/GridSampler.cpp @@ -23,8 +23,8 @@ using namespace at::xpu::detail; template struct GridSampler2dKernelFunctor { using opmath_t = at::opmath_type; - void operator()(sycl::nd_item<1> item_id) const { - auto index = item_id.get_global_linear_id(); + void operator()(sycl::nd_item<1> item) const { + auto index = item.get_global_linear_id(); if (index >= nthreads_) return; const index_t w = index % out_W_; @@ -357,8 +357,8 @@ Tensor grid_sampler_2d_kernel( template struct GridSampler2dBackwardKernelFunctor { - void operator()(sycl::nd_item<1> item_id) const { - auto index = item_id.get_global_linear_id(); + void operator()(sycl::nd_item<1> item) const { + auto index = item.get_global_linear_id(); if (index >= nthreads_) return; const index_t w = index % out_W_; @@ -838,6 +838,965 @@ void grid_sampler_2d_backward_kernel( } } +template +struct GridSampler3dKernelFunctor { + using opmath_t = at::opmath_type; + void operator()(sycl::nd_item<1> item) const { + auto index = item.get_global_linear_id(); + if (index >= nthreads_) + return; + + const index_t w = index % out_W_; + const index_t h = (index / out_W_) % out_H_; + const index_t d = (index / (out_H_ * out_W_)) % out_D_; + const index_t n = index / (out_D_ * out_H_ * out_W_); + const index_t grid_offset = + n * grid_sN_ + d * grid_sD_ + h * grid_sH_ + w * grid_sW_; + + // get the corresponding input_ x, y, z co-ordinates from grid_ + opmath_t ix = grid_.data[grid_offset]; + opmath_t iy = grid_.data[grid_offset + grid_sCoor_]; + opmath_t iz = grid_.data[grid_offset + 2 * grid_sCoor_]; + + ix = at::native::xpu::grid_sampler_compute_source_index( + ix, inp_W_, padding_mode_, align_corners_); + iy = at::native::xpu::grid_sampler_compute_source_index( + iy, inp_H_, padding_mode_, align_corners_); + iz = at::native::xpu::grid_sampler_compute_source_index( + iz, inp_D_, padding_mode_, align_corners_); + + if (interpolation_mode_ == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(std::floor(ix)); + index_t iy_tnw = static_cast(std::floor(iy)); + index_t iz_tnw = static_cast(std::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + opmath_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + opmath_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + opmath_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + opmath_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + opmath_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + opmath_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + opmath_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + opmath_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input_.data + n * inp_sN_; + auto out_ptr_NCDHW = + output_.data + n * out_sN_ + d * out_sD_ + h * out_sH_ + w * out_sW_; + for (index_t c = 0; c < C_; + ++c, inp_ptr_NC += inp_sC_, out_ptr_NCDHW += out_sC_) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * + // tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * + // tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * + // bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * + // bse + opmath_t out_acc = 0; + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_tnw * inp_sD_ + iy_tnw * inp_sH_ + ix_tnw * inp_sW_] * + tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_tne * inp_sD_ + iy_tne * inp_sH_ + ix_tne * inp_sW_] * + tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_tsw * inp_sD_ + iy_tsw * inp_sH_ + ix_tsw * inp_sW_] * + tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_tse * inp_sD_ + iy_tse * inp_sH_ + ix_tse * inp_sW_] * + tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_bnw * inp_sD_ + iy_bnw * inp_sH_ + ix_bnw * inp_sW_] * + bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_bne * inp_sD_ + iy_bne * inp_sH_ + ix_bne * inp_sW_] * + bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_bsw * inp_sD_ + iy_bsw * inp_sH_ + ix_bsw * inp_sW_] * + bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D_, inp_H_, inp_W_)) { + out_acc += + inp_ptr_NC + [iz_bse * inp_sD_ + iy_bse * inp_sH_ + ix_bse * inp_sW_] * + bse; + } + *out_ptr_NCDHW = out_acc; + } + } else if (interpolation_mode_ == GridSamplerInterpolation::Nearest) { + index_t ix_nearest = static_cast(std::nearbyint(ix)); + index_t iy_nearest = static_cast(std::nearbyint(iy)); + index_t iz_nearest = static_cast(std::nearbyint(iz)); + + // assign nearest neighor pixel value to output_ pixel + auto inp_ptr_NC = input_.data + n * inp_sN_; + auto out_ptr_NCDHW = + output_.data + n * out_sN_ + d * out_sD_ + h * out_sH_ + w * out_sW_; + for (index_t c = 0; c < C_; + ++c, inp_ptr_NC += inp_sC_, out_ptr_NCDHW += out_sC_) { + if (within_bounds_3d( + iz_nearest, iy_nearest, ix_nearest, inp_D_, inp_H_, inp_W_)) { + *out_ptr_NCDHW = inp_ptr_NC + [iz_nearest * inp_sD_ + iy_nearest * inp_sH_ + + ix_nearest * inp_sW_]; + } else { + *out_ptr_NCDHW = static_cast(0); + } + } + } + } + GridSampler3dKernelFunctor( + const index_t nthreads, + TensorInfo input, + TensorInfo grid, + TensorInfo output, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + const bool align_corners, + index_t C, + index_t inp_D, + index_t inp_H, + index_t inp_W, + index_t out_D, + index_t out_H, + index_t out_W, + index_t inp_sN, + index_t inp_sC, + index_t inp_sD, + index_t inp_sH, + index_t inp_sW, + index_t grid_sN, + index_t grid_sD, + index_t grid_sH, + index_t grid_sW, + index_t grid_sCoor, + index_t out_sN, + index_t out_sC, + index_t out_sD, + index_t out_sH, + index_t out_sW) + : nthreads_(nthreads), + input_(input), + grid_(grid), + output_(output), + interpolation_mode_(interpolation_mode), + padding_mode_(padding_mode), + align_corners_(align_corners), + C_(C), + inp_D_(inp_D), + inp_H_(inp_H), + inp_W_(inp_W), + out_D_(out_D), + out_H_(out_H), + out_W_(out_W), + inp_sN_(inp_sN), + inp_sC_(inp_sC), + inp_sD_(inp_sD), + inp_sH_(inp_sH), + inp_sW_(inp_sW), + grid_sN_(grid_sN), + grid_sD_(grid_sD), + grid_sH_(grid_sH), + grid_sW_(grid_sW), + grid_sCoor_(grid_sCoor), + out_sN_(out_sN), + out_sC_(out_sC), + out_sD_(out_sD), + out_sH_(out_sH), + out_sW_(out_sW) {} + + private: + const index_t nthreads_; + TensorInfo input_; + TensorInfo grid_; + TensorInfo output_; + const GridSamplerInterpolation interpolation_mode_; + const GridSamplerPadding padding_mode_; + bool align_corners_; + index_t C_; + index_t inp_D_; + index_t inp_H_; + index_t inp_W_; + index_t out_D_; + index_t out_H_; + index_t out_W_; + index_t inp_sN_; + index_t inp_sC_; + index_t inp_sD_; + index_t inp_sH_; + index_t inp_sW_; + index_t grid_sN_; + index_t grid_sD_; + index_t grid_sH_; + index_t grid_sW_; + index_t grid_sCoor_; + index_t out_sN_; + index_t out_sC_; + index_t out_sD_; + index_t out_sH_; + index_t out_sW_; +}; + +template +void grid_sampler_3d_forward_template( + const index_t nthreads, + TensorInfo input, + TensorInfo grid, + TensorInfo output, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + bool align_corners) { + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + index_t out_sN = output.strides[0]; + index_t out_sC = output.strides[1]; + index_t out_sD = output.strides[2]; + index_t out_sH = output.strides[3]; + index_t out_sW = output.strides[4]; + + GridSampler3dKernelFunctor kfn( + nthreads, + input, + grid, + output, + interpolation_mode, + padding_mode, + align_corners, + C, + inp_D, + inp_H, + inp_W, + out_D, + out_H, + out_W, + inp_sN, + inp_sC, + inp_sD, + inp_sH, + inp_sW, + grid_sN, + grid_sD, + grid_sH, + grid_sW, + grid_sCoor, + out_sN, + out_sC, + out_sD, + out_sH, + out_sW); + + const auto wgroup_size = syclMaxWorkGroupSize(kfn); + const auto ngroups = (nthreads + wgroup_size - 1) / wgroup_size; + auto& queue = getCurrentSYCLQueue(); + + sycl_kernel_submit( + sycl::range<1>(ngroups * wgroup_size), + sycl::range<1>(wgroup_size), + queue, + kfn); +} + +Tensor grid_sampler_3d_kernel( + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + // See NOTE [ grid_sampler Native Functions ]. + // Add checks here in case this is called instead of grid_sampler. + check_grid_sampler_common(input, grid); + check_grid_sampler_3d(input, grid, interpolation_mode); + + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + auto output = at::empty({N, input.size(1), D, H, W}, input.options()); + int64_t count = N * D * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + input.scalar_type(), + "grid_sampler_3d_xpu", + [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(output)) { + grid_sampler_3d_forward_template( + static_cast(count), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners); + } else { + grid_sampler_3d_forward_template( + count, + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners); + } + }); + } + return output; +} + +template +struct GridSampler3dBackwardKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + auto index = item.get_global_linear_id(); + if (index >= nthreads_) + return; + + const index_t w = index % out_W_; + const index_t h = (index / out_W_) % out_H_; + const index_t d = (index / (out_H_ * out_W_)) % out_D_; + const index_t n = index / (out_D_ * out_H_ * out_W_); + const auto grid_offset = + n * grid_sN_ + d * grid_sD_ + h * grid_sH_ + w * grid_sW_; + + // get the corresponding input_ x, y, z co-ordinates from grid_ + scalar_t ix = grid_.data[grid_offset]; + scalar_t iy = grid_.data[grid_offset + grid_sCoor_]; + scalar_t iz = grid_.data[grid_offset + 2 * grid_sCoor_]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult, giz_mult; + ix = at::native::xpu::grid_sampler_compute_source_index_set_grad( + ix, inp_W_, padding_mode_, align_corners_, &gix_mult); + iy = at::native::xpu::grid_sampler_compute_source_index_set_grad( + iy, inp_H_, padding_mode_, align_corners_, &giy_mult); + iz = at::native::xpu::grid_sampler_compute_source_index_set_grad( + iz, inp_D_, padding_mode_, align_corners_, &giz_mult); + + if (interpolation_mode_ == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(std::floor(ix)); + index_t iy_tnw = static_cast(std::floor(iy)); + index_t iz_tnw = static_cast(std::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + scalar_t gix = static_cast(0), giy = static_cast(0), + giz = static_cast(0); + scalar_t* gOut_ptr_NCDHW = grad_output_.data + n * gOut_sN_ + + d * gOut_sD_ + h * gOut_sH_ + w * gOut_sW_; + index_t NC_offset = n * gInp_sN_; + scalar_t* inp_ptr_NC = input_.data + n * inp_sN_; + // calculate bilinear weighted pixel value and set output pixel + for (index_t c = 0; c < C_; ++c, + gOut_ptr_NCDHW += gOut_sC_, + NC_offset += gInp_sC_, + inp_ptr_NC += inp_sC_) { + scalar_t gOut = *gOut_ptr_NCDHW; + + if (input_requires_grad_) { + // calculate and set grad_input_ + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_tnw, + iy_tnw, + ix_tnw, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + tnw * gOut, + NC_offset); + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_tne, + iy_tne, + ix_tne, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + tne * gOut, + NC_offset); + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_tsw, + iy_tsw, + ix_tsw, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + tsw * gOut, + NC_offset); + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_tse, + iy_tse, + ix_tse, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + tse * gOut, + NC_offset); + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_bnw, + iy_bnw, + ix_bnw, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + bnw * gOut, + NC_offset); + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_bne, + iy_bne, + ix_bne, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + bne * gOut, + NC_offset); + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_bsw, + iy_bsw, + ix_bsw, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + bsw * gOut, + NC_offset); + at::native::xpu::safe_add_3d( + grad_input_.data, + iz_bse, + iy_bse, + ix_bse, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + bse * gOut, + NC_offset); + } + + // calculate grad_grid_ + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D_, inp_H_, inp_W_)) { + scalar_t tnw_val = inp_ptr_NC + [iz_tnw * inp_sD_ + iy_tnw * inp_sH_ + ix_tnw * inp_sW_]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D_, inp_H_, inp_W_)) { + scalar_t tne_val = inp_ptr_NC + [iz_tne * inp_sD_ + iy_tne * inp_sH_ + ix_tne * inp_sW_]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D_, inp_H_, inp_W_)) { + scalar_t tsw_val = inp_ptr_NC + [iz_tsw * inp_sD_ + iy_tsw * inp_sH_ + ix_tsw * inp_sW_]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D_, inp_H_, inp_W_)) { + scalar_t tse_val = inp_ptr_NC + [iz_tse * inp_sD_ + iy_tse * inp_sH_ + ix_tse * inp_sW_]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D_, inp_H_, inp_W_)) { + scalar_t bnw_val = inp_ptr_NC + [iz_bnw * inp_sD_ + iy_bnw * inp_sH_ + ix_bnw * inp_sW_]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D_, inp_H_, inp_W_)) { + scalar_t bne_val = inp_ptr_NC + [iz_bne * inp_sD_ + iy_bne * inp_sH_ + ix_bne * inp_sW_]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D_, inp_H_, inp_W_)) { + scalar_t bsw_val = inp_ptr_NC + [iz_bsw * inp_sD_ + iy_bsw * inp_sH_ + ix_bsw * inp_sW_]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D_, inp_H_, inp_W_)) { + scalar_t bse_val = inp_ptr_NC + [iz_bse * inp_sD_ + iy_bse * inp_sH_ + ix_bse * inp_sW_]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + // assuming grad_grid_ is contiguous + // thus we can + // 1. use index with gGrid_sW_ to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], + // gGrid_ptr_NDHW[2] + scalar_t* gGrid_ptr_NDHW = grad_grid_.data + index * gGrid_sW_; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } else if (interpolation_mode_ == GridSamplerInterpolation::Nearest) { + if (input_requires_grad_) { + auto ix_nearest = static_cast(std::round(ix)); + auto iy_nearest = static_cast(std::round(iy)); + auto iz_nearest = static_cast(std::round(iz)); + + // assign nearest neighor pixel value to output pixel + scalar_t* gOut_ptr_NCDHW = grad_output_.data + n * gOut_sN_ + + d * gOut_sD_ + h * gOut_sH_ + w * gOut_sW_; + index_t NC_offset = n * gInp_sN_; + for (index_t c = 0; c < C_; + ++c, gOut_ptr_NCDHW += gOut_sC_, NC_offset += gInp_sC_) { + // calculate and set grad_input_ + safe_add_3d( + grad_input_.data, + iz_nearest, + iy_nearest, + ix_nearest, + gInp_sD_, + gInp_sH_, + gInp_sW_, + inp_D_, + inp_H_, + inp_W_, + *gOut_ptr_NCDHW, + NC_offset); + } + } + + // assuming grad_grid_ is contiguous + // thus we can + // 1. use index with gGrid_sW_ to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], + // gGrid_ptr_NDHW[2] + scalar_t* gGrid_ptr_NDHW = grad_grid_.data + index * gGrid_sW_; + gGrid_ptr_NDHW[0] = static_cast(0); + gGrid_ptr_NDHW[1] = static_cast(0); + gGrid_ptr_NDHW[2] = static_cast(0); + } + } + GridSampler3dBackwardKernelFunctor( + const index_t nthreads, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_input, + TensorInfo grad_grid, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + bool align_corners, + const bool input_requires_grad, + index_t C, + index_t inp_D, + index_t inp_H, + index_t inp_W, + index_t out_D, + index_t out_H, + index_t out_W, + index_t inp_sN, + index_t inp_sC, + index_t inp_sD, + index_t inp_sH, + index_t inp_sW, + index_t grid_sN, + index_t grid_sD, + index_t grid_sH, + index_t grid_sW, + index_t grid_sCoor, + index_t gOut_sN, + index_t gOut_sC, + index_t gOut_sD, + index_t gOut_sH, + index_t gOut_sW, + int64_t gInp_sN, + int64_t gInp_sC, + int64_t gInp_sD, + int64_t gInp_sH, + int64_t gInp_sW, + index_t gGrid_sW) + : nthreads_(nthreads), + grad_output_(grad_output), + input_(input), + grid_(grid), + grad_input_(grad_input), + grad_grid_(grad_grid), + interpolation_mode_(interpolation_mode), + padding_mode_(padding_mode), + align_corners_(align_corners), + input_requires_grad_(input_requires_grad), + C_(C), + inp_D_(inp_D), + inp_H_(inp_H), + inp_W_(inp_W), + out_D_(out_D), + out_H_(out_H), + out_W_(out_W), + inp_sN_(inp_sN), + inp_sC_(inp_sC), + inp_sD_(inp_sD), + inp_sH_(inp_sH), + inp_sW_(inp_sW), + grid_sN_(grid_sN), + grid_sD_(grid_sD), + grid_sH_(grid_sH), + grid_sW_(grid_sW), + grid_sCoor_(grid_sCoor), + gOut_sN_(gOut_sN), + gOut_sC_(gOut_sC), + gOut_sD_(gOut_sD), + gOut_sH_(gOut_sH), + gOut_sW_(gOut_sW), + gInp_sN_(gInp_sN), + gInp_sC_(gInp_sC), + gInp_sD_(gInp_sD), + gInp_sH_(gInp_sH), + gInp_sW_(gInp_sW), + gGrid_sW_(gGrid_sW) {} + + private: + const index_t nthreads_; + TensorInfo grad_output_; + TensorInfo input_; + TensorInfo grid_; + TensorInfo grad_input_; + TensorInfo grad_grid_; + const GridSamplerInterpolation interpolation_mode_; + const GridSamplerPadding padding_mode_; + bool align_corners_; + const bool input_requires_grad_; + index_t C_; + index_t inp_D_; + index_t inp_H_; + index_t inp_W_; + index_t out_D_; + index_t out_H_; + index_t out_W_; + index_t inp_sN_; + index_t inp_sC_; + index_t inp_sD_; + index_t inp_sH_; + index_t inp_sW_; + index_t grid_sN_; + index_t grid_sD_; + index_t grid_sH_; + index_t grid_sW_; + index_t grid_sCoor_; + index_t gOut_sN_; + index_t gOut_sC_; + index_t gOut_sD_; + index_t gOut_sH_; + index_t gOut_sW_; + int64_t gInp_sN_; + int64_t gInp_sC_; + int64_t gInp_sD_; + int64_t gInp_sH_; + int64_t gInp_sW_; + index_t gGrid_sW_; +}; + +template +void grid_sampler_3d_backward_template( + const index_t nthreads, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_input, // initialized to zeros + // (or unused if input_requires_grad is false) + TensorInfo grad_grid, // initialized to empty + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + bool align_corners, + const bool input_requires_grad) { + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sD = grad_output.strides[2]; + index_t gOut_sH = grad_output.strides[3]; + index_t gOut_sW = grad_output.strides[4]; + // gInp_* are not really needed if input_requires_grad is false. + int64_t gInp_sN = 0; + int64_t gInp_sC = 0; + int64_t gInp_sD = 0; + int64_t gInp_sH = 0; + int64_t gInp_sW = 0; + if (input_requires_grad) { + gInp_sN = grad_input.strides[0]; + gInp_sC = grad_input.strides[1]; + gInp_sD = grad_input.strides[2]; + gInp_sH = grad_input.strides[3]; + gInp_sW = grad_input.strides[4]; + } + index_t gGrid_sW = grad_grid.strides[3]; + + GridSampler3dBackwardKernelFunctor kfn( + nthreads, + grad_output, + input, + grid, + grad_input, + grad_grid, + interpolation_mode, + padding_mode, + align_corners, + input_requires_grad, + C, + inp_D, + inp_H, + inp_W, + out_D, + out_H, + out_W, + inp_sN, + inp_sC, + inp_sD, + inp_sH, + inp_sW, + grid_sN, + grid_sD, + grid_sH, + grid_sW, + grid_sCoor, + gOut_sN, + gOut_sC, + gOut_sD, + gOut_sH, + gOut_sW, + gInp_sN, + gInp_sC, + gInp_sD, + gInp_sH, + gInp_sW, + gGrid_sW); + + const auto wgroup_size = syclMaxWorkGroupSize(kfn); + const auto ngroups = (nthreads + wgroup_size - 1) / wgroup_size; + auto& queue = getCurrentSYCLQueue(); + + sycl_kernel_submit( + sycl::range<1>(ngroups * wgroup_size), + sycl::range<1>(wgroup_size), + queue, + kfn); +} + +void grid_sampler_3d_backward_kernel( + const Tensor& grad_input, + const Tensor& grad_grid, + const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask) { + // See NOTE [ grid_sampler Native Functions ]. + // Add checks here in case this is called instead of grid_sampler. + check_grid_sampler_common(input, grid); + check_grid_sampler_3d(input, grid, interpolation_mode); + + globalContext().alertNotDeterministic("grid_sampler_3d_backward_xpu"); + auto input_requires_grad = output_mask[0]; + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + int64_t count = N * D * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + input.scalar_type(), + "grid_sampler_2d_backward_xpu", + [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_3d_backward_template( + static_cast(count), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + input_requires_grad ? getTensorInfo(grad_input) + : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners, + input_requires_grad); + } else { + grid_sampler_3d_backward_template( + count, + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + input_requires_grad + ? getTensorInfo(grad_input) + : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners, + input_requires_grad); + } + }); + } +} + } // namespace at::native::xpu #pragma GCC diagnostic pop diff --git a/src/ATen/native/xpu/sycl/GridSampler.h b/src/ATen/native/xpu/sycl/GridSampler.h index 56681d526..33192f4cd 100644 --- a/src/ATen/native/xpu/sycl/GridSampler.h +++ b/src/ATen/native/xpu/sycl/GridSampler.h @@ -23,6 +23,27 @@ static inline void safe_add_2d( } } +template +static inline void safe_add_3d( + scalar_t* data, + int64_t d, + int64_t h, + int64_t w, + int64_t sD, + int64_t sH, + int64_t sW, + int64_t D, + int64_t H, + int64_t W, + scalar_t delta, + index_t NC_offset) { + if (within_bounds_3d(d, h, w, D, H, W)) { + atomicAdd( + (sycl_global_ptr)&data[NC_offset + d * sD + h * sH + w * sW], + delta); + } +} + template static inline scalar_t safe_downgrade_to_int_range(scalar_t x) { // -100.0 does not have special meaning. This is just to make sure diff --git a/src/ATen/native/xpu/sycl/GridSamplerKernels.h b/src/ATen/native/xpu/sycl/GridSamplerKernels.h index b56ed8dcd..ee87527ae 100644 --- a/src/ATen/native/xpu/sycl/GridSamplerKernels.h +++ b/src/ATen/native/xpu/sycl/GridSamplerKernels.h @@ -22,4 +22,22 @@ TORCH_XPU_API void grid_sampler_2d_backward_kernel( bool align_corners, std::array output_mask); +TORCH_XPU_API Tensor grid_sampler_3d_kernel( + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners); + +TORCH_XPU_API void grid_sampler_3d_backward_kernel( + const Tensor& grad_input, + const Tensor& grad_grid, + const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask); + } // namespace at::native::xpu diff --git a/test/xpu/extended/skip_list_common.py b/test/xpu/extended/skip_list_common.py index dbeeb0a13..f955be0ef 100644 --- a/test/xpu/extended/skip_list_common.py +++ b/test/xpu/extended/skip_list_common.py @@ -70,6 +70,7 @@ "test_compare_cpu_nn_functional_embedding_bag_xpu_bfloat16", # Double and complex datatype matmul is not supported in oneDNN "test_compare_cpu_cdist_xpu_float64", + "test_compare_cpu_nn_functional_grid_sample_xpu_float64", # bilinear interpolate includes large calculation steps, accuracy reduces in half-precision # Not in CUDA test scope too "test_compare_cpu_nn_functional_upsample_bilinear_xpu_bfloat16", diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 49bbf1c1a..878262f2a 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -220,7 +220,7 @@ "bucketize", "searchsorted", "grid_sampler_2d", - # "nn.functional.grid_sample", # Lack of XPU implementation of aten::grid_sampler_3d. + "nn.functional.grid_sample", "addr", "cdist", "nn.functional.pdist", @@ -306,6 +306,11 @@ # format hint:{op_name:{(cls_name,test_name):{dtype:tol(atol, rtol)}} _xpu_tolerance_override = { + "nn.functional.grid_sample": { + ("TestCommon", "test_compare_cpu"): { + torch.float32: tol(atol=0.002, rtol=0.008), + } + }, "nn.functional.tanhshrink": { ("TestUnaryUfuncs", "test_reference_numerics_normal"): { torch.complex64: tol(atol=2e-05, rtol=9e-06), diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index dd45febf6..dc6dd9992 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -3669,6 +3669,19 @@ XPU: grid_sampler_2d_backward_xpu autogen: grid_sampler_2d_backward.out +- func: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + dispatch: + XPU: grid_sampler_3d_xpu + autogen: grid_sampler_3d.out + +# `grid_sampler_3d_backward` takes in `output_mask` to optimize performance for +# the case where `input` doesn't require gradient. Gradient for `grid` is always +# computed (only `output_mask[0]` is checked by the implementations). +- func: grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + XPU: grid_sampler_3d_backward_xpu + autogen: grid_sampler_3d_backward.out + - func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index e3e681578..54c89c5c7 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -621,6 +621,8 @@ supported: - linalg_vector_norm.out - grid_sampler_2d - grid_sampler_2d_backward + - grid_sampler_3d + - grid_sampler_3d_backward - acos - acos_ - acos.out