Skip to content

Commit

Permalink
[mlir][affine] Add ValueBoundsOpInterface to [de]linearize_index (llv…
Browse files Browse the repository at this point in the history
…m#121833)

Since a need for it came up dowstream (in proving that loops run at
least once), this commit implements the ValueBoundsOpInterface for
affine.delinearize_index and affine.linearize_index, using affine map
representations of the operations they perform.

These implementations also use information from outer bounds to impose
additional constraints when those are available.
  • Loading branch information
krzysz00 authored Jan 7, 2025
1 parent 2015c0a commit c6f67b8
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
62 changes: 62 additions & 0 deletions mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,64 @@ struct AffineMaxOpInterface
};
};

struct AffineDelinearizeIndexOpInterface
: public ValueBoundsOpInterface::ExternalModel<
AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
void populateBoundsForIndexValue(Operation *rawOp, Value value,
ValueBoundsConstraintSet &cstr) const {
auto op = cast<AffineDelinearizeIndexOp>(rawOp);
auto result = cast<OpResult>(value);
assert(result.getOwner() == rawOp &&
"bounded value isn't a result of this delinearize_index");
unsigned resIdx = result.getResultNumber();

AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());

SmallVector<OpFoldResult> basis = op.getPaddedBasis();
AffineExpr divisor = cstr.getExpr(1);
for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1))
divisor = divisor * cstr.getExpr(basisElem);

if (resIdx == 0) {
cstr.bound(value) == linearIdx.floorDiv(divisor);
if (!basis.front().isNull())
cstr.bound(value) < cstr.getExpr(basis.front());
return;
}
AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
}
};

struct AffineLinearizeIndexOpInterface
: public ValueBoundsOpInterface::ExternalModel<
AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
void populateBoundsForIndexValue(Operation *rawOp, Value value,
ValueBoundsConstraintSet &cstr) const {
auto op = cast<AffineLinearizeIndexOp>(rawOp);
assert(value == op.getResult() &&
"value isn't the result of this linearize");

AffineExpr bound = cstr.getExpr(0);
AffineExpr stride = cstr.getExpr(1);
SmallVector<OpFoldResult> basis = op.getPaddedBasis();
OperandRange multiIndex = op.getMultiIndex();
unsigned numArgs = multiIndex.size();
for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
unsigned argNum = numArgs - (revArgNum + 1);
if (argNum == 0)
break;
OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
stride = stride * cstr.getExpr(length);
}
bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
cstr.bound(value) == bound;
if (op.getDisjoint() && !basis.front().isNull()) {
cstr.bound(value) < stride *cstr.getExpr(basis.front());
}
}
};
} // namespace
} // namespace mlir

Expand All @@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
AffineDelinearizeIndexOp::attachInterface<
AffineDelinearizeIndexOpInterface>(*ctx);
AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
*ctx);
});
}

Expand Down
81 changes: 81 additions & 0 deletions mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,84 @@ func.func @compare_maps(%a: index, %b: index) {
: (index, index, index, index) -> ()
return
}

// -----

// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
// CHECK-LABEL: func.func @delinearize_static
// CHECK-SAME: (%[[arg0:.+]]: index)
// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
// CHECK: return %[[v1]], %[[v2]], %[[v3]]
func.func @delinearize_static(%arg0: index) -> (index, index, index) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index
%1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
%2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
%3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
// expected-remark @below{{true}}
"test.compare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
// expected-remark @below{{true}}
"test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
return %1, %2, %3 : index, index, index
}

// -----

// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
// CHECK-LABEL: func.func @delinearize_static_no_outer_bound
// CHECK-SAME: (%[[arg0:.+]]: index)
// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
// CHECK: return %[[v1]], %[[v2]], %[[v3]]
func.func @delinearize_static_no_outer_bound(%arg0: index) -> (index, index, index) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
%1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
%2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
%3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
"test.compaare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
// expected-remark @below{{true}}
"test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
return %1, %2, %3 : index, index, index
}

// -----

// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
// CHECK-LABEL: func.func @linearize_static
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
// CHECK: return %[[v1]]
func.func @linearize_static(%arg0: index, %arg1: index) -> index {
%c6 = arith.constant 6 : index
%0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 3) : index
%1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
// expected-remark @below{{true}}
"test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
return %1 : index
}

// -----

// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
// CHECK-LABEL: func.func @linearize_static_no_outer_bound
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
// CHECK: return %[[v1]]
func.func @linearize_static_no_outer_bound(%arg0: index, %arg1: index) -> index {
%c6 = arith.constant 6 : index
%0 = affine.linearize_index disjoint [%arg0, %arg1] by (3) : index
%1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
// expected-error @below{{unknown}}
"test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
return %1 : index
}

0 comments on commit c6f67b8

Please sign in to comment.