Skip to content

Commit 38a0a5a

Browse files
authored
Fix output size computation for MaxPool2D for ceil_model = true. (llvm#3890)
This PR fixes the output size computation as per https://github.com/pytorch/pytorch/blob/d8c14838f164ee02b88b6e37471b71bb0373f865/torch/_meta_registrations.py#L3847 ``` if ceil_mode: if (outputSize - 1) * stride >= inputSize + pad_l: outputSize -= 1 return outputSize ```
1 parent a6179c0 commit 38a0a5a

File tree

4 files changed

+66
-3
lines changed

4 files changed

+66
-3
lines changed

lib/Conversion/TorchToLinalg/Utils.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,22 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
116116
else
117117
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
118118
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
119+
120+
if (ceilMode) {
121+
Value outMinusOneTimesStride =
122+
b.createOrFold<arith::MulIOp>(loc, division, strideInt);
123+
Value inAddLeftPadding = b.createOrFold<arith::AddIOp>(
124+
loc, castIndexToInt64(b, loc, in), paddingInt);
125+
126+
auto reduceOutputDimCond =
127+
b.createOrFold<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
128+
outMinusOneTimesStride, inAddLeftPadding);
129+
130+
auto reducedDim = b.createOrFold<arith::SelectOp>(loc, reduceOutputDimCond,
131+
division, out);
132+
return castIntToIndex(b, loc, reducedDim);
133+
}
134+
119135
return castIntToIndex(b, loc, out);
120136
}
121137

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -5398,9 +5398,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
53985398
} else {
53995399
int64_t dimSize =
54005400
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
5401-
if (ceilMode && (dimSize % stride != 0))
5402-
return dimSize / stride + 2;
5403-
return dimSize / stride + 1;
5401+
int64_t outputDim = dimSize / stride + 1;
5402+
if (ceilMode && (dimSize % stride != 0) &&
5403+
(outputDim * stride < inputDim + padBefore))
5404+
outputDim++;
5405+
return outputDim;
54045406
}
54055407
}
54065408

projects/pt1/e2e_testing/xfail_sets.py

+16
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@
735735
"LenStrModule_basic",
736736
"MaxPool2dCeilModeTrueModule_basic",
737737
"MaxPool2dStaticCeilModeTrueModule_basic",
738+
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
738739
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
739740
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
740741
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
@@ -2255,6 +2256,7 @@
22552256
"MatmulStaticBroadcast_basic",
22562257
"MaxPool2dEmptyStrideStaticModule_basic",
22572258
"MaxPool2dStaticCeilModeTrueModule_basic",
2259+
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
22582260
"MaxPool2dStaticModule_basic",
22592261
"MeanModule_basic",
22602262
"MmDagModule_basic",
@@ -3380,6 +3382,13 @@
33803382
"ScaledDotProductAttentionBoolMaskModule_basic",
33813383
}
33823384

3385+
if torch_version_for_comparison() > version.parse("2.5.1"):
3386+
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
3387+
# error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible
3388+
# torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3
3389+
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
3390+
}
3391+
33833392
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
33843393
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
33853394
"AtenIntMM_basic",
@@ -4932,3 +4941,10 @@
49324941
"_LogSoftmaxModule_basic",
49334942
"_SoftmaxModule_basic",
49344943
}
4944+
4945+
if torch_version_for_comparison() > version.parse("2.5.1"):
4946+
ONNX_TOSA_XFAIL_SET = ONNX_TOSA_XFAIL_SET | {
4947+
# error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible
4948+
# torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3
4949+
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
4950+
}

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

+29
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
420420
module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0))
421421

422422

423+
class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module):
424+
def __init__(self):
425+
super().__init__()
426+
self.mp2d = torch.nn.MaxPool2d(
427+
kernel_size=6,
428+
stride=6,
429+
padding=3,
430+
dilation=1,
431+
ceil_mode=True,
432+
)
433+
434+
@export
435+
@annotate_args(
436+
[
437+
None,
438+
([2, 6, 20, 10], torch.float32, True),
439+
]
440+
)
441+
def forward(self, x):
442+
return self.mp2d(x)
443+
444+
445+
@register_test_case(
446+
module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule()
447+
)
448+
def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils):
449+
module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0))
450+
451+
423452
# ==============================================================================
424453

425454

0 commit comments

Comments
 (0)