@@ -43,6 +43,42 @@ at::Tensor deform_conv2d(
43
43
use_mask);
44
44
}
45
45
46
+ at::Tensor deform_conv2d_symint (
47
+ const at::Tensor& input,
48
+ const at::Tensor& weight,
49
+ const at::Tensor& offset,
50
+ const at::Tensor& mask,
51
+ const at::Tensor& bias,
52
+ c10::SymInt stride_h,
53
+ c10::SymInt stride_w,
54
+ c10::SymInt pad_h,
55
+ c10::SymInt pad_w,
56
+ c10::SymInt dilation_h,
57
+ c10::SymInt dilation_w,
58
+ c10::SymInt groups,
59
+ c10::SymInt offset_groups,
60
+ bool use_mask) {
61
+ C10_LOG_API_USAGE_ONCE (" torchvision.csrc.ops.deform_conv2d.deform_conv2d" );
62
+ static auto op = c10::Dispatcher::singleton ()
63
+ .findSchemaOrThrow (" torchvision::deform_conv2d" , " " )
64
+ .typed <decltype (deform_conv2d_symint)>();
65
+ return op.call (
66
+ input,
67
+ weight,
68
+ offset,
69
+ mask,
70
+ bias,
71
+ stride_h,
72
+ stride_w,
73
+ pad_h,
74
+ pad_w,
75
+ dilation_h,
76
+ dilation_w,
77
+ groups,
78
+ offset_groups,
79
+ use_mask);
80
+ }
81
+
46
82
namespace detail {
47
83
48
84
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -84,13 +120,52 @@ _deform_conv2d_backward(
84
120
use_mask);
85
121
}
86
122
123
+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
124
+ _deform_conv2d_backward_symint (
125
+ const at::Tensor& grad,
126
+ const at::Tensor& input,
127
+ const at::Tensor& weight,
128
+ const at::Tensor& offset,
129
+ const at::Tensor& mask,
130
+ const at::Tensor& bias,
131
+ c10::SymInt stride_h,
132
+ c10::SymInt stride_w,
133
+ c10::SymInt pad_h,
134
+ c10::SymInt pad_w,
135
+ c10::SymInt dilation_h,
136
+ c10::SymInt dilation_w,
137
+ c10::SymInt groups,
138
+ c10::SymInt offset_groups,
139
+ bool use_mask) {
140
+ static auto op =
141
+ c10::Dispatcher::singleton ()
142
+ .findSchemaOrThrow (" torchvision::_deform_conv2d_backward" , " " )
143
+ .typed <decltype (_deform_conv2d_backward_symint)>();
144
+ return op.call (
145
+ grad,
146
+ input,
147
+ weight,
148
+ offset,
149
+ mask,
150
+ bias,
151
+ stride_h,
152
+ stride_w,
153
+ pad_h,
154
+ pad_w,
155
+ dilation_h,
156
+ dilation_w,
157
+ groups,
158
+ offset_groups,
159
+ use_mask);
160
+ }
161
+
87
162
} // namespace detail
88
163
89
164
TORCH_LIBRARY_FRAGMENT (torchvision, m) {
90
165
m.def (TORCH_SELECTIVE_SCHEMA (
91
- " torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor" ));
166
+ " torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor" ));
92
167
m.def (TORCH_SELECTIVE_SCHEMA (
93
- " torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)" ));
168
+ " torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)" ));
94
169
}
95
170
96
171
} // namespace ops
0 commit comments