diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index 82a9fb0d490882..e93b99b4f49866 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -91,6 +91,64 @@ struct AffineMaxOpInterface }; }; +struct AffineDelinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *rawOp, Value value, + ValueBoundsConstraintSet &cstr) const { + auto op = cast(rawOp); + auto result = cast(value); + assert(result.getOwner() == rawOp && + "bounded value isn't a result of this delinearize_index"); + unsigned resIdx = result.getResultNumber(); + + AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex()); + + SmallVector basis = op.getPaddedBasis(); + AffineExpr divisor = cstr.getExpr(1); + for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1)) + divisor = divisor * cstr.getExpr(basisElem); + + if (resIdx == 0) { + cstr.bound(value) == linearIdx.floorDiv(divisor); + if (!basis.front().isNull()) + cstr.bound(value) < cstr.getExpr(basis.front()); + return; + } + AffineExpr thisBasis = cstr.getExpr(basis[resIdx]); + cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor); + } +}; + +struct AffineLinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *rawOp, Value value, + ValueBoundsConstraintSet &cstr) const { + auto op = cast(rawOp); + assert(value == op.getResult() && + "value isn't the result of this linearize"); + + AffineExpr bound = cstr.getExpr(0); + AffineExpr stride = cstr.getExpr(1); + SmallVector basis = op.getPaddedBasis(); + OperandRange multiIndex = op.getMultiIndex(); + unsigned numArgs = multiIndex.size(); + for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) { + unsigned argNum = numArgs - (revArgNum + 1); + if (argNum == 0) + break; + OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]); + bound = bound + cstr.getExpr(indexAsFoldRes) * stride; + stride = stride * cstr.getExpr(length); + } + bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride; + cstr.bound(value) == bound; + if (op.getDisjoint() && !basis.front().isNull()) { + cstr.bound(value) < stride *cstr.getExpr(basis.front()); + } + } +}; } // namespace } // namespace mlir @@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels( AffineApplyOp::attachInterface(*ctx); AffineMaxOp::attachInterface(*ctx); AffineMinOp::attachInterface(*ctx); + AffineDelinearizeIndexOp::attachInterface< + AffineDelinearizeIndexOpInterface>(*ctx); + AffineLinearizeIndexOp::attachInterface( + *ctx); }); } diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir index 935c08aceff548..5354eb38d7b039 100644 --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -155,3 +155,84 @@ func.func @compare_maps(%a: index, %b: index) { : (index, index, index, index) -> () return } + +// ----- + +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)> +// CHECK-LABEL: func.func @delinearize_static +// CHECK-SAME: (%[[arg0:.+]]: index) +// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]] +// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]] +// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]] +// CHECK: return %[[v1]], %[[v2]], %[[v3]] +func.func @delinearize_static(%arg0: index) -> (index, index, index) { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index + %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index) + %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index) + %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index) + // expected-remark @below{{true}} + "test.compare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> () + // expected-remark @below{{true}} + "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> () + return %1, %2, %3 : index, index, index +} + +// ----- + +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)> +// CHECK-LABEL: func.func @delinearize_static_no_outer_bound +// CHECK-SAME: (%[[arg0:.+]]: index) +// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]] +// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]] +// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]] +// CHECK: return %[[v1]], %[[v2]], %[[v3]] +func.func @delinearize_static_no_outer_bound(%arg0: index) -> (index, index, index) { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index + %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index) + %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index) + %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index) + "test.compaare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> () + // expected-remark @below{{true}} + "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> () + return %1, %2, %3 : index, index, index +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> +// CHECK-LABEL: func.func @linearize_static +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index) +// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]] +// CHECK: return %[[v1]] +func.func @linearize_static(%arg0: index, %arg1: index) -> index { + %c6 = arith.constant 6 : index + %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 3) : index + %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + // expected-remark @below{{true}} + "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> () + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> +// CHECK-LABEL: func.func @linearize_static_no_outer_bound +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index) +// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]] +// CHECK: return %[[v1]] +func.func @linearize_static_no_outer_bound(%arg0: index, %arg1: index) -> index { + %c6 = arith.constant 6 : index + %0 = affine.linearize_index disjoint [%arg0, %arg1] by (3) : index + %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + // expected-error @below{{unknown}} + "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> () + return %1 : index +}