Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use full slices when tiling by the full loop trip count (to support non-monotonic expressions) #468

Merged
merged 10 commits into from
Feb 13, 2025
11 changes: 5 additions & 6 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,11 @@ struct SliceParameters {
///
/// `omitPartialTileCheck` controls whether to omit the partial/boundary tile
/// condition check in cases where we statically know that it is unnecessary.
SliceParameters
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck);
SliceParameters computeSliceParameters(
OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes = {});

/// Computes SliceParamaters for all `valuesToTile` of the given `linalgOp`,
/// assuming `linalgOp` is being fused into a loop nest. Calls
Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/IR/AffineExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,6 @@ class AffineExpr {
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool isPureAffine() const;

/// Returns true if this expression is monotonicically increasing with respect
/// to the AffineDimExprs, i.e. increasing the value of any AffineDimExpr will
/// never decrease the value of the result.
bool isMonotonicallyIncreasing() const;

/// Returns the greatest known integral divisor of this affine expression. The
/// result is always positive.
int64_t getLargestKnownDivisor() const;
Expand Down
4 changes: 0 additions & 4 deletions mlir/include/mlir/IR/AffineMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,6 @@ class AffineMap {
/// Returns true if the AffineMap represents a symbol-less permutation map.
bool isPermutation() const;

// Returns true if every result is monotonically increasing.
// See AffineExpr::isMonotonicallyIncreasing().
bool isComponentWiseMonotonicallyIncreasing() const;

/// Returns the map consisting of the `resultPos` subset.
AffineMap getSubMap(ArrayRef<unsigned> resultPos) const;

Expand Down
56 changes: 35 additions & 21 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,21 @@ namespace {
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
//
struct TileCheck : public AffineExprVisitor<TileCheck> {
TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<int64_t> domainSizes)
: tileSizes(tileSizes), domainSizes(domainSizes) {}

void visitDimExpr(AffineDimExpr expr) {
isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
unsigned pos = expr.getPosition();

// There is no tile if all tile sizes correspond to the domain size
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
if (tileSize && !domainSizes.empty()) {
if (domainSizes[pos] == *tileSize) {
return;
}
}

isTiled |= !isZeroIndex(tileSizes[pos]);
}
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
visit(expr.getLHS());
Expand All @@ -70,24 +81,27 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
}
bool isTiled = false;
ArrayRef<OpFoldResult> tileSizes;
ArrayRef<int64_t> domainSizes;
};

} // namespace

static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> domainSizes) {
if (!expr)
return false;
TileCheck t(tileSizes);
TileCheck t(tileSizes, domainSizes);
t.visit(expr);
return t.isTiled;
}

// Checks whether the `map varies with respect to a non-zero `tileSize`.
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> domainSizes) {
if (!map)
return false;
for (unsigned r = 0; r < map.getNumResults(); ++r)
if (isTiled(map.getResult(r), tileSizes))
if (isTiled(map.getResult(r), tileSizes, domainSizes))
return true;
return false;
}
Expand Down Expand Up @@ -556,19 +570,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
SliceParameters sliceParams =
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
ubs, subShapeSizes, omitPartialTileCheck);
bool omitPartialTileCheck,
ArrayRef<int64_t> domainSizes) {
SliceParameters sliceParams = computeSliceParameters(
builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes,
omitPartialTileCheck, domainSizes);
return materializeTiledShape(builder, loc, valueToTile, sliceParams);
}

SliceParameters
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
SliceParameters computeSliceParameters(
OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes) {
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
assert(shapedType && "only shaped types can be tiled");
ArrayRef<int64_t> shape = shapedType.getShape();
Expand All @@ -585,7 +599,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
// The offset & size computation below only handles the case when
// the map is monotonically increasing, i.e. the min and max values are
// attained at the lower and upper bounds of the iteration domain.
if (!isTiled(m, tileSizes) || !m.isComponentWiseMonotonicallyIncreasing()) {
if (!isTiled(m, tileSizes, domainSizes)) {
sliceParams.offsets.push_back(builder.getIndexAttr(0));
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
sliceParams.sizes.push_back(dim);
Expand Down Expand Up @@ -784,10 +798,10 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
// transformations such as padding and bufferization since the
// extract/insert slice pairs make the accessed iteration argument
// subdomains explicit.

Type operandType = opOperand.get().getType();
if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
linalgOp.isDpsInit(&opOperand))) {
if (!isTiled(map, tileSizes, linalgOp.getStaticLoopRanges()) &&
!(isa<RankedTensorType>(operandType) &&
linalgOp.isDpsInit(&opOperand))) {
allSliceParams.push_back(std::nullopt);
LLVM_DEBUG(llvm::dbgs()
<< ": not tiled: use shape: " << operandType << "\n");
Expand All @@ -797,7 +811,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,

allSliceParams.push_back(computeSliceParameters(
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
omitPartialTileCheck));
omitPartialTileCheck, linalgOp.getStaticLoopRanges()));
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
}

return allSliceParams;
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,

// 4a. Clone the operation.
{
auto clonedOp = cast<PartialReductionOpInterface>(
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
auto clonedOp = cast<PartialReductionOpInterface>(rewriter.clone(*op));

// 4b. Tile the cloned operation.
FailureOr<TilingResult> partialTilingResult =
Expand Down
36 changes: 0 additions & 36 deletions mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,42 +239,6 @@ bool AffineExpr::isPureAffine() const {
llvm_unreachable("Unknown AffineExpr");
}

static bool isNonNegativeConstant(AffineExpr expr) {
auto constant = dyn_cast<AffineConstantExpr>(expr);
return constant && constant.getValue() >= 0;
}

bool AffineExpr::isMonotonicallyIncreasing() const {
switch (getKind()) {
case AffineExprKind::SymbolId:
case AffineExprKind::DimId:
case AffineExprKind::Constant:
return true;
case AffineExprKind::Add: {
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
return op.getLHS().isMonotonicallyIncreasing() &&
op.getRHS().isMonotonicallyIncreasing();
}
case AffineExprKind::Mul: {
// One operand must be a non-negative constant.
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
return op.getLHS().isMonotonicallyIncreasing() &&
op.getRHS().isMonotonicallyIncreasing() &&
(isNonNegativeConstant(op.getLHS()) ||
isNonNegativeConstant(op.getRHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
return op.getLHS().isMonotonicallyIncreasing() &&
isNonNegativeConstant(op.getRHS());
}
case AffineExprKind::Mod:
return false;
}
llvm_unreachable("Unknown AffineExpr");
}

// Returns the greatest known integral divisor of this affine expression.
int64_t AffineExpr::getLargestKnownDivisor() const {
AffineBinaryOpExpr binExpr(nullptr);
Expand Down
5 changes: 0 additions & 5 deletions mlir/lib/IR/AffineMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,6 @@ bool AffineMap::isPermutation() const {
return isProjectedPermutation();
}

bool AffineMap::isComponentWiseMonotonicallyIncreasing() const {
return all_of(getResults(),
[](auto expr) { return expr.isMonotonicallyIncreasing(); });
}

AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const {
SmallVector<AffineExpr, 4> exprs;
exprs.reserve(resultPos.size());
Expand Down
26 changes: 13 additions & 13 deletions mlir/test/Dialect/Linalg/tile-tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -171,31 +171,31 @@ module attributes {transform.with_named_sequence} {
// -----

// CHECK-LABEL: func @non_monotonic_affine_expr
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
func.func @non_monotonic_affine_expr(%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<7xf32>
func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.dim %arg0, %c0 : tensor<?xf32>
%empty = tensor.empty(%0) : tensor<?xf32>
%0 = tensor.dim %arg0, %c0 : tensor<7xf32>
%empty = tensor.empty() : tensor<7xf32>

// CHECK: scf.for
// CHECK: %[[SIZE:[a-zA-Z0-9_]+]] = tensor.dim %[[ARG0]],
// CHECK: tensor.extract_slice %[[ARG0]][0] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32>
// CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) {
// CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32>
%generic = linalg.generic
{indexing_maps = [affine_map<(d0) -> (d0 mod 3)>,
{indexing_maps = [affine_map<(d0) -> (d0 mod 4)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0: tensor<?xf32>)
outs(%empty : tensor<?xf32>) {
ins(%arg0: tensor<7xf32>)
outs(%empty : tensor<7xf32>) {
^bb0(%in : f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?xf32>
return %generic : tensor<?xf32>
} -> tensor<7xf32>
return %generic : tensor<7xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [100] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [7] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ func.func @pad_and_hoist_rhs(
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
-> tensor<24x25xf32>
{
// expected-note @below {{target op}}
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
Expand All @@ -24,10 +25,11 @@ module attributes {transform.with_named_sequence} {

// In this case, the pad op is actually empty: we only tile the first dimension
// and it does not have an impact on the RHS operand.
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
// expected-error @below {{could not find a producer for operand number: 1}}
%pad = transform.get_producer_of_operand %matmul_padded[1]
: (!transform.any_op) -> !transform.any_op

// expected-error @below {{requires exactly 2 non-null handles}}
// We do not even reach this transform op.
transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1
: (!transform.any_op, !transform.any_op) -> !transform.any_op
transform.yield
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ func.func @pad_and_hoist_rhs(
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
-> tensor<24x25xf32>
{
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
// expected-note @below {{payload operation}}
// expected-note @below {{target op}}
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
Expand All @@ -25,7 +25,7 @@ module attributes {transform.with_named_sequence} {

// In this case, the pad op is actually empty: we only tile the first dimension
// and it does not have an impact on the RHS operand.
// expected-error @below {{incompatible payload operation name}}
// expected-error @below {{could not find a producer for operand number: 1}}
%pad = transform.get_producer_of_operand %matmul_padded[1]
: (!transform.any_op) -> !transform.op<"tensor.pad">

Expand Down
Loading