@@ -3797,13 +3797,126 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
3797
3797
indicesTfConcatTensors.push_back (indicesTfOneDim.getResult ());
3798
3798
}
3799
3799
3800
- // Right now only support multiple indexes with same shape
3801
- // TODO for different shape multiple indexes, add broadcast_to for small
3802
- // shape
3800
+ auto getRankExtendedShape =
3801
+ [](SmallVector<int64_t > inputShape,
3802
+ SmallVector<int64_t > maxRank1DimShape) -> SmallVector<int64_t > {
3803
+ SmallVector<int64_t > rankExtendedShape (maxRank1DimShape);
3804
+ auto inputRank = inputShape.size ();
3805
+ auto maxRank = maxRank1DimShape.size ();
3806
+ auto startIdx = maxRank - inputRank;
3807
+ for (size_t i = startIdx; i < maxRank; i++) {
3808
+ rankExtendedShape[i] = inputShape[i - startIdx];
3809
+ }
3810
+ return rankExtendedShape;
3811
+ };
3812
+
3813
+ bool hasDiffShapedIndexes = false ;
3803
3814
for (auto indexShapeOneDim : indexesShape) {
3804
3815
if (!llvm::equal (indexesShape[0 ], indexShapeOneDim)) {
3805
- return rewriter.notifyMatchFailure (
3806
- op, " unimplemented: Only support multi indexes with same shape" );
3816
+ hasDiffShapedIndexes = true ;
3817
+ break ;
3818
+ }
3819
+ }
3820
+
3821
+ if (hasDiffShapedIndexes) {
3822
+ int64_t maxRank = 1 ;
3823
+ for (auto idxRank : indexesRank) {
3824
+ if (idxRank > maxRank)
3825
+ maxRank = idxRank;
3826
+ }
3827
+ // Tensor shape of max rank, each dim being 1
3828
+ SmallVector<int64_t > maxRank1DimShape;
3829
+ for (int i = 0 ; i < maxRank; i++)
3830
+ maxRank1DimShape.push_back (1 );
3831
+ // Tensor shape of max rank, each dim being the max dim.
3832
+ SmallVector<int64_t > maxRankMaxDimShape (maxRank1DimShape);
3833
+
3834
+ auto updateMaxRankMaxDimShape =
3835
+ [&](SmallVector<int64_t > broadcastedShape) -> LogicalResult {
3836
+ for (size_t i = 0 ; i < maxRankMaxDimShape.size (); i++) {
3837
+ // check for malformed index tensors
3838
+ if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 &&
3839
+ maxRankMaxDimShape[i] != broadcastedShape[i]) {
3840
+ return failure ();
3841
+ }
3842
+ if (broadcastedShape[i] > maxRankMaxDimShape[i])
3843
+ maxRankMaxDimShape[i] = broadcastedShape[i];
3844
+ }
3845
+ return success ();
3846
+ };
3847
+
3848
+ for (size_t i = 0 ; i < indexesRank.size (); i++) {
3849
+ // Reshape all index tensors to same maxRank
3850
+ auto idxRank = indexesRank[i];
3851
+ auto unreshapedIdxTensor = indicesTfConcatTensors[i];
3852
+ SmallVector<int64_t > broadcastedShape =
3853
+ getRankExtendedShape (indexesShape[i], maxRank1DimShape);
3854
+
3855
+ if (idxRank < maxRank) {
3856
+ auto idxType =
3857
+ dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType ());
3858
+ // indicesTfConcatTensors has a trailing [1] dim for the final concat.
3859
+ auto broadcastedShapeTf (broadcastedShape);
3860
+ broadcastedShapeTf.push_back (1 );
3861
+ auto reshapeOutputTy = RankedTensorType::get (
3862
+ broadcastedShapeTf, idxType.getElementType ());
3863
+ // Update the tensor array with the max rank-extended form
3864
+ indicesTfConcatTensors[i] = rewriter.create <tosa::ReshapeOp>(
3865
+ op->getLoc (), reshapeOutputTy, unreshapedIdxTensor,
3866
+ rewriter.getDenseI64ArrayAttr (broadcastedShapeTf));
3867
+ }
3868
+
3869
+ // Construct the max rank broadcasted form of all index tensors with
3870
+ // each index tensor.
3871
+ if (updateMaxRankMaxDimShape (broadcastedShape).failed ()) {
3872
+ return rewriter.notifyMatchFailure (
3873
+ op, " Malformed index tensors that have mismatched dim shapes" );
3874
+ }
3875
+
3876
+ // Every index now has the same rank but not yet same shape until
3877
+ // tosa.tile below.
3878
+ indexesShape[i] = broadcastedShape;
3879
+ indexesRank[i] = maxRank;
3880
+ }
3881
+
3882
+ auto getTileOpShape = [&](SmallVector<int64_t > indexShape,
3883
+ SmallVector<int64_t > &tileOpShape) -> bool {
3884
+ bool needsTiling = false ;
3885
+ for (size_t i = 0 ; i < indexShape.size (); i++) {
3886
+ if (1 == indexShape[i]) {
3887
+ tileOpShape.push_back (maxRankMaxDimShape[i]);
3888
+ needsTiling = true ;
3889
+ } else {
3890
+ tileOpShape.push_back (1 );
3891
+ }
3892
+ }
3893
+ return needsTiling;
3894
+ };
3895
+
3896
+ // Use tosa.tile to broadcast in multiple dims so all index tensors have
3897
+ // the same shape. This materializes new tensors.
3898
+ for (size_t i = 0 ; i < indexesRank.size (); i++) {
3899
+ SmallVector<int64_t > tileOpShape;
3900
+ bool needsTiling = getTileOpShape (indexesShape[i], tileOpShape);
3901
+
3902
+ if (needsTiling) {
3903
+ auto idxType =
3904
+ dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType ());
3905
+ // indicesTfConcatTensors has a trailing [1] dim for the final concat.
3906
+ auto maxRankMaxDimShapeTf (maxRankMaxDimShape);
3907
+ maxRankMaxDimShapeTf.push_back (1 );
3908
+ auto tileOpShapeTf (tileOpShape);
3909
+ tileOpShapeTf.push_back (1 );
3910
+ auto tileOutputTy = RankedTensorType::get (maxRankMaxDimShapeTf,
3911
+ idxType.getElementType ());
3912
+ auto reshapedIdxTensor = indicesTfConcatTensors[i];
3913
+ indicesTfConcatTensors[i] = rewriter.create <tosa::TileOp>(
3914
+ op->getLoc (), tileOutputTy, reshapedIdxTensor,
3915
+ rewriter.getDenseI64ArrayAttr (tileOpShapeTf));
3916
+ }
3917
+
3918
+ // Every index tensor now has the same rank and shape
3919
+ indexesShape[i] = maxRankMaxDimShape;
3807
3920
}
3808
3921
}
3809
3922
0 commit comments