Skip to content

Commit b1c31b1

Browse files
nkaretnikovpytorchmergebot
authored andcommitted
[pt2] metas and SymInt support for max_pool ops (pytorch#103951)
Pull Request resolved: pytorch#103951 Approved by: https://github.com/Chillee, https://github.com/kulinseth
1 parent c4a6f86 commit b1c31b1

11 files changed

+285
-28
lines changed

Diff for: aten/src/ATen/native/DilatedMaxPool2d.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ bool ceil_mode) {
3939
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
4040

4141
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
42-
"max_pool2d: padding must be either be a single int, or a tuple of two ints");
42+
"max_pool2d: padding must either be a single int, or a tuple of two ints");
4343
const int padH = safe_downcast<int, int64_t>(padding[0]);
4444
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
4545

@@ -112,7 +112,7 @@ const Tensor& indices) {
112112
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
113113

114114
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
115-
"max_pool2d: padding must be either be a single int, or a tuple of two ints");
115+
"max_pool2d: padding must either be a single int, or a tuple of two ints");
116116
const int padH = safe_downcast<int, int64_t>(padding[0]);
117117
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
118118

Diff for: aten/src/ATen/native/DilatedMaxPool3d.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ void max_pool3d_with_indices_out_cpu_template(
173173
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
174174

175175
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
176-
"max_pool3d: padding must be either be a single int, or a tuple of three ints");
176+
"max_pool3d: padding must either be a single int, or a tuple of three ints");
177177
const int pT = safe_downcast<int, int64_t>(padding[0]);
178178
const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
179179
const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
@@ -381,7 +381,7 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template(
381381
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
382382

383383
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
384-
"max_pool3d: padding must be either be a single int, or a tuple of three ints");
384+
"max_pool3d: padding must either be a single int, or a tuple of three ints");
385385
const int pT = safe_downcast<int, int64_t>(padding[0]);
386386
const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
387387
const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);

Diff for: aten/src/ATen/native/MaxPooling.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static void check_max_pool1d(
3434

3535
TORCH_CHECK(
3636
self.dim() == 2 || self.dim() == 3,
37-
"max_pool1d() Expected 2D or 3D input tensor, but got ", self.sizes());
37+
"max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
3838
TORCH_CHECK(
3939
kernel_size.size() == 1,
4040
"max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
@@ -74,7 +74,7 @@ static void check_max_pool1d(
7474
TORCH_CHECK(
7575
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
7676

77-
const int64_t OW = pooling_output_shape(self.size(-1), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
77+
const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
7878
TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
7979
}
8080

@@ -132,10 +132,10 @@ Tensor max_pool1d(
132132

133133
auto ndim = self.ndimension();
134134
TORCH_CHECK(
135-
(ndim == 2 && self.size(0) != 0 && self.size(1) != 0) ||
136-
(ndim == 3 && self.size(1) != 0 && self.size(2) != 0),
135+
(ndim == 2 && self.sym_size(0) != 0 && self.sym_size(1) != 0) ||
136+
(ndim == 3 && self.sym_size(1) != 0 && self.sym_size(2) != 0),
137137
"max_pool1d: Expected 2D or 3D (batch mode) tensor with optional 0 dim batch size for input, but got:",
138-
self.sizes());
138+
self.sym_sizes());
139139

140140
if (self.is_quantized()) {
141141
return at::quantized_max_pool1d(

Diff for: aten/src/ATen/native/cuda/DilatedMaxPool3d.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ void max_pool3d_with_indices_out_cuda_template(
313313
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
314314

315315
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
316-
"max_pool3d: padding must be either be a single int, or a tuple of three ints");
316+
"max_pool3d: padding must either be a single int, or a tuple of three ints");
317317
const int pT = safe_downcast<int, int64_t>(padding[0]);
318318
const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
319319
const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
@@ -443,7 +443,7 @@ void max_pool3d_with_indices_backward_out_cuda_template(
443443
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
444444

445445
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
446-
"max_pool3d: padding must be either be a single int, or a tuple of three ints");
446+
"max_pool3d: padding must either be a single int, or a tuple of three ints");
447447
const int pT = safe_downcast<int, int64_t>(padding[0]);
448448
const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
449449
const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);

Diff for: aten/src/ATen/native/mps/operations/Pooling.mm

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ static void pool2d_template(const Tensor& input,
7070
": stride must either be omitted, a single int, or a tuple of two ints")
7171
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
7272
op_name,
73-
": padding must be either be a single int, or a tuple of two ints");
73+
": padding must either be a single int, or a tuple of two ints");
7474
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
7575
op_name,
7676
": dilation must be either a single int, or a tuple of two ints");

Diff for: test/functorch/test_aotdispatch.py

-5
Original file line numberDiff line numberDiff line change
@@ -2834,8 +2834,6 @@ def forward(self, x):
28342834
xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides
28352835
xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides
28362836
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
2837-
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
2838-
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic m...
28392837
xfail('nn.functional.multi_margin_loss', ''), # could not find kernel
28402838
xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel
28412839
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@@ -2982,9 +2980,6 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
29822980
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
29832981
torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
29842982
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
2985-
torch.nn.MaxPool1d, # Cannot call sizes() on tensor with symbolic sizes/strides
2986-
torch.nn.MaxPool3d, # torch._subclasses.fake_tensor.UnsupportedOperatorException:
2987-
# aten.max_pool3d_with_indices.default
29882983
}
29892984

29902985

Diff for: test/test_meta.py

-6
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,6 @@ def run_meta_crossref(
631631
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
632632
torch.nn.functional.ctc_loss : {f64, f32},
633633
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
634-
torch.nn.functional.max_pool3d : {f64, f32},
635-
torch.nn.functional.max_pool3d_with_indices : {f64, f32},
636634
torch.nn.functional.multi_margin_loss : {f64, f32},
637635
torch.nn.functional.multilabel_margin_loss : {f64, f32},
638636
torch.nn.functional.one_hot : {i64},
@@ -722,8 +720,6 @@ def run_meta_crossref(
722720
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
723721
torch.kthvalue: {f16}, # aten::kthvalue.values
724722
torch.median: {f16}, # aten::median, aten::median.dim_values
725-
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
726-
torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices
727723
torch.nn.functional.multi_margin_loss: {bf16, f16}, # aten::multi_margin_loss
728724
torch.nn.functional.multilabel_margin_loss: {bf16, f16}, # aten::multilabel_margin_loss_forward
729725
torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out
@@ -847,7 +843,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
847843
aten.histogram.bin_ct : {f32, f64},
848844
aten.histogram.bins_tensor : {f32, f64},
849845
aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
850-
aten.max_pool3d_with_indices.default : {f32, f64},
851846
aten.median.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
852847
aten.median.dim : {i8, f64, i64, bf16, f32, i32, i16, u8},
853848
aten.mode.default : {f16, i8, f64, i64, bf16, f32, i32, b8, i16, u8},
@@ -907,7 +902,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
907902
aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out
908903
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
909904
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output
910-
aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices
911905
aten.median.default: {f16}, # aten::median
912906
aten.median.dim: {f16}, # aten::median.dim_values
913907
aten.multi_margin_loss.default: {bf16, f16}, # aten::multi_margin_loss

Diff for: test/test_proxy_tensor.py

-2
Original file line numberDiff line numberDiff line change
@@ -1541,8 +1541,6 @@ def f(a, b, c, d, e):
15411541
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
15421542
xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec...
15431543
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
1544-
xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes.
1545-
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
15461544
xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
15471545
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
15481546
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...

0 commit comments

Comments
 (0)