Skip to content

Commit 13ee7c2

Browse files
[TOSA] Add legalization for torch.aten.unfold (llvm#3922)
* Add Torch to TOSA legalization for torch.aten.unfold * Update e2e results in xfail_sets.py * Fix a minor detail in one of the unfold e2e tests * Add LIT tests for aten.unfold Change-Id: I6583019d1c2569bdaf9f0b67cf44b33067448af7 Signed-off-by: Justin Ngo <[email protected]>
1 parent 2f8dbca commit 13ee7c2

File tree

4 files changed

+251
-12
lines changed

4 files changed

+251
-12
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+193
Original file line numberDiff line numberDiff line change
@@ -8256,6 +8256,198 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
82568256
return success();
82578257
}
82588258

8259+
// Legalization for aten.unfold
8260+
template <>
8261+
LogicalResult ConvertAtenOp<AtenUnfoldOp>::matchAndRewrite(
8262+
AtenUnfoldOp op, OpAdaptor adaptor,
8263+
ConversionPatternRewriter &rewriter) const {
8264+
// Approach: Use GatherOp to retrieve target elements from target dim and then
8265+
// reshape the output into slices according to the output shape
8266+
//
8267+
// Lowering steps:
8268+
// 1. Create PyTorch-style indices tensor corresponding to target elements and
8269+
// reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1))
8270+
// with d_x being the dimension size of the input at dim x.
8271+
// The indices vector will be calculated using the following formula:
8272+
// for i in range(d_0 * d_1 * ... * d_(target_dim - 1)):
8273+
// for window in range(nWindows):
8274+
// for elementIndex in range(size):
8275+
// for j in range(d_(target_dim + 1) * ... * d_(rank-1)):
8276+
// indices_vec.push_back(elementIndex + window * step)
8277+
// 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices
8278+
// 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
8279+
// target elements
8280+
// 4. Reshape result from above to correct output shape
8281+
auto self = adaptor.getSelf();
8282+
8283+
auto selfType = dyn_cast<TensorType>(self.getType());
8284+
if (!selfType)
8285+
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
8286+
8287+
auto selfShape = selfType.getShape();
8288+
auto selfRank = selfType.getRank();
8289+
auto selfElemTy = selfType.getElementType();
8290+
8291+
auto resultType =
8292+
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
8293+
auto resultElemTy = resultType.getElementType();
8294+
8295+
int64_t dim;
8296+
if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim)))
8297+
return rewriter.notifyMatchFailure(op,
8298+
"Only constant int dims are supported");
8299+
8300+
int64_t size;
8301+
if (!matchPattern(op.getSize(), m_TorchConstantInt(&size)))
8302+
return rewriter.notifyMatchFailure(op,
8303+
"Only constant int sizes are supported");
8304+
8305+
int64_t step;
8306+
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
8307+
return rewriter.notifyMatchFailure(op,
8308+
"Only constant int steps are supported");
8309+
8310+
if (step <= 0)
8311+
return rewriter.notifyMatchFailure(op, "Step value must be greater than 0");
8312+
8313+
// Handle rank zero
8314+
if (selfRank == 0) {
8315+
if (dim != 0)
8316+
return rewriter.notifyMatchFailure(
8317+
op, "Unsupported dim value for rank zero input");
8318+
8319+
if (size != 1)
8320+
return rewriter.notifyMatchFailure(
8321+
op, "Unsupported size value for rank zero input");
8322+
8323+
auto result = rewriter.create<tosa::ReshapeOp>(
8324+
op->getLoc(), RankedTensorType::get({1}, selfElemTy), self,
8325+
rewriter.getDenseI64ArrayAttr({1}));
8326+
8327+
rewriter.replaceOp(op, {result.getResult()});
8328+
return success();
8329+
}
8330+
8331+
dim = toPositiveDim(dim, selfRank);
8332+
if (!isValidDim(dim, selfRank))
8333+
return rewriter.notifyMatchFailure(op, "Dim value is invalid");
8334+
8335+
// Size of dimension 'dim' in the returned tensor (or number of windows within
8336+
// the dimension that got sliced)
8337+
int64_t nWindows = (selfShape[dim] - size) / step + 1;
8338+
8339+
// Find number of times that each base index value gets repeated for target
8340+
// dim based on dim values before and after target dim i.e. preDimAccumulate =
8341+
// d_0 * d_1 * ... * d_(target_dim - 1)
8342+
// postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1)
8343+
int64_t preDimAccumulate =
8344+
std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1,
8345+
std::multiplies<int64_t>());
8346+
int64_t postDimAccumulate =
8347+
std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1,
8348+
std::multiplies<int64_t>());
8349+
8350+
// Calculate PyTorch-style gather indices vector
8351+
// Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1
8352+
// -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2
8353+
// pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
8354+
// 1, 1, 1, 2, 2, 2, 3, 3, 3]
8355+
// pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
8356+
// 1, 1, 1, 2, 2, 2, 3, 3, 3,
8357+
// 0, 0, 0, 1, 1, 1, 2, 2, 2,
8358+
// 1, 1, 1, 2, 2, 2, 3, 3, 3]
8359+
SmallVector<int32_t> pyTorchIndicesBaseVec;
8360+
SmallVector<int32_t> pyTorchIndicesVec;
8361+
8362+
for (int64_t window = 0; window < nWindows; window++) {
8363+
for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) {
8364+
int32_t baseIndex = static_cast<int32_t>(elementIndex + window * step);
8365+
for (int64_t i = 0; i < postDimAccumulate; i++)
8366+
pyTorchIndicesBaseVec.push_back(baseIndex);
8367+
}
8368+
}
8369+
8370+
for (int64_t i = 0; i < preDimAccumulate; i++)
8371+
pyTorchIndicesVec.insert(pyTorchIndicesVec.end(),
8372+
pyTorchIndicesBaseVec.begin(),
8373+
pyTorchIndicesBaseVec.end());
8374+
8375+
// Create the PyTorch-style indices tensor
8376+
// Continuing with the previous example:
8377+
// pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3)
8378+
// pyTorchIndices = tensor([[[0, 0, 0],
8379+
// [1, 1, 1],
8380+
// [2, 2, 2],
8381+
// [1, 1, 1],
8382+
// [2, 2, 2],
8383+
// [3, 3, 3]],
8384+
// [[0, 0, 0],
8385+
// [1, 1, 1],
8386+
// [2, 2, 2],
8387+
// [1, 1, 1],
8388+
// [2, 2, 2],
8389+
// [3, 3, 3]]])
8390+
SmallVector<int64_t> pyTorchIndicesShape(selfShape);
8391+
pyTorchIndicesShape[dim] = nWindows * size;
8392+
auto pyTorchIndices =
8393+
tosa::getConstTensor<int32_t>(rewriter, op, pyTorchIndicesVec,
8394+
pyTorchIndicesShape)
8395+
.value();
8396+
8397+
// Convert PyTorch-style indices to TensorFlow-style indices
8398+
auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self,
8399+
pyTorchIndices, dim);
8400+
if (!tfIndices)
8401+
return rewriter.notifyMatchFailure(op,
8402+
"Convert PyTorch-style indices and dim "
8403+
"to TensorFlow-style indices failed");
8404+
8405+
// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
8406+
// target elements
8407+
auto gatherNdOp = tosa::convertGatherNdOp(
8408+
rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy),
8409+
self, tfIndices.value());
8410+
if (!gatherNdOp)
8411+
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
8412+
8413+
// Reshape to an intermediary shape where the gathered elements in dimension
8414+
// 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size'
8415+
SmallVector<int64_t> intermediaryShape;
8416+
for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) {
8417+
if (currentDim == dim) {
8418+
intermediaryShape.push_back(nWindows);
8419+
intermediaryShape.push_back(size);
8420+
} else {
8421+
intermediaryShape.push_back(pyTorchIndicesShape[currentDim]);
8422+
}
8423+
}
8424+
8425+
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
8426+
op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy),
8427+
gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape));
8428+
8429+
// Permute dims to the correct result order
8430+
SmallVector<int32_t> permutedDims;
8431+
for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) {
8432+
if (currentDim != dim + 1)
8433+
permutedDims.push_back(static_cast<int32_t>(currentDim));
8434+
}
8435+
permutedDims.push_back(static_cast<int32_t>(dim + 1));
8436+
8437+
auto permutedDimsConst = tosa::getConstTensor<int32_t>(
8438+
rewriter, op,
8439+
/*vec=*/permutedDims,
8440+
/*shape=*/{static_cast<int32_t>(selfRank + 1)})
8441+
.value();
8442+
8443+
auto result = rewriter.create<tosa::TransposeOp>(
8444+
op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst);
8445+
8446+
rewriter.replaceOp(op, {result.getResult()});
8447+
8448+
return success();
8449+
}
8450+
82598451
} // namespace
82608452

82618453
// -----------------------------------------------------------------------------
@@ -8617,6 +8809,7 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
86178809
INSERT_ATENOP_PATTERN(AtenLog1pOp);
86188810
INSERT_ATENOP_PATTERN(AtenLog10Op);
86198811
INSERT_ATENOP_PATTERN(AtenTanOp);
8812+
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
86208813
#undef INSERT_ATENOP_PATTERN
86218814

86228815
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

projects/pt1/e2e_testing/xfail_sets.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,8 @@
16981698
"Aten_TrilinearModuleSumAllDims_basic",
16991699
"Aten_TrilinearModuleSumdims_basic",
17001700
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
1701+
"CrossEntropyLossModule_basic",
1702+
"CrossEntropyLossNoReductionModule_basic",
17011703
"ScatterSrcModule_basic",
17021704
"ScatterSrcStaticModule_basic",
17031705
"HBC_basic",
@@ -1706,6 +1708,9 @@
17061708
# Write the TOSA set as a "passing" set as it is very early in development
17071709
# and very few tests work yet.
17081710
TOSA_PASS_SET = {
1711+
"Unfold_Module_Rank_4",
1712+
"Unfold_Module_Rank_Zero_basic",
1713+
"Unfold_Module_basic",
17091714
"ElementwiseErfIntModule_basic",
17101715
"ElementwiseIntTensorLtFloatScalarModule_basic",
17111716
"ElementwiseSigmoidIntModule_basic",
@@ -3441,6 +3446,8 @@
34413446
}
34423447

34433448
FX_IMPORTER_TOSA_XFAIL_SET = {
3449+
"UniformModule_basic",
3450+
"UniformStaticShapeModule_basic",
34443451
"AtenFftRfft2DLastDim_basic",
34453452
"AtenFftRfft2DMiddleDim_basic",
34463453
"IsInfiniteModule_basic",
@@ -3460,11 +3467,7 @@
34603467
"MaxPool3dModule_basic",
34613468
"MaxPool3dStaticModule_basic",
34623469
"ViewDtypeStaticModule_basic",
3463-
"Unfold_Module_Dynamic_basic",
3464-
"Unfold_Module_Rank_4",
34653470
"Unfold_Module_Rank_Zero_Size_Zero_basic",
3466-
"Unfold_Module_Rank_Zero_basic",
3467-
"Unfold_Module_basic",
34683471
"ArangeZeroElementOutputModule_basic",
34693472
"NumpyTRank0Module_basic",
34703473
"Permute0RankModule_basic",
@@ -3888,17 +3891,10 @@
38883891
"AdaptiveAvgPool2dDynamic_basic",
38893892
"CrossEntropyLossModule_basic",
38903893
"CrossEntropyLossNoReductionModule_basic",
3891-
"ElementwiseRreluTrainModule_basic",
3892-
"ElementwiseRreluTrainStaticModule_basic",
3893-
"IndexPutImpl1DFloatNonAccumulateModule_basic",
3894-
"IndexPutImpl1DIntNonAccumulateModule_basic",
3895-
"IndexPutImpl2DFloatNonAccumulateModule_basic",
3896-
"IndexPutImpl3DFloatNonAccumulateModule_basic",
38973894
"IouOfModule_basic",
38983895
"MeshgridIndexingIJ_basic",
38993896
"MeshgridIndexingXY_basic",
39003897
"Meshgrid_basic",
3901-
"OneHotModule_basic",
39023898
"ReduceFrobeniusNormKeepDimModule_basic",
39033899
"ReduceFrobeniusNormModule_basic",
39043900
"ScaledDotProductAttentionBoolMaskModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1752,7 +1752,7 @@ def forward(self, x):
17521752
return x.unfold(0, 0, 1)
17531753

17541754

1755-
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
1755+
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero_Size_Zero())
17561756
def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils):
17571757
module.forward(tu.rand())
17581758

test/Conversion/TorchToTosa/basic.mlir

+50
Original file line numberDiff line numberDiff line change
@@ -2943,3 +2943,53 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s
29432943
}
29442944

29452945
// -----
2946+
2947+
// CHECK-LABEL: func.func @torch.aten.unfold$basic(
2948+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> {
2949+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32>
2950+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
2951+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
2952+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32>
2953+
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 6, 4, 1>} : (tensor<6x4xi32>) -> tensor<6x4x1xi32>
2954+
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32>
2955+
// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32>
2956+
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 24, 1>} : (tensor<6x4xf32>) -> tensor<1x24x1xf32>
2957+
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 24, 2>} : (tensor<6x4x2xi32>) -> tensor<24x2xi32>
2958+
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
2959+
// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_9]], %[[VAL_10]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<2xi32>) -> tensor<24x2xi32>
2960+
// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32>
2961+
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 1, 24>} : (tensor<24x1xi32>) -> tensor<1x24xi32>
2962+
// CHECK: %[[VAL_14:.*]] = tosa.gather %[[VAL_8]], %[[VAL_13]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32>
2963+
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 6, 4>} : (tensor<1x24x1xf32>) -> tensor<6x4xf32>
2964+
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 3, 2, 4>} : (tensor<6x4xf32>) -> tensor<3x2x4xf32>
2965+
// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
2966+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32>
2967+
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32>
2968+
// CHECK: return %[[VAL_19]] : !torch.vtensor<[3,4,2],f32>
2969+
// CHECK: }
2970+
func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> {
2971+
%int0 = torch.constant.int 0
2972+
%int2 = torch.constant.int 2
2973+
%0 = torch.aten.unfold %arg0, %int0, %int2, %int2 : !torch.vtensor<[6,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4,2],f32>
2974+
return %0 : !torch.vtensor<[3,4,2],f32>
2975+
}
2976+
2977+
// -----
2978+
2979+
// CHECK-LABEL: func.func @torch.aten.unfold$rank_zero(
2980+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
2981+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor<f32>
2982+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
2983+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
2984+
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
2985+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
2986+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32>
2987+
// CHECK: }
2988+
func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
2989+
%int0 = torch.constant.int 0
2990+
%int1 = torch.constant.int 1
2991+
%0 = torch.aten.unfold %arg0, %int0, %int1, %int1 : !torch.vtensor<[],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
2992+
return %0 : !torch.vtensor<[1],f32>
2993+
}
2994+
2995+
// -----

0 commit comments

Comments
 (0)