Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,36 @@ def meta_ps_roi_align_backward(
return grad.new_empty((batch_size, channels, height, width))


@register_meta("roi_pool")
def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
channels = input.size(1)
out_size = (num_rois, channels, pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)


@register_meta("_roi_pool_backward")
def meta_roi_pool_backward(
grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))


@torch._custom_ops.impl_abstract("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
Expand Down
56 changes: 28 additions & 28 deletions torchvision/csrc/ops/autograd/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
c10::SymInt pooled_height,
c10::SymInt pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
ctx->saved_data["input_shape"] = input.sym_sizes();
at::AutoDispatchBelowADInplaceOrView g;
auto result =
roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
auto result = roi_pool_symint(
input, rois, spatial_scale, pooled_height, pooled_width);

auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
Expand All @@ -40,18 +40,18 @@ class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_roi_pool_backward(
auto input_shape = ctx->saved_data["input_shape"].toList();
auto grad_in = detail::_roi_pool_backward_symint(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
ctx->saved_data["pooled_height"].toSymInt(),
ctx->saved_data["pooled_width"].toSymInt(),
input_shape[0].get().toSymInt(),
input_shape[1].get().toSymInt(),
input_shape[2].get().toSymInt(),
input_shape[3].get().toSymInt());

return {
grad_in,
Expand All @@ -72,14 +72,14 @@ class ROIPoolBackwardFunction
const torch::autograd::Variable& rois,
const torch::autograd::Variable& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
at::AutoDispatchBelowADInplaceOrView g;
auto grad_in = detail::_roi_pool_backward(
auto grad_in = detail::_roi_pool_backward_symint(
grad,
rois,
argmax,
Expand All @@ -105,8 +105,8 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
c10::SymInt pooled_height,
c10::SymInt pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);

Expand All @@ -118,12 +118,12 @@ at::Tensor roi_pool_backward_autograd(
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
return ROIPoolBackwardFunction::apply(
grad,
rois,
Expand Down
44 changes: 42 additions & 2 deletions torchvision/csrc/ops/roi_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ std::tuple<at::Tensor, at::Tensor> roi_pool(
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}

std::tuple<at::Tensor, at::Tensor> roi_pool_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::roi_pool", "")
.typed<decltype(roi_pool_symint)>();
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}

namespace detail {

at::Tensor _roi_pool_backward(
Expand Down Expand Up @@ -49,13 +62,40 @@ at::Tensor _roi_pool_backward(
width);
}

at::Tensor _roi_pool_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_pool_backward", "")
.typed<decltype(_roi_pool_backward_symint)>();
return op.call(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
}

} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
"torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
"torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"));
}

} // namespace ops
Expand Down
19 changes: 19 additions & 0 deletions torchvision/csrc/ops/roi_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool(
int64_t pooled_height,
int64_t pooled_width);

VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width);

namespace detail {

at::Tensor _roi_pool_backward(
Expand All @@ -27,6 +34,18 @@ at::Tensor _roi_pool_backward(
int64_t height,
int64_t width);

at::Tensor _roi_pool_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width);

} // namespace detail

} // namespace ops
Expand Down