-
Notifications
You must be signed in to change notification settings - Fork 535
[TOSA] Fix output size calculation for pool ops #4125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
TOSA requires (inputDim + padBefore + padAfter - kernel) to be fully divisible by stride. This update adds pad and input size modifications for pooling ops (AvgPool2d and MaxPool2d) to satisfy that requirement by TOSA. Signed-off-by: Justin Ngo <[email protected]> Change-Id: Iab4021f2dda87cb87e54e4e9ca20bd3688dc1c50
if (remainderDim > padAfter) { | ||
SmallVector<int64_t> startSlice(inputRank, 0); | ||
SmallVector<int64_t> sizeSlice( | ||
dyn_cast<TensorType>(input.getType()).getShape()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can use inputShape
directly instead of computing it again.
RankedTensorType inputTy, SmallVectorImpl<int64_t> &kernelSize, | ||
SmallVectorImpl<int64_t> &strideArray, SmallVectorImpl<int64_t> &padArray, | ||
SmallVectorImpl<int64_t> &dilationArray, bool ceilMode = false) { | ||
auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); | ||
auto inputRank = inputTy.getRank(); | ||
auto inputElemTy = inputTy.getElementType(); | ||
|
||
// PyTorch uses xCHW, so Height dim index is 2 and Width dim index is 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For CHW input and not NCHW input, isn't H/W dim index 1/2 ?
inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0], | ||
padArray[0], dilationArray[0], ceilMode); | ||
rewriter, input, op->getLoc(), inputRank, inputShape, inputElemTy, | ||
/*dimIndex=*/2, inputShape[inputRank - 2], kernelSize[0], strideArray[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using inputRank - 2
instead of 2
will honor both CHW and NCHW inputs.
static int64_t getOutputDim(PatternRewriter &rewriter, Value &input, | ||
Location loc, int64_t inputRank, | ||
ArrayRef<int64_t> inputShape, Type inputElemTy, | ||
int64_t dimIndex, int64_t inputDim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#nit: since both dimIndex
and inputShape
are inputs, can grab inputDim = inputShape[dimIndex]
instead of passing it as input
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x1x56x56xf32> -> !torch.vtensor<[1,1,56,56],f32> | ||
// CHECK: return %[[VAL_19]] : !torch.vtensor<[1,1,56,56],f32> | ||
// CHECK: } | ||
func.func @torch.aten.max_pool2d$full_dim_indivisible_by_stride_without_sliced_input(%arg0: !torch.vtensor<[1,1,112,112],f32>) -> !torch.vtensor<[1,1,56,56],f32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I am still getting familiar with the code-base, do you know if there are any guidelines as to when to add e2e tests in addition to LIT tests -- IMO adding at least one e2e test to lock down the numeric will be beneficial.
TOSA requires (inputDim + padBefore + padAfter - kernel) to be fully divisible by stride. This update adds pad and input size modifications for pooling ops (AvgPool2d and MaxPool2d) to satisfy that requirement by TOSA.