Skip to content

Commit

Permalink
Use sizeBounds instead of domainSizes
Browse files Browse the repository at this point in the history
  • Loading branch information
josel-amd committed Feb 13, 2025
1 parent 6d0b806 commit b02683d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
11 changes: 6 additions & 5 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,12 @@ struct SliceParameters {
///
/// `omitPartialTileCheck` controls whether to omit the partial/boundary tile
/// condition check in cases where we statically know that it is unnecessary.
SliceParameters computeSliceParameters(
OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes = {});
SliceParameters
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck);

/// Computes SliceParamaters for all `valuesToTile` of the given `linalgOp`,
/// assuming `linalgOp` is being fused into a loop nest. Calls
Expand Down
9 changes: 6 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,16 @@ struct LinalgOpTilingInterface
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
// specified could lead to out of bounds accesses.
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<OpFoldResult> allShapeSizes =
linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc());
SmallVector<OpFoldResult> sizeBounds =
mlir::affine::makeComposedFoldedMultiResultAffineApply(
b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes);
SmallVector<Value> valuesToTile = linalgOp->getOperands();
SmallVector<Value> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
b, loc, linalgOp, valuesToTile, offsets, sizes, sizeBounds, true);
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
llvm::make_filter_range(
tiledOperands,
Expand Down
43 changes: 22 additions & 21 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,18 @@ namespace {
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
//
struct TileCheck : public AffineExprVisitor<TileCheck> {
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<int64_t> domainSizes)
: tileSizes(tileSizes), domainSizes(domainSizes) {}
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds)
: tileSizes(tileSizes), sizeBounds(sizeBounds) {}

void visitDimExpr(AffineDimExpr expr) {
unsigned pos = expr.getPosition();

// This dimension is tiled if the tile size is larger than zero and not
// equal to its domain size (if statically known).
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
if (tileSize && !domainSizes.empty()) {
if (domainSizes[pos] == *tileSize) {
if (tileSize && !sizeBounds.empty()) {
std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
if (sizeBound && *sizeBound == *tileSize) {
return;
}
}
Expand All @@ -82,27 +83,27 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
}
bool isTiled = false;
ArrayRef<OpFoldResult> tileSizes;
ArrayRef<int64_t> domainSizes;
ArrayRef<OpFoldResult> sizeBounds;
};

} // namespace

static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> domainSizes) {
ArrayRef<OpFoldResult> sizeBounds) {
if (!expr)
return false;
TileCheck t(tileSizes, domainSizes);
TileCheck t(tileSizes, sizeBounds);
t.visit(expr);
return t.isTiled;
}

// Checks whether the `map varies with respect to a non-zero `tileSize`.
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> domainSizes) {
ArrayRef<OpFoldResult> sizeBounds) {
if (!map)
return false;
for (unsigned r = 0; r < map.getNumResults(); ++r)
if (isTiled(map.getResult(r), tileSizes, domainSizes))
if (isTiled(map.getResult(r), tileSizes, sizeBounds))
return true;
return false;
}
Expand Down Expand Up @@ -571,19 +572,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck,
ArrayRef<int64_t> domainSizes) {
SliceParameters sliceParams = computeSliceParameters(
builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes,
omitPartialTileCheck, domainSizes);
bool omitPartialTileCheck) {
SliceParameters sliceParams =
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
ubs, subShapeSizes, omitPartialTileCheck);
return materializeTiledShape(builder, loc, valueToTile, sliceParams);
}

SliceParameters computeSliceParameters(
OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes) {
SliceParameters
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
assert(shapedType && "only shaped types can be tiled");
ArrayRef<int64_t> shape = shapedType.getShape();
Expand All @@ -600,7 +601,7 @@ SliceParameters computeSliceParameters(
// The offset & size computation below only handles the case when
// the map is monotonically increasing, i.e. the min and max values are
// attained at the lower and upper bounds of the iteration domain.
if (!isTiled(m, tileSizes, domainSizes)) {
if (!isTiled(m, tileSizes, ubs)) {
sliceParams.offsets.push_back(builder.getIndexAttr(0));
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
sliceParams.sizes.push_back(dim);
Expand Down Expand Up @@ -811,7 +812,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,

allSliceParams.push_back(computeSliceParameters(
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
omitPartialTileCheck, linalgOp.getStaticLoopRanges()));
omitPartialTileCheck));
}

return allSliceParams;
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,

// 4a. Clone the operation.
{
auto clonedOp = cast<PartialReductionOpInterface>(rewriter.clone(*op));
auto clonedOp = cast<PartialReductionOpInterface>(
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));

// 4b. Tile the cloned operation.
FailureOr<TilingResult> partialTilingResult =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,12 @@ module {

// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
// CHECK: %[[T3:.*]] = linalg.generic {{.*}}
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>

%8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
// CHECK: tensor.parallel_insert_slice %[[T3]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
}
}
Expand Down

0 comments on commit b02683d

Please sign in to comment.