@@ -8256,6 +8256,198 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
8256
8256
return success();
8257
8257
}
8258
8258
8259
+ // Legalization for aten.unfold
8260
+ template <>
8261
+ LogicalResult ConvertAtenOp<AtenUnfoldOp>::matchAndRewrite(
8262
+ AtenUnfoldOp op, OpAdaptor adaptor,
8263
+ ConversionPatternRewriter &rewriter) const {
8264
+ // Approach: Use GatherOp to retrieve target elements from target dim and then
8265
+ // reshape the output into slices according to the output shape
8266
+ //
8267
+ // Lowering steps:
8268
+ // 1. Create PyTorch-style indices tensor corresponding to target elements and
8269
+ // reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1))
8270
+ // with d_x being the dimension size of the input at dim x.
8271
+ // The indices vector will be calculated using the following formula:
8272
+ // for i in range(d_0 * d_1 * ... * d_(target_dim - 1)):
8273
+ // for window in range(nWindows):
8274
+ // for elementIndex in range(size):
8275
+ // for j in range(d_(target_dim + 1) * ... * d_(rank-1)):
8276
+ // indices_vec.push_back(elementIndex + window * step)
8277
+ // 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices
8278
+ // 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
8279
+ // target elements
8280
+ // 4. Reshape result from above to correct output shape
8281
+ auto self = adaptor.getSelf();
8282
+
8283
+ auto selfType = dyn_cast<TensorType>(self.getType());
8284
+ if (!selfType)
8285
+ return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
8286
+
8287
+ auto selfShape = selfType.getShape();
8288
+ auto selfRank = selfType.getRank();
8289
+ auto selfElemTy = selfType.getElementType();
8290
+
8291
+ auto resultType =
8292
+ dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
8293
+ auto resultElemTy = resultType.getElementType();
8294
+
8295
+ int64_t dim;
8296
+ if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim)))
8297
+ return rewriter.notifyMatchFailure(op,
8298
+ "Only constant int dims are supported");
8299
+
8300
+ int64_t size;
8301
+ if (!matchPattern(op.getSize(), m_TorchConstantInt(&size)))
8302
+ return rewriter.notifyMatchFailure(op,
8303
+ "Only constant int sizes are supported");
8304
+
8305
+ int64_t step;
8306
+ if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
8307
+ return rewriter.notifyMatchFailure(op,
8308
+ "Only constant int steps are supported");
8309
+
8310
+ if (step <= 0)
8311
+ return rewriter.notifyMatchFailure(op, "Step value must be greater than 0");
8312
+
8313
+ // Handle rank zero
8314
+ if (selfRank == 0) {
8315
+ if (dim != 0)
8316
+ return rewriter.notifyMatchFailure(
8317
+ op, "Unsupported dim value for rank zero input");
8318
+
8319
+ if (size != 1)
8320
+ return rewriter.notifyMatchFailure(
8321
+ op, "Unsupported size value for rank zero input");
8322
+
8323
+ auto result = rewriter.create<tosa::ReshapeOp>(
8324
+ op->getLoc(), RankedTensorType::get({1}, selfElemTy), self,
8325
+ rewriter.getDenseI64ArrayAttr({1}));
8326
+
8327
+ rewriter.replaceOp(op, {result.getResult()});
8328
+ return success();
8329
+ }
8330
+
8331
+ dim = toPositiveDim(dim, selfRank);
8332
+ if (!isValidDim(dim, selfRank))
8333
+ return rewriter.notifyMatchFailure(op, "Dim value is invalid");
8334
+
8335
+ // Size of dimension 'dim' in the returned tensor (or number of windows within
8336
+ // the dimension that got sliced)
8337
+ int64_t nWindows = (selfShape[dim] - size) / step + 1;
8338
+
8339
+ // Find number of times that each base index value gets repeated for target
8340
+ // dim based on dim values before and after target dim i.e. preDimAccumulate =
8341
+ // d_0 * d_1 * ... * d_(target_dim - 1)
8342
+ // postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1)
8343
+ int64_t preDimAccumulate =
8344
+ std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1,
8345
+ std::multiplies<int64_t>());
8346
+ int64_t postDimAccumulate =
8347
+ std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1,
8348
+ std::multiplies<int64_t>());
8349
+
8350
+ // Calculate PyTorch-style gather indices vector
8351
+ // Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1
8352
+ // -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2
8353
+ // pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
8354
+ // 1, 1, 1, 2, 2, 2, 3, 3, 3]
8355
+ // pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
8356
+ // 1, 1, 1, 2, 2, 2, 3, 3, 3,
8357
+ // 0, 0, 0, 1, 1, 1, 2, 2, 2,
8358
+ // 1, 1, 1, 2, 2, 2, 3, 3, 3]
8359
+ SmallVector<int32_t> pyTorchIndicesBaseVec;
8360
+ SmallVector<int32_t> pyTorchIndicesVec;
8361
+
8362
+ for (int64_t window = 0; window < nWindows; window++) {
8363
+ for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) {
8364
+ int32_t baseIndex = static_cast<int32_t>(elementIndex + window * step);
8365
+ for (int64_t i = 0; i < postDimAccumulate; i++)
8366
+ pyTorchIndicesBaseVec.push_back(baseIndex);
8367
+ }
8368
+ }
8369
+
8370
+ for (int64_t i = 0; i < preDimAccumulate; i++)
8371
+ pyTorchIndicesVec.insert(pyTorchIndicesVec.end(),
8372
+ pyTorchIndicesBaseVec.begin(),
8373
+ pyTorchIndicesBaseVec.end());
8374
+
8375
+ // Create the PyTorch-style indices tensor
8376
+ // Continuing with the previous example:
8377
+ // pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3)
8378
+ // pyTorchIndices = tensor([[[0, 0, 0],
8379
+ // [1, 1, 1],
8380
+ // [2, 2, 2],
8381
+ // [1, 1, 1],
8382
+ // [2, 2, 2],
8383
+ // [3, 3, 3]],
8384
+ // [[0, 0, 0],
8385
+ // [1, 1, 1],
8386
+ // [2, 2, 2],
8387
+ // [1, 1, 1],
8388
+ // [2, 2, 2],
8389
+ // [3, 3, 3]]])
8390
+ SmallVector<int64_t> pyTorchIndicesShape(selfShape);
8391
+ pyTorchIndicesShape[dim] = nWindows * size;
8392
+ auto pyTorchIndices =
8393
+ tosa::getConstTensor<int32_t>(rewriter, op, pyTorchIndicesVec,
8394
+ pyTorchIndicesShape)
8395
+ .value();
8396
+
8397
+ // Convert PyTorch-style indices to TensorFlow-style indices
8398
+ auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self,
8399
+ pyTorchIndices, dim);
8400
+ if (!tfIndices)
8401
+ return rewriter.notifyMatchFailure(op,
8402
+ "Convert PyTorch-style indices and dim "
8403
+ "to TensorFlow-style indices failed");
8404
+
8405
+ // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
8406
+ // target elements
8407
+ auto gatherNdOp = tosa::convertGatherNdOp(
8408
+ rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy),
8409
+ self, tfIndices.value());
8410
+ if (!gatherNdOp)
8411
+ return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
8412
+
8413
+ // Reshape to an intermediary shape where the gathered elements in dimension
8414
+ // 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size'
8415
+ SmallVector<int64_t> intermediaryShape;
8416
+ for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) {
8417
+ if (currentDim == dim) {
8418
+ intermediaryShape.push_back(nWindows);
8419
+ intermediaryShape.push_back(size);
8420
+ } else {
8421
+ intermediaryShape.push_back(pyTorchIndicesShape[currentDim]);
8422
+ }
8423
+ }
8424
+
8425
+ auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
8426
+ op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy),
8427
+ gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape));
8428
+
8429
+ // Permute dims to the correct result order
8430
+ SmallVector<int32_t> permutedDims;
8431
+ for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) {
8432
+ if (currentDim != dim + 1)
8433
+ permutedDims.push_back(static_cast<int32_t>(currentDim));
8434
+ }
8435
+ permutedDims.push_back(static_cast<int32_t>(dim + 1));
8436
+
8437
+ auto permutedDimsConst = tosa::getConstTensor<int32_t>(
8438
+ rewriter, op,
8439
+ /*vec=*/permutedDims,
8440
+ /*shape=*/{static_cast<int32_t>(selfRank + 1)})
8441
+ .value();
8442
+
8443
+ auto result = rewriter.create<tosa::TransposeOp>(
8444
+ op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst);
8445
+
8446
+ rewriter.replaceOp(op, {result.getResult()});
8447
+
8448
+ return success();
8449
+ }
8450
+
8259
8451
} // namespace
8260
8452
8261
8453
// -----------------------------------------------------------------------------
@@ -8617,6 +8809,7 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
8617
8809
INSERT_ATENOP_PATTERN(AtenLog1pOp);
8618
8810
INSERT_ATENOP_PATTERN(AtenLog10Op);
8619
8811
INSERT_ATENOP_PATTERN(AtenTanOp);
8812
+ INSERT_ATENOP_PATTERN(AtenUnfoldOp);
8620
8813
#undef INSERT_ATENOP_PATTERN
8621
8814
8622
8815
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
0 commit comments