Skip to content

Commit d3efab9

Browse files
authored
[TOSA] Fix Tensor.hacked_twin to support diff size indexes (llvm#3547)
- Broadcasts index list tensors - Adds torch.nn.Unfold test Signed-off-by: Suraj Sudhir <[email protected]>
1 parent 8bd1b97 commit d3efab9

File tree

3 files changed

+146
-6
lines changed

3 files changed

+146
-6
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+118-5
Original file line numberDiff line numberDiff line change
@@ -3797,13 +3797,126 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
37973797
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
37983798
}
37993799

3800-
// Right now only support multiple indexes with same shape
3801-
// TODO for different shape multiple indexes, add broadcast_to for small
3802-
// shape
3800+
auto getRankExtendedShape =
3801+
[](SmallVector<int64_t> inputShape,
3802+
SmallVector<int64_t> maxRank1DimShape) -> SmallVector<int64_t> {
3803+
SmallVector<int64_t> rankExtendedShape(maxRank1DimShape);
3804+
auto inputRank = inputShape.size();
3805+
auto maxRank = maxRank1DimShape.size();
3806+
auto startIdx = maxRank - inputRank;
3807+
for (size_t i = startIdx; i < maxRank; i++) {
3808+
rankExtendedShape[i] = inputShape[i - startIdx];
3809+
}
3810+
return rankExtendedShape;
3811+
};
3812+
3813+
bool hasDiffShapedIndexes = false;
38033814
for (auto indexShapeOneDim : indexesShape) {
38043815
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
3805-
return rewriter.notifyMatchFailure(
3806-
op, "unimplemented: Only support multi indexes with same shape");
3816+
hasDiffShapedIndexes = true;
3817+
break;
3818+
}
3819+
}
3820+
3821+
if (hasDiffShapedIndexes) {
3822+
int64_t maxRank = 1;
3823+
for (auto idxRank : indexesRank) {
3824+
if (idxRank > maxRank)
3825+
maxRank = idxRank;
3826+
}
3827+
// Tensor shape of max rank, each dim being 1
3828+
SmallVector<int64_t> maxRank1DimShape;
3829+
for (int i = 0; i < maxRank; i++)
3830+
maxRank1DimShape.push_back(1);
3831+
// Tensor shape of max rank, each dim being the max dim.
3832+
SmallVector<int64_t> maxRankMaxDimShape(maxRank1DimShape);
3833+
3834+
auto updateMaxRankMaxDimShape =
3835+
[&](SmallVector<int64_t> broadcastedShape) -> LogicalResult {
3836+
for (size_t i = 0; i < maxRankMaxDimShape.size(); i++) {
3837+
// check for malformed index tensors
3838+
if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 &&
3839+
maxRankMaxDimShape[i] != broadcastedShape[i]) {
3840+
return failure();
3841+
}
3842+
if (broadcastedShape[i] > maxRankMaxDimShape[i])
3843+
maxRankMaxDimShape[i] = broadcastedShape[i];
3844+
}
3845+
return success();
3846+
};
3847+
3848+
for (size_t i = 0; i < indexesRank.size(); i++) {
3849+
// Reshape all index tensors to same maxRank
3850+
auto idxRank = indexesRank[i];
3851+
auto unreshapedIdxTensor = indicesTfConcatTensors[i];
3852+
SmallVector<int64_t> broadcastedShape =
3853+
getRankExtendedShape(indexesShape[i], maxRank1DimShape);
3854+
3855+
if (idxRank < maxRank) {
3856+
auto idxType =
3857+
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
3858+
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
3859+
auto broadcastedShapeTf(broadcastedShape);
3860+
broadcastedShapeTf.push_back(1);
3861+
auto reshapeOutputTy = RankedTensorType::get(
3862+
broadcastedShapeTf, idxType.getElementType());
3863+
// Update the tensor array with the max rank-extended form
3864+
indicesTfConcatTensors[i] = rewriter.create<tosa::ReshapeOp>(
3865+
op->getLoc(), reshapeOutputTy, unreshapedIdxTensor,
3866+
rewriter.getDenseI64ArrayAttr(broadcastedShapeTf));
3867+
}
3868+
3869+
// Construct the max rank broadcasted form of all index tensors with
3870+
// each index tensor.
3871+
if (updateMaxRankMaxDimShape(broadcastedShape).failed()) {
3872+
return rewriter.notifyMatchFailure(
3873+
op, "Malformed index tensors that have mismatched dim shapes");
3874+
}
3875+
3876+
// Every index now has the same rank but not yet same shape until
3877+
// tosa.tile below.
3878+
indexesShape[i] = broadcastedShape;
3879+
indexesRank[i] = maxRank;
3880+
}
3881+
3882+
auto getTileOpShape = [&](SmallVector<int64_t> indexShape,
3883+
SmallVector<int64_t> &tileOpShape) -> bool {
3884+
bool needsTiling = false;
3885+
for (size_t i = 0; i < indexShape.size(); i++) {
3886+
if (1 == indexShape[i]) {
3887+
tileOpShape.push_back(maxRankMaxDimShape[i]);
3888+
needsTiling = true;
3889+
} else {
3890+
tileOpShape.push_back(1);
3891+
}
3892+
}
3893+
return needsTiling;
3894+
};
3895+
3896+
// Use tosa.tile to broadcast in multiple dims so all index tensors have
3897+
// the same shape. This materializes new tensors.
3898+
for (size_t i = 0; i < indexesRank.size(); i++) {
3899+
SmallVector<int64_t> tileOpShape;
3900+
bool needsTiling = getTileOpShape(indexesShape[i], tileOpShape);
3901+
3902+
if (needsTiling) {
3903+
auto idxType =
3904+
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
3905+
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
3906+
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
3907+
maxRankMaxDimShapeTf.push_back(1);
3908+
auto tileOpShapeTf(tileOpShape);
3909+
tileOpShapeTf.push_back(1);
3910+
auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
3911+
idxType.getElementType());
3912+
auto reshapedIdxTensor = indicesTfConcatTensors[i];
3913+
indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
3914+
op->getLoc(), tileOutputTy, reshapedIdxTensor,
3915+
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
3916+
}
3917+
3918+
// Every index tensor now has the same rank and shape
3919+
indexesShape[i] = maxRankMaxDimShape;
38073920
}
38083921
}
38093922

projects/pt1/e2e_testing/xfail_sets.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# this is added to check the torch.onnx.export -> import_onnx -> torch path
3131
"DeformConv2D_basic",
3232
"ReduceAnyDimFloatModule_basic",
33+
"UnfoldModule_basic",
3334
}
3435

3536
LINALG_CRASHING_SET = {
@@ -1983,6 +1984,8 @@
19831984
"TorchPrimLoopForLikeTensorArgModule_basic",
19841985
"RenormModuleFloat32NegativeDim_basic",
19851986
"RenormModuleFloat32_basic",
1987+
"IndexTensorStaticContiguousWithNoneModule_basic",
1988+
"IndexTensorStaticNonContiguousWithNoneModule_basic",
19861989
}
19871990

19881991
MAKE_FX_TOSA_PASS_SET = (
@@ -2750,6 +2753,7 @@
27502753
"ReduceAnyFloatModule_basic",
27512754
"ReduceMaxAlongDimUnsignedInt_basic",
27522755
"ReduceMinAlongDimUnsignedInt_basic",
2756+
"UnfoldModule_basic",
27532757
}
27542758

27552759
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3189,7 +3193,6 @@
31893193
"IndexSelectWholeTensorModule_basic",
31903194
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
31913195
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
3192-
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
31933196
"IndexTensorMultiInputContiguousCenter_basic",
31943197
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
31953198
"IndexTensorMultiInputNonContiguousDynamic_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

+24
Original file line numberDiff line numberDiff line change
@@ -5646,3 +5646,27 @@ def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
56465646
module.forward(
56475647
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
56485648
)
5649+
5650+
5651+
# ==============================================================================
5652+
5653+
5654+
class UnfoldModule(torch.nn.Module):
5655+
def __init__(self):
5656+
super().__init__()
5657+
self.unfold = torch.nn.Unfold(kernel_size=(2, 3))
5658+
5659+
@export
5660+
@annotate_args(
5661+
[
5662+
None,
5663+
([-1, -1, -1, -1], torch.float32, True),
5664+
]
5665+
)
5666+
def forward(self, input):
5667+
return self.unfold(input)
5668+
5669+
5670+
@register_test_case(module_factory=lambda: UnfoldModule())
5671+
def UnfoldModule_basic(module, tu: TestUtils):
5672+
module.forward(tu.rand(2, 5, 3, 4))

0 commit comments

Comments
 (0)