Skip to content

Commit a8ebd0b

Browse files
authored
DeformConv2d: SymInt support + meta-implem + opchecks (#8063)
1 parent 668348e commit a8ebd0b

File tree

5 files changed

+218
-45
lines changed

5 files changed

+218
-45
lines changed

test/test_ops.py

+13
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,7 @@ def test_is_leaf_node(self, device):
10191019
@pytest.mark.parametrize("device", cpu_and_cuda())
10201020
@pytest.mark.parametrize("contiguous", (True, False))
10211021
@pytest.mark.parametrize("batch_sz", (0, 33))
1022+
@pytest.mark.opcheck_only_one()
10221023
def test_forward(self, device, contiguous, batch_sz, dtype=None):
10231024
dtype = dtype or self.dtype
10241025
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
@@ -1071,6 +1072,7 @@ def test_wrong_sizes(self):
10711072
@pytest.mark.parametrize("device", cpu_and_cuda())
10721073
@pytest.mark.parametrize("contiguous", (True, False))
10731074
@pytest.mark.parametrize("batch_sz", (0, 33))
1075+
@pytest.mark.opcheck_only_one()
10741076
def test_backward(self, device, contiguous, batch_sz):
10751077
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
10761078
device, contiguous, batch_sz, self.dtype
@@ -1120,6 +1122,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
11201122

11211123
@needs_cuda
11221124
@pytest.mark.parametrize("contiguous", (True, False))
1125+
@pytest.mark.opcheck_only_one()
11231126
def test_compare_cpu_cuda_grads(self, contiguous):
11241127
# Test from https://github.com/pytorch/vision/issues/2598
11251128
# Run on CUDA only
@@ -1154,6 +1157,7 @@ def test_compare_cpu_cuda_grads(self, contiguous):
11541157
@needs_cuda
11551158
@pytest.mark.parametrize("batch_sz", (0, 33))
11561159
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
1160+
@pytest.mark.opcheck_only_one()
11571161
def test_autocast(self, batch_sz, dtype):
11581162
with torch.cuda.amp.autocast():
11591163
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
@@ -1163,6 +1167,15 @@ def test_forward_scriptability(self):
11631167
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
11641168

11651169

1170+
optests.generate_opcheck_tests(
1171+
testcase=TestDeformConv,
1172+
namespaces=["torchvision"],
1173+
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
1174+
additional_decorators=[],
1175+
test_utils=OPTESTS,
1176+
)
1177+
1178+
11661179
class TestFrozenBNT:
11671180
def test_frozenbatchnorm2d_repr(self):
11681181
num_features = 32

torchvision/_meta_registrations.py

+51
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,54 @@ def meta_nms(dets, scores, iou_threshold):
172172
ctx = torch._custom_ops.get_ctx()
173173
num_to_keep = ctx.create_unbacked_symint()
174174
return dets.new_empty(num_to_keep, dtype=torch.long)
175+
176+
177+
@register_meta("deform_conv2d")
178+
def meta_deform_conv2d(
179+
input,
180+
weight,
181+
offset,
182+
mask,
183+
bias,
184+
stride_h,
185+
stride_w,
186+
pad_h,
187+
pad_w,
188+
dil_h,
189+
dil_w,
190+
n_weight_grps,
191+
n_offset_grps,
192+
use_mask,
193+
):
194+
195+
out_height, out_width = offset.shape[-2:]
196+
out_channels = weight.shape[0]
197+
batch_size = input.shape[0]
198+
return input.new_empty((batch_size, out_channels, out_height, out_width))
199+
200+
201+
@register_meta("_deform_conv2d_backward")
202+
def meta_deform_conv2d_backward(
203+
grad,
204+
input,
205+
weight,
206+
offset,
207+
mask,
208+
bias,
209+
stride_h,
210+
stride_w,
211+
pad_h,
212+
pad_w,
213+
dilation_h,
214+
dilation_w,
215+
groups,
216+
offset_groups,
217+
use_mask,
218+
):
219+
220+
grad_input = input.new_empty(input.shape)
221+
grad_weight = weight.new_empty(weight.shape)
222+
grad_offset = offset.new_empty(offset.shape)
223+
grad_mask = mask.new_empty(mask.shape)
224+
grad_bias = bias.new_empty(bias.shape)
225+
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias

torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp

+43-43
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ class DeformConv2dFunction
1818
const torch::autograd::Variable& offset,
1919
const torch::autograd::Variable& mask,
2020
const torch::autograd::Variable& bias,
21-
int64_t stride_h,
22-
int64_t stride_w,
23-
int64_t pad_h,
24-
int64_t pad_w,
25-
int64_t dilation_h,
26-
int64_t dilation_w,
27-
int64_t groups,
28-
int64_t offset_groups,
21+
c10::SymInt stride_h,
22+
c10::SymInt stride_w,
23+
c10::SymInt pad_h,
24+
c10::SymInt pad_w,
25+
c10::SymInt dilation_h,
26+
c10::SymInt dilation_w,
27+
c10::SymInt groups,
28+
c10::SymInt offset_groups,
2929
bool use_mask) {
3030
at::AutoDispatchBelowADInplaceOrView g;
31-
auto output = deform_conv2d(
31+
auto output = deform_conv2d_symint(
3232
input,
3333
weight,
3434
offset,
@@ -70,17 +70,17 @@ class DeformConv2dFunction
7070
auto mask = saved[3];
7171
auto bias = saved[4];
7272

73-
auto stride_h = ctx->saved_data["stride_h"].toInt();
74-
auto stride_w = ctx->saved_data["stride_w"].toInt();
75-
auto pad_h = ctx->saved_data["pad_h"].toInt();
76-
auto pad_w = ctx->saved_data["pad_w"].toInt();
77-
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
78-
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
79-
auto groups = ctx->saved_data["groups"].toInt();
80-
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
73+
auto stride_h = ctx->saved_data["stride_h"].toSymInt();
74+
auto stride_w = ctx->saved_data["stride_w"].toSymInt();
75+
auto pad_h = ctx->saved_data["pad_h"].toSymInt();
76+
auto pad_w = ctx->saved_data["pad_w"].toSymInt();
77+
auto dilation_h = ctx->saved_data["dilation_h"].toSymInt();
78+
auto dilation_w = ctx->saved_data["dilation_w"].toSymInt();
79+
auto groups = ctx->saved_data["groups"].toSymInt();
80+
auto offset_groups = ctx->saved_data["offset_groups"].toSymInt();
8181
auto use_mask = ctx->saved_data["use_mask"].toBool();
8282

83-
auto grads = detail::_deform_conv2d_backward(
83+
auto grads = detail::_deform_conv2d_backward_symint(
8484
grad_output[0],
8585
input,
8686
weight,
@@ -133,17 +133,17 @@ class DeformConv2dBackwardFunction
133133
const torch::autograd::Variable& offset,
134134
const torch::autograd::Variable& mask,
135135
const torch::autograd::Variable& bias,
136-
int64_t stride_h,
137-
int64_t stride_w,
138-
int64_t pad_h,
139-
int64_t pad_w,
140-
int64_t dilation_h,
141-
int64_t dilation_w,
142-
int64_t groups,
143-
int64_t offset_groups,
136+
c10::SymInt stride_h,
137+
c10::SymInt stride_w,
138+
c10::SymInt pad_h,
139+
c10::SymInt pad_w,
140+
c10::SymInt dilation_h,
141+
c10::SymInt dilation_w,
142+
c10::SymInt groups,
143+
c10::SymInt offset_groups,
144144
bool use_mask) {
145145
at::AutoDispatchBelowADInplaceOrView g;
146-
auto result = detail::_deform_conv2d_backward(
146+
auto result = detail::_deform_conv2d_backward_symint(
147147
grad,
148148
input,
149149
weight,
@@ -188,14 +188,14 @@ at::Tensor deform_conv2d_autograd(
188188
const at::Tensor& offset,
189189
const at::Tensor& mask,
190190
const at::Tensor& bias,
191-
int64_t stride_h,
192-
int64_t stride_w,
193-
int64_t pad_h,
194-
int64_t pad_w,
195-
int64_t dilation_h,
196-
int64_t dilation_w,
197-
int64_t groups,
198-
int64_t offset_groups,
191+
c10::SymInt stride_h,
192+
c10::SymInt stride_w,
193+
c10::SymInt pad_h,
194+
c10::SymInt pad_w,
195+
c10::SymInt dilation_h,
196+
c10::SymInt dilation_w,
197+
c10::SymInt groups,
198+
c10::SymInt offset_groups,
199199
bool use_mask) {
200200
return DeformConv2dFunction::apply(
201201
input,
@@ -222,14 +222,14 @@ deform_conv2d_backward_autograd(
222222
const at::Tensor& offset,
223223
const at::Tensor& mask,
224224
const at::Tensor& bias,
225-
int64_t stride_h,
226-
int64_t stride_w,
227-
int64_t pad_h,
228-
int64_t pad_w,
229-
int64_t dilation_h,
230-
int64_t dilation_w,
231-
int64_t groups,
232-
int64_t offset_groups,
225+
c10::SymInt stride_h,
226+
c10::SymInt stride_w,
227+
c10::SymInt pad_h,
228+
c10::SymInt pad_w,
229+
c10::SymInt dilation_h,
230+
c10::SymInt dilation_w,
231+
c10::SymInt groups,
232+
c10::SymInt offset_groups,
233233
bool use_mask) {
234234
auto result = DeformConv2dBackwardFunction::apply(
235235
grad,

torchvision/csrc/ops/deform_conv2d.cpp

+77-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,42 @@ at::Tensor deform_conv2d(
4343
use_mask);
4444
}
4545

46+
at::Tensor deform_conv2d_symint(
47+
const at::Tensor& input,
48+
const at::Tensor& weight,
49+
const at::Tensor& offset,
50+
const at::Tensor& mask,
51+
const at::Tensor& bias,
52+
c10::SymInt stride_h,
53+
c10::SymInt stride_w,
54+
c10::SymInt pad_h,
55+
c10::SymInt pad_w,
56+
c10::SymInt dilation_h,
57+
c10::SymInt dilation_w,
58+
c10::SymInt groups,
59+
c10::SymInt offset_groups,
60+
bool use_mask) {
61+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d");
62+
static auto op = c10::Dispatcher::singleton()
63+
.findSchemaOrThrow("torchvision::deform_conv2d", "")
64+
.typed<decltype(deform_conv2d_symint)>();
65+
return op.call(
66+
input,
67+
weight,
68+
offset,
69+
mask,
70+
bias,
71+
stride_h,
72+
stride_w,
73+
pad_h,
74+
pad_w,
75+
dilation_h,
76+
dilation_w,
77+
groups,
78+
offset_groups,
79+
use_mask);
80+
}
81+
4682
namespace detail {
4783

4884
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -84,13 +120,52 @@ _deform_conv2d_backward(
84120
use_mask);
85121
}
86122

123+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
124+
_deform_conv2d_backward_symint(
125+
const at::Tensor& grad,
126+
const at::Tensor& input,
127+
const at::Tensor& weight,
128+
const at::Tensor& offset,
129+
const at::Tensor& mask,
130+
const at::Tensor& bias,
131+
c10::SymInt stride_h,
132+
c10::SymInt stride_w,
133+
c10::SymInt pad_h,
134+
c10::SymInt pad_w,
135+
c10::SymInt dilation_h,
136+
c10::SymInt dilation_w,
137+
c10::SymInt groups,
138+
c10::SymInt offset_groups,
139+
bool use_mask) {
140+
static auto op =
141+
c10::Dispatcher::singleton()
142+
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
143+
.typed<decltype(_deform_conv2d_backward_symint)>();
144+
return op.call(
145+
grad,
146+
input,
147+
weight,
148+
offset,
149+
mask,
150+
bias,
151+
stride_h,
152+
stride_w,
153+
pad_h,
154+
pad_w,
155+
dilation_h,
156+
dilation_w,
157+
groups,
158+
offset_groups,
159+
use_mask);
160+
}
161+
87162
} // namespace detail
88163

89164
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
90165
m.def(TORCH_SELECTIVE_SCHEMA(
91-
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"));
166+
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor"));
92167
m.def(TORCH_SELECTIVE_SCHEMA(
93-
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
168+
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
94169
}
95170

96171
} // namespace ops

torchvision/csrc/ops/deform_conv2d.h

+34
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@ VISION_API at::Tensor deform_conv2d(
2222
int64_t offset_groups,
2323
bool use_mask);
2424

25+
VISION_API at::Tensor deform_conv2d_symint(
26+
const at::Tensor& input,
27+
const at::Tensor& weight,
28+
const at::Tensor& offset,
29+
const at::Tensor& mask,
30+
const at::Tensor& bias,
31+
c10::SymInt stride_h,
32+
c10::SymInt stride_w,
33+
c10::SymInt pad_h,
34+
c10::SymInt pad_w,
35+
c10::SymInt dilation_h,
36+
c10::SymInt dilation_w,
37+
c10::SymInt groups,
38+
c10::SymInt offset_groups,
39+
bool use_mask);
40+
2541
namespace detail {
2642

2743
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -42,6 +58,24 @@ _deform_conv2d_backward(
4258
int64_t offset_groups,
4359
bool use_mask);
4460

61+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
62+
_deform_conv2d_backward_symint(
63+
const at::Tensor& grad,
64+
const at::Tensor& input,
65+
const at::Tensor& weight,
66+
const at::Tensor& offset,
67+
const at::Tensor& mask,
68+
const at::Tensor& bias,
69+
c10::SymInt stride_h,
70+
c10::SymInt stride_w,
71+
c10::SymInt pad_h,
72+
c10::SymInt pad_w,
73+
c10::SymInt dilation_h,
74+
c10::SymInt dilation_w,
75+
c10::SymInt groups,
76+
c10::SymInt offset_groups,
77+
bool use_mask);
78+
4579
} // namespace detail
4680

4781
} // namespace ops

0 commit comments

Comments
 (0)