From b02683dd3e790a3f1457e9481dd8d0b87298b00e Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Thu, 13 Feb 2025 12:15:53 +0000 Subject: [PATCH] Use sizeBounds instead of domainSizes --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 11 ++--- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 9 ++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 43 ++++++++++--------- .../SCF/Transforms/TileUsingInterface.cpp | 3 +- .../transform-op-fuse-into-containing.mlir | 3 +- 5 files changed, 38 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 1fdc55d7e58c9..1e4f3004dec7e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -143,11 +143,12 @@ struct SliceParameters { /// /// `omitPartialTileCheck` controls whether to omit the partial/boundary tile /// condition check in cases where we statically know that it is unnecessary. -SliceParameters computeSliceParameters( - OpBuilder &builder, Location loc, Value valueToTile, - ArrayRef tileSizes, AffineMap map, ArrayRef lbs, - ArrayRef ubs, ArrayRef subShapeSizes, - bool omitPartialTileCheck, ArrayRef domainSizes = {}); +SliceParameters +computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, + ArrayRef tileSizes, AffineMap map, + ArrayRef lbs, ArrayRef ubs, + ArrayRef subShapeSizes, + bool omitPartialTileCheck); /// Computes SliceParamaters for all `valuesToTile` of the given `linalgOp`, /// assuming `linalgOp` is being fused into a loop nest. Calls diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index f86715a94b268..8044405645d44 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -115,13 +115,16 @@ struct LinalgOpTilingInterface getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { - // Leave the `sizeBounds` value empty. That is only needed when the `sizes` - // specified could lead to out of bounds accesses. Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); + SmallVector allShapeSizes = + linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc()); + SmallVector sizeBounds = + mlir::affine::makeComposedFoldedMultiResultAffineApply( + b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes); SmallVector valuesToTile = linalgOp->getOperands(); SmallVector tiledOperands = makeTiledShapes( - b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); + b, loc, linalgOp, valuesToTile, offsets, sizes, sizeBounds, true); SmallVector generatedSlices = llvm::map_to_vector( llvm::make_filter_range( tiledOperands, diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c185e0831d9bc..8e898904d87c2 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -56,8 +56,8 @@ namespace { // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] // struct TileCheck : public AffineExprVisitor { - TileCheck(ArrayRef tileSizes, ArrayRef domainSizes) - : tileSizes(tileSizes), domainSizes(domainSizes) {} + TileCheck(ArrayRef tileSizes, ArrayRef sizeBounds) + : tileSizes(tileSizes), sizeBounds(sizeBounds) {} void visitDimExpr(AffineDimExpr expr) { unsigned pos = expr.getPosition(); @@ -65,8 +65,9 @@ struct TileCheck : public AffineExprVisitor { // This dimension is tiled if the tile size is larger than zero and not // equal to its domain size (if statically known). std::optional tileSize = getConstantIntValue(tileSizes[pos]); - if (tileSize && !domainSizes.empty()) { - if (domainSizes[pos] == *tileSize) { + if (tileSize && !sizeBounds.empty()) { + std::optional sizeBound = getConstantIntValue(sizeBounds[pos]); + if (sizeBound && *sizeBound == *tileSize) { return; } } @@ -82,27 +83,27 @@ struct TileCheck : public AffineExprVisitor { } bool isTiled = false; ArrayRef tileSizes; - ArrayRef domainSizes; + ArrayRef sizeBounds; }; } // namespace static bool isTiled(AffineExpr expr, ArrayRef tileSizes, - ArrayRef domainSizes) { + ArrayRef sizeBounds) { if (!expr) return false; - TileCheck t(tileSizes, domainSizes); + TileCheck t(tileSizes, sizeBounds); t.visit(expr); return t.isTiled; } // Checks whether the `map varies with respect to a non-zero `tileSize`. static bool isTiled(AffineMap map, ArrayRef tileSizes, - ArrayRef domainSizes) { + ArrayRef sizeBounds) { if (!map) return false; for (unsigned r = 0; r < map.getNumResults(); ++r) - if (isTiled(map.getResult(r), tileSizes, domainSizes)) + if (isTiled(map.getResult(r), tileSizes, sizeBounds)) return true; return false; } @@ -571,19 +572,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef lbs, ArrayRef ubs, ArrayRef subShapeSizes, - bool omitPartialTileCheck, - ArrayRef domainSizes) { - SliceParameters sliceParams = computeSliceParameters( - builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes, - omitPartialTileCheck, domainSizes); + bool omitPartialTileCheck) { + SliceParameters sliceParams = + computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, + ubs, subShapeSizes, omitPartialTileCheck); return materializeTiledShape(builder, loc, valueToTile, sliceParams); } -SliceParameters computeSliceParameters( - OpBuilder &builder, Location loc, Value valueToTile, - ArrayRef tileSizes, AffineMap map, ArrayRef lbs, - ArrayRef ubs, ArrayRef subShapeSizes, - bool omitPartialTileCheck, ArrayRef domainSizes) { +SliceParameters +computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, + ArrayRef tileSizes, AffineMap map, + ArrayRef lbs, ArrayRef ubs, + ArrayRef subShapeSizes, + bool omitPartialTileCheck) { auto shapedType = dyn_cast(valueToTile.getType()); assert(shapedType && "only shaped types can be tiled"); ArrayRef shape = shapedType.getShape(); @@ -600,7 +601,7 @@ SliceParameters computeSliceParameters( // The offset & size computation below only handles the case when // the map is monotonically increasing, i.e. the min and max values are // attained at the lower and upper bounds of the iteration domain. - if (!isTiled(m, tileSizes, domainSizes)) { + if (!isTiled(m, tileSizes, ubs)) { sliceParams.offsets.push_back(builder.getIndexAttr(0)); OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); sliceParams.sizes.push_back(dim); @@ -811,7 +812,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, allSliceParams.push_back(computeSliceParameters( builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, - omitPartialTileCheck, linalgOp.getStaticLoopRanges())); + omitPartialTileCheck)); } return allSliceParams; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 47a9462eb3a05..846c2064d87b4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -984,7 +984,8 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, // 4a. Clone the operation. { - auto clonedOp = cast(rewriter.clone(*op)); + auto clonedOp = cast( + cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); // 4b. Tile the cloned operation. FailureOr partialTilingResult = diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index 2cea815ac2b04..a677079cec0cf 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -555,11 +555,12 @@ module { // CHECK: %[[T1:.*]] = linalg.generic {{.*}} // CHECK: %[[T2:.*]] = linalg.generic {{.*}} + // CHECK: %[[T3:.*]] = linalg.generic {{.*}} %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor to tensor %8 = linalg.elemwise_unary ins(%7 : tensor) outs(%6 : tensor) -> tensor scf.forall.in_parallel { - // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor + // CHECK: tensor.parallel_insert_slice %[[T3]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor into tensor } }