Skip to content

Commit 668348e

Browse files
authored
PSRoiPool: SymInt support + meta-implem (#8062)
1 parent 85c586c commit 668348e

File tree

4 files changed

+124
-30
lines changed

4 files changed

+124
-30
lines changed

torchvision/_meta_registrations.py

+34
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,40 @@ def meta_roi_pool_backward(
126126
return grad.new_empty((batch_size, channels, height, width))
127127

128128

129+
@register_meta("ps_roi_pool")
130+
def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
131+
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
132+
torch._check(
133+
input.dtype == rois.dtype,
134+
lambda: (
135+
"Expected tensor for input to have the same type as tensor for rois; "
136+
f"but type {input.dtype} does not equal {rois.dtype}"
137+
),
138+
)
139+
channels = input.size(1)
140+
torch._check(
141+
channels % (pooled_height * pooled_width) == 0,
142+
"input channels must be a multiple of pooling height * pooling width",
143+
)
144+
num_rois = rois.size(0)
145+
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
146+
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
147+
148+
149+
@register_meta("_ps_roi_pool_backward")
150+
def meta_ps_roi_pool_backward(
151+
grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
152+
):
153+
torch._check(
154+
grad.dtype == rois.dtype,
155+
lambda: (
156+
"Expected tensor for grad to have the same type as tensor for rois; "
157+
f"but type {grad.dtype} does not equal {rois.dtype}"
158+
),
159+
)
160+
return grad.new_empty((batch_size, channels, height, width))
161+
162+
129163
@torch._custom_ops.impl_abstract("torchvision::nms")
130164
def meta_nms(dets, scores, iou_threshold):
131165
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")

torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp

+28-28
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
1515
const torch::autograd::Variable& input,
1616
const torch::autograd::Variable& rois,
1717
double spatial_scale,
18-
int64_t pooled_height,
19-
int64_t pooled_width) {
18+
c10::SymInt pooled_height,
19+
c10::SymInt pooled_width) {
2020
ctx->saved_data["spatial_scale"] = spatial_scale;
2121
ctx->saved_data["pooled_height"] = pooled_height;
2222
ctx->saved_data["pooled_width"] = pooled_width;
23-
ctx->saved_data["input_shape"] = input.sizes();
23+
ctx->saved_data["input_shape"] = input.sym_sizes();
2424
at::AutoDispatchBelowADInplaceOrView g;
25-
auto result =
26-
ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
25+
auto result = ps_roi_pool_symint(
26+
input, rois, spatial_scale, pooled_height, pooled_width);
2727

2828
auto output = std::get<0>(result);
2929
auto channel_mapping = std::get<1>(result);
@@ -40,18 +40,18 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
4040
auto saved = ctx->get_saved_variables();
4141
auto rois = saved[0];
4242
auto channel_mapping = saved[1];
43-
auto input_shape = ctx->saved_data["input_shape"].toIntList();
44-
auto grad_in = detail::_ps_roi_pool_backward(
43+
auto input_shape = ctx->saved_data["input_shape"].toList();
44+
auto grad_in = detail::_ps_roi_pool_backward_symint(
4545
grad_output[0],
4646
rois,
4747
channel_mapping,
4848
ctx->saved_data["spatial_scale"].toDouble(),
49-
ctx->saved_data["pooled_height"].toInt(),
50-
ctx->saved_data["pooled_width"].toInt(),
51-
input_shape[0],
52-
input_shape[1],
53-
input_shape[2],
54-
input_shape[3]);
49+
ctx->saved_data["pooled_height"].toSymInt(),
50+
ctx->saved_data["pooled_width"].toSymInt(),
51+
input_shape[0].get().toSymInt(),
52+
input_shape[1].get().toSymInt(),
53+
input_shape[2].get().toSymInt(),
54+
input_shape[3].get().toSymInt());
5555

5656
return {
5757
grad_in,
@@ -72,14 +72,14 @@ class PSROIPoolBackwardFunction
7272
const torch::autograd::Variable& rois,
7373
const torch::autograd::Variable& channel_mapping,
7474
double spatial_scale,
75-
int64_t pooled_height,
76-
int64_t pooled_width,
77-
int64_t batch_size,
78-
int64_t channels,
79-
int64_t height,
80-
int64_t width) {
75+
c10::SymInt pooled_height,
76+
c10::SymInt pooled_width,
77+
c10::SymInt batch_size,
78+
c10::SymInt channels,
79+
c10::SymInt height,
80+
c10::SymInt width) {
8181
at::AutoDispatchBelowADInplaceOrView g;
82-
auto grad_in = detail::_ps_roi_pool_backward(
82+
auto grad_in = detail::_ps_roi_pool_backward_symint(
8383
grad,
8484
rois,
8585
channel_mapping,
@@ -105,8 +105,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
105105
const at::Tensor& input,
106106
const at::Tensor& rois,
107107
double spatial_scale,
108-
int64_t pooled_height,
109-
int64_t pooled_width) {
108+
c10::SymInt pooled_height,
109+
c10::SymInt pooled_width) {
110110
auto result = PSROIPoolFunction::apply(
111111
input, rois, spatial_scale, pooled_height, pooled_width);
112112

@@ -118,12 +118,12 @@ at::Tensor ps_roi_pool_backward_autograd(
118118
const at::Tensor& rois,
119119
const at::Tensor& channel_mapping,
120120
double spatial_scale,
121-
int64_t pooled_height,
122-
int64_t pooled_width,
123-
int64_t batch_size,
124-
int64_t channels,
125-
int64_t height,
126-
int64_t width) {
121+
c10::SymInt pooled_height,
122+
c10::SymInt pooled_width,
123+
c10::SymInt batch_size,
124+
c10::SymInt channels,
125+
c10::SymInt height,
126+
c10::SymInt width) {
127127
return PSROIPoolBackwardFunction::apply(
128128
grad,
129129
rois,

torchvision/csrc/ops/ps_roi_pool.cpp

+43-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
2020
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
2121
}
2222

23+
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_symint(
24+
const at::Tensor& input,
25+
const at::Tensor& rois,
26+
double spatial_scale,
27+
c10::SymInt pooled_height,
28+
c10::SymInt pooled_width) {
29+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool");
30+
static auto op = c10::Dispatcher::singleton()
31+
.findSchemaOrThrow("torchvision::ps_roi_pool", "")
32+
.typed<decltype(ps_roi_pool_symint)>();
33+
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
34+
}
35+
2336
namespace detail {
2437

2538
at::Tensor _ps_roi_pool_backward(
@@ -50,13 +63,41 @@ at::Tensor _ps_roi_pool_backward(
5063
width);
5164
}
5265

66+
at::Tensor _ps_roi_pool_backward_symint(
67+
const at::Tensor& grad,
68+
const at::Tensor& rois,
69+
const at::Tensor& channel_mapping,
70+
double spatial_scale,
71+
c10::SymInt pooled_height,
72+
c10::SymInt pooled_width,
73+
c10::SymInt batch_size,
74+
c10::SymInt channels,
75+
c10::SymInt height,
76+
c10::SymInt width) {
77+
static auto op =
78+
c10::Dispatcher::singleton()
79+
.findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "")
80+
.typed<decltype(_ps_roi_pool_backward_symint)>();
81+
return op.call(
82+
grad,
83+
rois,
84+
channel_mapping,
85+
spatial_scale,
86+
pooled_height,
87+
pooled_width,
88+
batch_size,
89+
channels,
90+
height,
91+
width);
92+
}
93+
5394
} // namespace detail
5495

5596
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
5697
m.def(TORCH_SELECTIVE_SCHEMA(
57-
"torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
98+
"torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)"));
5899
m.def(TORCH_SELECTIVE_SCHEMA(
59-
"torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
100+
"torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"));
60101
}
61102

62103
} // namespace ops

torchvision/csrc/ops/ps_roi_pool.h

+19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
1313
int64_t pooled_height,
1414
int64_t pooled_width);
1515

16+
VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool_symint(
17+
const at::Tensor& input,
18+
const at::Tensor& rois,
19+
double spatial_scale,
20+
c10::SymInt pooled_height,
21+
c10::SymInt pooled_width);
22+
1623
namespace detail {
1724

1825
at::Tensor _ps_roi_pool_backward(
@@ -27,6 +34,18 @@ at::Tensor _ps_roi_pool_backward(
2734
int64_t height,
2835
int64_t width);
2936

37+
at::Tensor _ps_roi_pool_backward_symint(
38+
const at::Tensor& grad,
39+
const at::Tensor& rois,
40+
const at::Tensor& channel_mapping,
41+
double spatial_scale,
42+
c10::SymInt pooled_height,
43+
c10::SymInt pooled_width,
44+
c10::SymInt batch_size,
45+
c10::SymInt channels,
46+
c10::SymInt height,
47+
c10::SymInt width);
48+
3049
} // namespace detail
3150

3251
} // namespace ops

0 commit comments

Comments
 (0)