@@ -15,15 +15,15 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
15
15
const torch::autograd::Variable& input,
16
16
const torch::autograd::Variable& rois,
17
17
double spatial_scale,
18
- int64_t pooled_height,
19
- int64_t pooled_width) {
18
+ c10::SymInt pooled_height,
19
+ c10::SymInt pooled_width) {
20
20
ctx->saved_data [" spatial_scale" ] = spatial_scale;
21
21
ctx->saved_data [" pooled_height" ] = pooled_height;
22
22
ctx->saved_data [" pooled_width" ] = pooled_width;
23
- ctx->saved_data [" input_shape" ] = input.sizes ();
23
+ ctx->saved_data [" input_shape" ] = input.sym_sizes ();
24
24
at::AutoDispatchBelowADInplaceOrView g;
25
- auto result =
26
- ps_roi_pool ( input, rois, spatial_scale, pooled_height, pooled_width);
25
+ auto result = ps_roi_pool_symint (
26
+ input, rois, spatial_scale, pooled_height, pooled_width);
27
27
28
28
auto output = std::get<0 >(result);
29
29
auto channel_mapping = std::get<1 >(result);
@@ -40,18 +40,18 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
40
40
auto saved = ctx->get_saved_variables ();
41
41
auto rois = saved[0 ];
42
42
auto channel_mapping = saved[1 ];
43
- auto input_shape = ctx->saved_data [" input_shape" ].toIntList ();
44
- auto grad_in = detail::_ps_roi_pool_backward (
43
+ auto input_shape = ctx->saved_data [" input_shape" ].toList ();
44
+ auto grad_in = detail::_ps_roi_pool_backward_symint (
45
45
grad_output[0 ],
46
46
rois,
47
47
channel_mapping,
48
48
ctx->saved_data [" spatial_scale" ].toDouble (),
49
- ctx->saved_data [" pooled_height" ].toInt (),
50
- ctx->saved_data [" pooled_width" ].toInt (),
51
- input_shape[0 ],
52
- input_shape[1 ],
53
- input_shape[2 ],
54
- input_shape[3 ]);
49
+ ctx->saved_data [" pooled_height" ].toSymInt (),
50
+ ctx->saved_data [" pooled_width" ].toSymInt (),
51
+ input_shape[0 ]. get (). toSymInt () ,
52
+ input_shape[1 ]. get (). toSymInt () ,
53
+ input_shape[2 ]. get (). toSymInt () ,
54
+ input_shape[3 ]. get (). toSymInt () );
55
55
56
56
return {
57
57
grad_in,
@@ -72,14 +72,14 @@ class PSROIPoolBackwardFunction
72
72
const torch::autograd::Variable& rois,
73
73
const torch::autograd::Variable& channel_mapping,
74
74
double spatial_scale,
75
- int64_t pooled_height,
76
- int64_t pooled_width,
77
- int64_t batch_size,
78
- int64_t channels,
79
- int64_t height,
80
- int64_t width) {
75
+ c10::SymInt pooled_height,
76
+ c10::SymInt pooled_width,
77
+ c10::SymInt batch_size,
78
+ c10::SymInt channels,
79
+ c10::SymInt height,
80
+ c10::SymInt width) {
81
81
at::AutoDispatchBelowADInplaceOrView g;
82
- auto grad_in = detail::_ps_roi_pool_backward (
82
+ auto grad_in = detail::_ps_roi_pool_backward_symint (
83
83
grad,
84
84
rois,
85
85
channel_mapping,
@@ -105,8 +105,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
105
105
const at::Tensor& input,
106
106
const at::Tensor& rois,
107
107
double spatial_scale,
108
- int64_t pooled_height,
109
- int64_t pooled_width) {
108
+ c10::SymInt pooled_height,
109
+ c10::SymInt pooled_width) {
110
110
auto result = PSROIPoolFunction::apply (
111
111
input, rois, spatial_scale, pooled_height, pooled_width);
112
112
@@ -118,12 +118,12 @@ at::Tensor ps_roi_pool_backward_autograd(
118
118
const at::Tensor& rois,
119
119
const at::Tensor& channel_mapping,
120
120
double spatial_scale,
121
- int64_t pooled_height,
122
- int64_t pooled_width,
123
- int64_t batch_size,
124
- int64_t channels,
125
- int64_t height,
126
- int64_t width) {
121
+ c10::SymInt pooled_height,
122
+ c10::SymInt pooled_width,
123
+ c10::SymInt batch_size,
124
+ c10::SymInt channels,
125
+ c10::SymInt height,
126
+ c10::SymInt width) {
127
127
return PSROIPoolBackwardFunction::apply (
128
128
grad,
129
129
rois,
0 commit comments