From 27ebbf83fcb934e80e13be7b09c9788403537929 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Sun, 24 Nov 2024 19:31:17 +0800 Subject: [PATCH] Add torchvision::roi_align forward/backward (#1097) - [x] roi_align - [x] _roi_align_backward --------- Co-authored-by: Yutao Xu --- src/ATen/native/xpu/RoiAlign.cpp | 74 +++ src/ATen/native/xpu/XPUFallback.template | 12 +- src/ATen/native/xpu/sycl/RoiAlignKernels.cpp | 516 ++++++++++++++++++ src/ATen/native/xpu/sycl/RoiAlignKernels.h | 27 + test/regressions/test_roi_align.py | 219 ++++++++ .../regressions/test_torchvision_roi_align.py | 4 +- 6 files changed, 845 insertions(+), 7 deletions(-) create mode 100644 src/ATen/native/xpu/RoiAlign.cpp create mode 100644 src/ATen/native/xpu/sycl/RoiAlignKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/RoiAlignKernels.h create mode 100644 test/regressions/test_roi_align.py diff --git a/src/ATen/native/xpu/RoiAlign.cpp b/src/ATen/native/xpu/RoiAlign.cpp new file mode 100644 index 000000000..067bd27d3 --- /dev/null +++ b/src/ATen/native/xpu/RoiAlign.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include +namespace at::native::xpu { + +at::Tensor roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + TORCH_CHECK(input.is_xpu(), "input must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + c10::DeviceGuard device_guard(input.device()); + return roi_align_kernel( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned); +} + +at::Tensor _roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + TORCH_CHECK(grad.is_xpu(), "grad must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + c10::DeviceGuard device_guard(grad.device()); + + return roi_align_backward_kernel( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 47120f75d..10e16e2dc 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -24,6 +24,8 @@ static void xpu_fallback_impl( namespace native::xpu { Tensor nms(const Tensor& dets, const Tensor& scores, double iou_threshold_); +Tensor roi_align(const Tensor& input, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio, bool aligned); +Tensor _roi_align_backward(const Tensor& grad, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t batch_size, int64_t channels, int64_t height, int64_t width, int64_t sampling_ratio, bool aligned); } // Register op's implementation lazily since sometimes the op is not defined, @@ -34,8 +36,8 @@ Tensor nms(const Tensor& dets, const Tensor& scores, double iou_threshold_); // static std::map torchvision_ops_dispatching_table_ = { {"torchvision::nms", false}, - {"torchvision::roi_align", true}, - {"torchvision::_roi_align_backward", true}, + {"torchvision::roi_align", false}, + {"torchvision::_roi_align_backward", false}, }; // Return: @@ -51,11 +53,9 @@ static bool lazy_registration_and_redispatch( // suppose ops of torchvision are all defined (`import torchvision`). m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(at::native::xpu::nms)); m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - torch::CppFunction::makeFromBoxedFunction<&xpu_fallback_impl>()); + TORCH_SELECTIVE_NAME("torchvision::roi_align"),TORCH_FN(at::native::xpu::roi_align)); m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - torch::CppFunction::makeFromBoxedFunction<&xpu_fallback_impl>()); + TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),TORCH_FN(at::native::xpu::_roi_align_backward)); }; static const torch::detail::TorchLibraryInit diff --git a/src/ATen/native/xpu/sycl/RoiAlignKernels.cpp b/src/ATen/native/xpu/sycl/RoiAlignKernels.cpp new file mode 100644 index 000000000..f95d82c42 --- /dev/null +++ b/src/ATen/native/xpu/sycl/RoiAlignKernels.cpp @@ -0,0 +1,516 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +T bilinear_interpolate( + const T* input, + int height, + int width, + T y, + T x, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} +template +struct RoiAlignForwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int c = (index / pooled_width_ / pooled_height_) % channels_; + int n = index / pooled_width_ / pooled_height_ / channels_; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned_ ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale_ - offset; + T roi_start_h = offset_rois[2] * spatial_scale_ - offset; + T roi_end_w = offset_rois[3] * spatial_scale_ - offset; + T roi_end_h = offset_rois[4] * spatial_scale_ - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned_) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + + T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height_); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + + const T* offset_input = + input_ + (roi_batch_ind * channels_ + c) * height_ * width_; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_height / pooled_height_); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_width / pooled_width_); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = std::max( + (int)(roi_bin_grid_h * roi_bin_grid_w), (int)(1)); // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = + bilinear_interpolate(offset_input, height_, width_, y, x, index); + output_val += val; + } + } + output_val /= count; + + output_[index] = output_val; + } + } + RoiAlignForwardKernel( + int nthreads, + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + bool aligned, + const T* rois, + T* output) + : nthreads_(nthreads), + input_(input), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + aligned_(aligned), + rois_(rois), + output_(output) {} + + private: + int nthreads_; + const T* input_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + int sampling_ratio_; + bool aligned_; + const T* rois_; + T* output_; +}; + +template +void bilinear_interpolate_gradient( + int height, + int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +struct RoiAlignBackwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int c = (index / pooled_width_ / pooled_height_) % channels_; + int n = index / pooled_width_ / pooled_height_ / channels_; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned_ ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale_ - offset; + T roi_start_h = offset_rois[2] * spatial_scale_ - offset; + T roi_end_w = offset_rois[3] * spatial_scale_ - offset; + T roi_end_h = offset_rois[4] * spatial_scale_ - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned_) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + + T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height_); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + + // We need to index the gradient using the tensor strides to access the + // correct values. + const int output_offset = n * n_stride_ + c * c_stride_; + const T* offset_grad_output = grad_output_ + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride_ + pw * w_stride_]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_height / pooled_height_); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_width / pooled_width_); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + const int input_offset = + (roi_batch_ind * channels_ + c) * height_ * width_; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height_, + width_, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + input_offset + y_low * width_ + x_low), + static_cast(g1)); + + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + input_offset + y_low * width_ + x_high), + static_cast(g2)); + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + input_offset + y_high * width_ + x_low), + static_cast(g3)); + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + input_offset + y_high * width_ + x_high), + static_cast(g4)); + } // if + } // ix + } // iy + } // XPU_KERNEL_LOOP + } + RoiAlignBackwardKernel( + int nthreads, + const T* grad_output, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + bool aligned, + T* grad_input, + const T* rois, + int n_stride, + int c_stride, + int h_stride, + int w_stride, + const int memory_span) + : nthreads_(nthreads), + grad_output_(grad_output), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + sampling_ratio_(sampling_ratio), + aligned_(aligned), + grad_input_(grad_input), + rois_(rois), + n_stride_(n_stride), + c_stride_(c_stride), + h_stride_(h_stride), + w_stride_(w_stride), + memory_span_(memory_span) {} + + private: + int nthreads_; + const T* grad_output_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + int sampling_ratio_; + bool aligned_; + T* grad_input_; + const T* rois_; + int n_stride_; + int c_stride_; + int h_stride_; + int w_stride_; + const int memory_span_; +}; + +Tensor roi_align_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + int64_t global_range = + ceil_div(static_cast(output_size), static_cast(512)); + int64_t local_range = 512; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_align_forward_kernel_xpu", [&] { + auto kfn = RoiAlignForwardKernel( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois_.data_ptr(), + output.data_ptr()); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return output; +} + +Tensor roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + int64_t global_range = + ceil_div(static_cast(grad.numel()), static_cast(512)); + int64_t local_range = 512; + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + at::globalContext().alertNotDeterministic("roi_align_backward_kernel_xpu"); + + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "roi_align_backward_kernel_xpu", [&] { + auto kfn = RoiAlignBackwardKernel( + grad.numel(), + grad.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride, + grad_input.numel()); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return grad_input; +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/RoiAlignKernels.h b/src/ATen/native/xpu/sycl/RoiAlignKernels.h new file mode 100644 index 000000000..9d37f0e27 --- /dev/null +++ b/src/ATen/native/xpu/sycl/RoiAlignKernels.h @@ -0,0 +1,27 @@ +#pragma once + +#include +namespace at::native::xpu { + +TORCH_XPU_API Tensor roi_align_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); + +TORCH_XPU_API Tensor roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); +} // namespace at::native::xpu \ No newline at end of file diff --git a/test/regressions/test_roi_align.py b/test/regressions/test_roi_align.py new file mode 100644 index 000000000..4f5caaa3c --- /dev/null +++ b/test/regressions/test_roi_align.py @@ -0,0 +1,219 @@ +import torch +from torch.testing._internal.common_utils import TestCase +import torchvision +import math +import numpy as np + + +def bilinear_interpolate(data, y, x, snap_border=False): + height, width = data.shape + if snap_border: + if -1 < y <= 0: + y = 0 + elif height - 1 <= y < height: + y = height - 1 + if -1 < x <= 0: + x = 0 + elif width - 1 <= x < width: + x = width - 1 + y_low = int(math.floor(y)) + x_low = int(math.floor(x)) + y_high = y_low + 1 + x_high = x_low + 1 + wy_h = y - y_low + wx_h = x - x_low + wy_l = 1 - wy_h + wx_l = 1 - wx_h + val = 0 + for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): + for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): + if 0 <= yp < height and 0 <= xp < width: + val += wx * wy * data[yp, xp] + return val + + +def expected_fn( + in_data, + rois, + pool_h, + pool_w, + spatial_scale=1, + sampling_ratio=-1, + aligned=False, + device=None, + dtype=torch.float64, +): + if device is None: + device = torch.device("cpu") + n_channels = in_data.size(1) + out_data = torch.zeros( + rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device + ) + offset = 0.5 if aligned else 0.0 + for r, roi in enumerate(rois): + batch_idx = int(roi[0]) + j_begin, i_begin, j_end, i_end = ( + x.item() * spatial_scale - offset for x in roi[1:] + ) + roi_h = i_end - i_begin + roi_w = j_end - j_begin + bin_h = roi_h / pool_h + bin_w = roi_w / pool_w + for i in range(0, pool_h): + start_h = i_begin + i * bin_h + grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h)) + for j in range(0, pool_w): + start_w = j_begin + j * bin_w + grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) + for channel in range(0, n_channels): + val = 0 + for iy in range(0, grid_h): + y = start_h + (iy + 0.5) * bin_h / grid_h + for ix in range(0, grid_w): + x = start_w + (ix + 0.5) * bin_w / grid_w + val += bilinear_interpolate( + in_data[batch_idx, channel, :, :], + y, + x, + snap_border=True, + ) + val /= grid_h * grid_w + out_data[r, channel, i, j] = val + return out_data + + +def expected_grad_fn( + in_data, + rois, + grad_output, + pool_h, + pool_w, + spatial_scale=1, + sampling_ratio=-1, + aligned=False, + device=None, + dtype=torch.float64, +): + if device is None: + device = torch.device("cpu") + n_channels = in_data.size(1) + grad_input = torch.zeros_like(in_data, dtype=dtype, device=device) + offset = 0.5 if aligned else 0.0 + for r, roi in enumerate(rois): + batch_idx = int(roi[0]) + j_begin, i_begin, j_end, i_end = ( + x.item() * spatial_scale - offset for x in roi[1:] + ) + roi_h = i_end - i_begin + roi_w = j_end - j_begin + bin_h = roi_h / pool_h + bin_w = roi_w / pool_w + for i in range(0, pool_h): + start_h = i_begin + i * bin_h + grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h)) + for j in range(0, pool_w): + start_w = j_begin + j * bin_w + grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) + for channel in range(0, n_channels): + grad_val = grad_output[r, channel, i, j] / (grid_h * grid_w) + for iy in range(0, grid_h): + y = start_h + (iy + 0.5) * bin_h / grid_h + for ix in range(0, grid_w): + x = start_w + (ix + 0.5) * bin_w / grid_w + y_low = int(math.floor(y)) + x_low = int(math.floor(x)) + y_high = y_low + 1 + x_high = x_low + 1 + wy_h = y - y_low + wx_h = x - x_low + wy_l = 1 - wy_h + wx_l = 1 - wx_h + for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): + for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): + if 0 <= yp < in_data.size(2) and 0 <= xp < in_data.size(3): + grad_input[batch_idx, channel, yp, xp] += wx * wy * grad_val + return grad_input + + +class TestNNMethod(TestCase): + def roi_align_forward_(self, dtype_): + device = torch.device("xpu") + x_dtype = dtype_ + rois_dtype = dtype_ + pool_size = 5 + n_channels = 2 * (pool_size**2) + x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) + rois = torch.tensor( + [ + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9], + ], # format is (xyxy) + dtype=rois_dtype, + device=device, + ) + pool_h, pool_w = pool_size, pool_size + y = torchvision.ops.roi_align( + x, rois, [pool_h, pool_w], spatial_scale=1, sampling_ratio=-1 + ) + assert y.dtype == x.dtype + gt_y = expected_fn( + x, + rois, + pool_h, + pool_w, + spatial_scale=1, + sampling_ratio=-1, + device=device, + dtype=x_dtype, + ) + tol = 1e-2 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 + torch.testing.assert_close(gt_y.cpu(), y.cpu(), rtol=tol, atol=tol) + + def roi_align_backward_(self, dtype_): + device = torch.device("xpu") + x_dtype = dtype_ + rois_dtype = dtype_ + pool_size = 5 + n_channels = 2 * (pool_size**2) + x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device, requires_grad=True) + rois = torch.tensor( + [ + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9], + ], # format is (xyxy) + dtype=rois_dtype, + device=device, + ) + pool_h, pool_w = pool_size, pool_size + y = torchvision.ops.roi_align( + x, rois, [pool_h, pool_w], spatial_scale=1, sampling_ratio=-1 + ) + grad_output = torch.rand_like(y) + y.backward(grad_output) + assert x.grad is not None + # Compare gradients + gt_grad = expected_grad_fn( + x.detach(), + rois, + grad_output, + pool_h, + pool_w, + spatial_scale=1, + sampling_ratio=-1, + device=device, + dtype=x_dtype, + ) + tol = 1e-2 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 + torch.testing.assert_close(gt_grad.cpu(), x.grad.cpu(), rtol=tol, atol=tol) + + def test_roi_align_forward(self): + for dtype in [torch.float, torch.half]: + self.roi_align_forward_(dtype) + + def test_roi_align_backward(self): + for dtype in [torch.float, torch.half]: + self.roi_align_backward_(dtype) diff --git a/test/regressions/test_torchvision_roi_align.py b/test/regressions/test_torchvision_roi_align.py index ea7ee71a6..fbfd82f8f 100644 --- a/test/regressions/test_torchvision_roi_align.py +++ b/test/regressions/test_torchvision_roi_align.py @@ -4,6 +4,8 @@ class TestTorchVisionMethod(TestCase): def test_roi_align(self): + atol = 1e-1 + rtol = 5e-5 a_ref = torch.zeros([4, 256, 296, 304]).requires_grad_(True) b_ref = torch.zeros([2292, 5]).requires_grad_(True) @@ -15,4 +17,4 @@ def test_roi_align(self): ref.sum().backward() res.sum().backward() self.assertEqual(ref, res.cpu()) - self.assertEqual(a_ref.grad, a_xpu.grad.cpu()) + self.assertEqual(a_ref.grad, a_xpu.grad.cpu(), rtol=rtol, atol=atol)