Skip to content

Commit

Permalink
Add torchvision::roi_align forward/backward (#1097)
Browse files Browse the repository at this point in the history
- [x] roi_align
- [x] _roi_align_backward

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
chunhuanMeng and xytintel authored Nov 24, 2024
1 parent f7ca0ae commit 27ebbf8
Show file tree
Hide file tree
Showing 6 changed files with 845 additions and 7 deletions.
74 changes: 74 additions & 0 deletions src/ATen/native/xpu/RoiAlign.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/xpu/sycl/RoiAlignKernels.h>
#include <comm/XPUGuard.h>
#include <comm/xpu_aten.h>
#include <torch/library.h>
namespace at::native::xpu {

at::Tensor roi_align(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
TORCH_CHECK(input.is_xpu(), "input must be a XPU tensor");
TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});

c10::DeviceGuard device_guard(input.device());
return roi_align_kernel(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
}

at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
TORCH_CHECK(grad.is_xpu(), "grad must be a XPU tensor");
TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor");

at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameType(c, {grad_t, rois_t});

c10::DeviceGuard device_guard(grad.device());

return roi_align_backward_kernel(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
}

} // namespace at::native::xpu
12 changes: 6 additions & 6 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ static void xpu_fallback_impl(

namespace native::xpu {
Tensor nms(const Tensor& dets, const Tensor& scores, double iou_threshold_);
Tensor roi_align(const Tensor& input, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio, bool aligned);
Tensor _roi_align_backward(const Tensor& grad, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t batch_size, int64_t channels, int64_t height, int64_t width, int64_t sampling_ratio, bool aligned);
}

// Register op's implementation lazily since sometimes the op is not defined,
Expand All @@ -34,8 +36,8 @@ Tensor nms(const Tensor& dets, const Tensor& scores, double iou_threshold_);
// <operator_name: string, is_cpu_fallback: bool>
static std::map<std::string, bool> torchvision_ops_dispatching_table_ = {
{"torchvision::nms", false},
{"torchvision::roi_align", true},
{"torchvision::_roi_align_backward", true},
{"torchvision::roi_align", false},
{"torchvision::_roi_align_backward", false},
};

// Return:
Expand All @@ -51,11 +53,9 @@ static bool lazy_registration_and_redispatch(
// suppose ops of torchvision are all defined (`import torchvision`).
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(at::native::xpu::nms));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
torch::CppFunction::makeFromBoxedFunction<&xpu_fallback_impl>());
TORCH_SELECTIVE_NAME("torchvision::roi_align"),TORCH_FN(at::native::xpu::roi_align));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
torch::CppFunction::makeFromBoxedFunction<&xpu_fallback_impl>());
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),TORCH_FN(at::native::xpu::_roi_align_backward));
};

static const torch::detail::TorchLibraryInit
Expand Down
Loading

0 comments on commit 27ebbf8

Please sign in to comment.