From 78ec7f2e5069a104c28a243218e27b5f5f653b73 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Wed, 12 Feb 2025 20:17:19 -0800 Subject: [PATCH] [DispatchCreation] Changes to dispatch region in preparation for horizontal fusion changes. (#19876) Current dispatch region formation handles consumer fusion by making the consumer the root of the DAG moved into dispatches. For cases where we have more than one consumer that dont have a direct dependency, this approach does not work. This changes dispatch region formation to keep the root operation as is, and move in consumers into the dispatch iteratively. This required a few additional changes 1) Move the method `moveOperandDefs` into a utility function. 2) Changes to how the dynamic dims of results of `flow.dispatch.region` created are resolved. --------- Signed-off-by: MaheshRavishankar Signed-off-by: Ian Wood Co-authored-by: Ian Wood --- .../Flow/Transforms/FormDispatchRegions.cpp | 12 +- .../Dialect/Flow/Transforms/RegionOpUtils.cpp | 34 +-- .../Dialect/LinalgExt/Utils/Utils.cpp | 7 +- .../compiler/Dialect/LinalgExt/Utils/Utils.h | 2 +- .../DispatchCreation/FormDispatchRegions.cpp | 198 +++++++++++------ .../compiler/DispatchCreation/FusionUtils.cpp | 43 ++++ .../compiler/DispatchCreation/FusionUtils.h | 9 + .../test/dispatch_linalg_on_tensors.mlir | 12 +- .../test/form_dispatch_regions.mlir | 203 ++++++++++++++++++ 9 files changed, 425 insertions(+), 95 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index 3dbc95149068..72443724b6a5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -45,7 +45,8 @@ void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) { void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op, InsertPoint previous) { IRRewriter::Listener::notifyOperationInserted(op, previous); - if (isa(op)) + auto dimOp = dyn_cast(op); + if (dimOp && isa(dimOp.getSource())) dimOps.insert(op); } @@ -60,16 +61,21 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, std::optional idx = dimOp.getConstantIndex(); if (!idx.has_value()) continue; + + if (isa(dimOp.getSource())) { + continue; + } + // Only DimOps with ranked tensors are supported. auto tensorType = llvm::dyn_cast(dimOp.getSource().getType()); if (!tensorType) continue; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(dimOp); if (!tensorType.isDynamicDim(*idx)) { // Rewrite static dimension with constant. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(dimOp); int64_t size = tensorType.getShape()[*idx]; rewriter.replaceOpWithNewOp(dimOp, size); continue; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 3fb9bdb3214c..7974609f8f32 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -266,18 +266,8 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, // Value is an OpResult. Operation *op = value.getDefiningOp(); OpResult opResult = llvm::cast(value); - b.setInsertionPoint(op); - // Case 3: Value is tied. Reify the dimensions of the tied operand. - auto tiedOp = dyn_cast(op); - if (tiedOp) { - Value tiedOperand = tiedOp.getTiedResultOperand(value); - if (tiedOperand && tiedOperand.getType() == value.getType()) - return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims, - createTensorDimOps); - } - - // Case 4: Query ShapeAwareOpInterface. + // Case 3: Query ShapeAwareOpInterface. auto shapeAwareOp = dyn_cast(op); if (shapeAwareOp) { ValueRange dims = @@ -286,6 +276,15 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, return success(); } + // Case 4: Value is tied. Reify the dimensions of the tied operand. + auto tiedOp = dyn_cast(op); + if (tiedOp) { + Value tiedOperand = tiedOp.getTiedResultOperand(value); + if (tiedOperand && tiedOperand.getType() == value.getType()) + return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims, + /*createTensorDimOps=*/true); + } + // Case 5: Query ReifyRankedShapedTypeOpInterface. auto reifyShapeOp = dyn_cast(op); if (reifyShapeOp) { @@ -308,8 +307,14 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, } /// Reify the dynamic dimensions of the given value. +/// Deprecated. Use `getOptimizedDynamicResultDims` instead. LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value, SmallVectorImpl &dynamicDims) { + + OpBuilder::InsertionGuard g(b); + if (auto op = value.getDefiningOp()) { + b.setInsertionPoint(op); + } return reifyDynamicResultDimsImpl(b, value, dynamicDims, /*createTensorDimOps=*/true); } @@ -473,7 +478,7 @@ movePrecedingOpsIntoDispatchRegion(RewriterBase &rewriter, rewriter.setInsertionPoint(target); SmallVector &dims = dispatchOpNewResultsDynamicDims.emplace_back(); - if (failed(reifyDynamicResultDims(rewriter, result, dims))) { + if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) { return target->emitOpError( "failed to reify dynamic dims of result to be yielded from " "dispatch region"); @@ -554,9 +559,10 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target, for (auto [index, result] : llvm::enumerate(target->getResults())) { replacedValues.push_back(result); yieldedResults.push_back(clonedTarget->getResult(index)); - rewriter.setInsertionPoint(target); + OpBuilder::InsertionGuard g1(rewriter); + rewriter.setInsertionPoint(regionOp); SmallVector &dims = dispatchOpNewResultsDynamicDims.emplace_back(); - if (failed(reifyDynamicResultDims(rewriter, result, dims))) { + if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) { return target->emitOpError( "failed to reify dynamic dims of result to be yielded from " "dispatch region"); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index c8896967b7f8..3e8ae66d680b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -692,7 +692,12 @@ isContractionOpSequence(Value yielded) { /// Recognize an operation that is horizontally fused contraction. /// TODO: The logic below is quite convoluted. Might be better /// off having a dedicated operation for this. -bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) { +bool isaHorizontallyFusedContraction(Operation *op) { + auto linalgOp = dyn_cast_or_null(op); + if (!linalgOp) { + return false; + } + if (linalgOp->getNumResults() == 1) { return false; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index c02699f80c25..48ecf9202d95 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -214,7 +214,7 @@ bool isGatherlikeOp(Operation *op); /// Check if a given operation is a horizontally fused contraction operation. /// The expectation is that the LHS is common, and all the operands are /// different RHS. -bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp); +bool isaHorizontallyFusedContraction(Operation *op); } // namespace mlir::iree_compiler::IREE::LinalgExt #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_ diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 562c779ba71d..b2de618764fd 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -24,6 +25,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -39,6 +41,7 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-dispatch-creation-form-dispatch-regions" @@ -330,36 +333,34 @@ static bool hasCompatibleOuterParallelLoops( consumerIndexingMap, rootOuterParallelLoops); } -/// For all uses of an operation, finds the use that dominates all other uses. -static std::optional -getFusableUse(Operation *op, DominanceInfo const &dominanceInfo, - bool aggressiveFusion) { +/// For all uses of an operation, return the uses that could be fused. +/// The returned vector contains the uses in dominance order. +static SmallVector +getFusableUses(MLIRContext *context, Operation *op, + DominanceInfo const &dominanceInfo, bool aggressiveFusion) { if (!aggressiveFusion && llvm::count_if(op->getUses(), [](OpOperand &use) { return !isa(use.getOwner()); }) != 1) { - return std::nullopt; + return {}; } - // Collect non-dim users. - SmallVector nonDimUsers; - for (Operation *user : op->getUsers()) { - if (isa(user)) - continue; - nonDimUsers.push_back(user); - } - - // Find the use in a non-dim user that dominates all other non-dim users. - for (auto &use : op->getUses()) { + // Collect all fusable user candidates. + SetVector fusableUses; + for (OpOperand &use : op->getUses()) { Operation *user = use.getOwner(); - if (isa(user)) + if (isa(user)) { continue; - if (llvm::all_of(nonDimUsers, [&](Operation *c) { - return dominanceInfo.dominates(user, c); - })) { - return &use; } + fusableUses.insert(&use); } - return std::nullopt; + + SmallVector usesVec = fusableUses.takeVector(); + llvm::sort(usesVec, [&](OpOperand *lhsUse, OpOperand *rhsUse) { + return dominanceInfo.properlyDominates(lhsUse->getOwner(), + rhsUse->getOwner()); + }); + + return usesVec; } /// Returns true if the operands are fusable. @@ -576,38 +577,39 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef roots, for (Operation *root : roots) { SmallVector workList; llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root); + int64_t rootNumber = getRootNumber(root); workList.push_back(root); while (!workList.empty()) { Operation *currRoot = workList.pop_back_val(); - assert(hasRootOpAttribute(currRoot) && - "unexpected non-root op in worklist"); - - // Helper function to make the consumer the root instead of the producer - // when they are to be fused. - auto updateRootTo = [&context, &currRoot](Operation *newRoot) { - int64_t rootNumber = getRootNumber(currRoot); - setRootAttribute(context, newRoot, rootNumber); - removeRootOpAttribute(currRoot); - appendToFusionGroup(currRoot, rootNumber); - }; - - std::optional fusableUse = - getFusableUse(currRoot, dominanceInfo, - /*aggressiveFusion=*/options.aggressiveFusion); - if (!fusableUse) - continue; - // Analyse the use to see if it is fusable. - Operation *consumerOp = fusableUse.value()->getOwner(); - if (hasRootOpAttribute(consumerOp) || - hasFusionGroupsAttribute(consumerOp)) { + SmallVector fusableUses = + getFusableUses(context, currRoot, dominanceInfo, + /*aggressiveFusion=*/options.aggressiveFusion); + if (fusableUses.empty()) continue; + + // For now disable the fusing with multiple consumers for all + // operations other than horizontally fused gemms. This should + // work in general but is causing time-outs on some CI examples. + if (!IREE::LinalgExt::isaHorizontallyFusedContraction(root)) { + fusableUses = {fusableUses.front()}; } - if (isFusableWithConsumer(*(fusableUse.value()), rootOuterParallelLoops, - options)) { - updateRootTo(consumerOp); - workList.push_back(consumerOp); + // Analyse the use to see if it is fusable. + for (OpOperand *fusableUse : fusableUses) { + Operation *consumerOp = fusableUse->getOwner(); + if (hasRootOpAttribute(consumerOp) || + hasFusionGroupsAttribute(consumerOp)) { + continue; + } + + if (isFusableWithConsumer(*fusableUse, rootOuterParallelLoops, + options)) { + appendToFusionGroup(consumerOp, rootNumber); + workList.push_back(consumerOp); + } else { + break; + } } } } @@ -701,16 +703,16 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum, continue; } - std::optional fusableUse = - getFusableUse(producer, dominanceInfo, - /*aggressiveFusion=*/options.aggressiveFusion); - if (!fusableUse || fusableUse.value()->getOwner() != candidate) - continue; - if (!isFusableWithProducer(operand, rootOuterParallelLoops, options)) { continue; } + SmallVector fusableUses = + getFusableUses(context, producer, dominanceInfo, + /*aggressiveFusion=*/options.aggressiveFusion); + if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate) + continue; + appendToFusionGroup(producer, groupNum); worklist.push_back(producer); } @@ -814,14 +816,14 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, static LogicalResult createFusionGroups(TensorDimTrackingRewriter &rewriter, mlir::FunctionOpInterface funcOp, - DominanceInfo const &dominanceInfo, + DominanceInfo &dominanceInfo, FormDispatchRegionsPassOptions const &options) { // Step 1: Decide fusion groups (heuristic). This marks rootOps with an // attribute unsigned numRoots = decideFusableLinalgOps(funcOp.getFunctionBody(), dominanceInfo, options); SmallVector roots(numRoots, nullptr); - DenseMap> producers; + DenseMap> fusedOperations; LLVM_DEBUG({ llvm::dbgs() << "\n--- After deciding fusion groups ---\n"; @@ -834,11 +836,12 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, funcOp.walk([&](Operation *op) { if (hasRootOpAttribute(op)) { roots[getRootNumber(op)] = op; + fusedOperations[getRootNumber(op)].push_back(op); removeRootOpAttribute(op); } if (hasFusionGroupsAttribute(op)) { assert(getFusionGroups(op).size() == 1 && "expected exactly one group"); - producers[getFusionGroups(op).front()].push_back(op); + fusedOperations[getFusionGroups(op).front()].push_back(op); removeFusionGroupsAttribute(op); } }); @@ -846,7 +849,32 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, // Step 2. Create a DispatchRegionOp for every fusion group. OpBuilder::InsertionGuard g(rewriter); SmallVector regionOps; - for (const auto &it : llvm::enumerate(roots)) { + for (auto [rootIndex, root] : llvm::enumerate(roots)) { + + // Sort producers and consumers topologically. All fused ops must be in the + // same block as the root. + SmallVector &currFusedOperations = fusedOperations[rootIndex]; + bool sortResult = mlir::computeTopologicalSorting(currFusedOperations); + (void)sortResult; + assert(sortResult && "could not compute topological sorting"); + + int rootPos = 0; + for (auto [index, fusedOperation] : llvm::enumerate(currFusedOperations)) { + if (fusedOperation == root) { + rootPos = index; + break; + } + } + SmallVector producers, consumers; + if (rootPos > 0) { + producers = llvm::to_vector( + ArrayRef(currFusedOperations).take_front(rootPos)); + } + if (rootPos < currFusedOperations.size() - 1) { + consumers = llvm::to_vector( + ArrayRef(currFusedOperations).drop_front(rootPos + 1)); + } + // Simplify tensor::DimOps. { SmallVector dimOps = rewriter.getTensorDimOps(); @@ -857,20 +885,14 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, // Create fusion group. IREE::Flow::DispatchRegionOp regionOp; - auto maybeRegionOp = - IREE::Flow::wrapOpInDispatchRegion(rewriter, it.value()); - if (failed(maybeRegionOp)) - return failure(); + auto maybeRegionOp = IREE::Flow::wrapOpInDispatchRegion(rewriter, root); + if (failed(maybeRegionOp)) { + return root->emitOpError("failed to move root into dispatch"); + } regionOp = *maybeRegionOp; - // Sort producers topologically. All producers must be in the same block - // as the root. - bool sortResult = mlir::computeTopologicalSorting(producers[it.index()]); - (void)sortResult; - assert(sortResult && "could not compute topological sorting"); - // Move ops into the region. - for (Operation *producer : llvm::reverse(producers[it.index()])) { + for (Operation *producer : llvm::reverse(producers)) { // Simplify tensor::DimOps. { SmallVector dimOps = rewriter.getTensorDimOps(); @@ -881,8 +903,31 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, auto newRegionOp = movePrecedingOpsIntoDispatchRegion(rewriter, producer, regionOp); - if (failed(newRegionOp)) - return failure(); + if (failed(newRegionOp)) { + return producer->emitOpError("failed to move producer into region"); + } + regionOp = *newRegionOp; + } + + for (Operation *consumer : consumers) { + // Simplify tensor::DimOps. + { + SmallVector dimOps = rewriter.getTensorDimOps(); + if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) { + return failure(); + } + } + + if (failed(moveOperandDefs(rewriter, consumer, regionOp, dominanceInfo, + regionOp.getOperation()))) { + continue; + } + + auto newRegionOp = IREE::Flow::moveFollowingOpIntoDispatchRegion( + rewriter, consumer, regionOp); + if (failed(newRegionOp)) { + return consumer->emitOpError("failed to move consumer into region"); + } regionOp = *newRegionOp; } // Simplify tensor::DimOps. @@ -916,7 +961,7 @@ struct FormDispatchRegionsPass final /// Create dispatch.region Ops based on a fusion heuristic. void FormDispatchRegionsPass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); - DominanceInfo const &dominanceInfo = getAnalysis(); + DominanceInfo &dominanceInfo = getAnalysis(); TensorDimTrackingRewriter rewriter(funcOp); FormDispatchRegionsPassOptions options{aggressiveFusion, fusePadWithConsumers, fusePadWithProducers}; @@ -924,5 +969,18 @@ void FormDispatchRegionsPass::runOnOperation() { funcOp->emitOpError("failed to create fusion groups"); return signalPassFailure(); } + + // Canonicalize all the dispatch regions to remove unused operands. + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); + IREE::Flow::DispatchRegionOp::getCanonicalizationPatterns(patterns, context); + GreedyRewriteConfig config; + config.maxIterations = GreedyRewriteConfig::kNoLimit; + config.fold = true; + if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) { + funcOp.emitOpError("failed in cleanup patterns"); + return signalPassFailure(); + } } } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index c428091f6cf8..1a5e5f51c720 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -10,7 +10,10 @@ #include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h" #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir::iree_compiler::DispatchCreation { @@ -97,4 +100,44 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } +LogicalResult moveOperandDefs(RewriterBase &rewriter, + ArrayRef operations, + Operation *insertionPoint, + DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations) { + BackwardSliceOptions options; + llvm::DenseSet ignoreOperationsSet; + ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, insertionPoint) && + !ignoreOperationsSet.contains(op); + }; + // Set inclusive to true cause the slice is computed from the operand, and + // we want to include the defining op (which is the point here) + options.omitUsesFromAbove = false; + options.inclusive = true; + + llvm::SetVector slice; + for (auto op : operations) { + for (auto operand : op->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + auto regions = op->getRegions(); + if (regions.empty()) { + continue; + } + llvm::SetVector capturedVals; + mlir::getUsedValuesDefinedAbove(regions, capturedVals); + for (auto value : capturedVals) { + getBackwardSlice(value, &slice, options); + } + } + + mlir::topologicalSort(slice); + for (auto op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index 1d9c9306f7ae..f99d7ae6ff24 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -10,7 +10,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir::iree_compiler::DispatchCreation { @@ -19,4 +21,11 @@ namespace mlir::iree_compiler::DispatchCreation { bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool fuseMultiReduction); +/// Move the definition of operands of `operations` before `insertionPoint`. +LogicalResult moveOperandDefs(RewriterBase &rewriter, + ArrayRef operations, + Operation *insertionPoint, + DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations = {}); + } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir index 16beb5b01599..1b961104b882 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir @@ -1431,9 +1431,9 @@ util.func public @multi_use_producer_fusion(%arg0 : tensor, %arg1 : ten // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[MATMUL]], %[[BIAS]] : // CHECK-SAME: outs(%[[INIT]] : -// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[RESULT0]] -// CHECK-DAG: flow.dispatch.tensor.store %[[MATMUL]], %[[RESULT1]] -// CHECK: util.return %[[DISPATCH]]#1, %[[DISPATCH]]#0 +// CHECK-DAG: flow.dispatch.tensor.store %[[MATMUL]], %[[RESULT0]] +// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[RESULT1]] +// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1 // ----- @@ -1535,9 +1535,9 @@ util.func public @fuse_conv2d_with_multiple_uses(%input: tensor<1x225x225x16xf32 // CHECK-SAME: %[[OUT2:.+]]: !flow.dispatch.tensor> // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[OUT1]] -// CHECK-DAG: flow.dispatch.tensor.store %[[CONV]], %[[OUT2]] -// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1 +// CHECK-DAG: flow.dispatch.tensor.store %[[CONV]], %[[OUT1]] +// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[OUT2]] +// CHECK: util.return %[[DISPATCH]]#1, %[[DISPATCH]]#0 // ----- diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir index 9934972ca0f5..4a1b808a9127 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir @@ -1008,3 +1008,206 @@ util.func @scatter_index_producer_fusion(%arg0 : tensor, // CHECK-SAME: ins(%{{.+}}, %[[GENERIC]] : // CHECK: flow.return %[[SCATTER]] // CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @move_captured_from_above_ops(%arg0 : tensor<1x1x2x4xi32>, + %arg1 : f64, %arg2 : f64) -> tensor<2x3xi8> { + %empty = tensor.empty() : tensor<2x3xi32> + %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1] + inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<1x1x2x4xi32> -> tensor<2x3xi32> + %0 = arith.mulf %arg1, %arg2 : f64 + %1 = tensor.empty() : tensor<2x3xi8> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%unpack : tensor<2x3xi32>) outs(%1 : tensor<2x3xi8>) { + ^bb0(%in: i32, %out: i8): + %3 = arith.sitofp %in : i32 to f32 + %4 = arith.truncf %0 : f64 to f32 + %5 = arith.mulf %3, %4 : f32 + %48 = arith.fptosi %5 : f32 to i8 + linalg.yield %48 : i8 + } -> tensor<2x3xi8> + util.return %2 : tensor<2x3xi8> +} +// CHECK-LABEL: func public @move_captured_from_above_ops +// CHECK: %[[OP:.+]] = arith.mulf +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[UNPACK]] : +// CHECK: %[[TRUNCF:.+]] = arith.truncf %[[OP]] +// CHECK: linalg.yield +// CHECK: flow.return %[[GENERIC]] +// CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @horizontal_fusion1(%lhs : tensor<2x4096x640xf16>, + %rhs0 : tensor<10x64x640xf16>, %rhs1 : tensor<10x64x640xf16>, + %rhs2 : tensor<10x64x640xf16>) -> + (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, + tensor<2x10x4096x64xf16>) { + %4 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0 : f32 + %5 = linalg.fill ins(%cst : f32) + outs(%4 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %6:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs0, %rhs1, %rhs2 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%5, %5, %5 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %14 = arith.extf %in : f16 to f32 + %15 = arith.extf %in_0 : f16 to f32 + %16 = arith.mulf %14, %15 : f32 + %17 = arith.addf %out, %16 : f32 + %18 = arith.extf %in_1 : f16 to f32 + %19 = arith.mulf %14, %18 : f32 + %20 = arith.addf %out_3, %19 : f32 + %21 = arith.extf %in_2 : f16 to f32 + %22 = arith.mulf %14, %21 : f32 + %23 = arith.addf %out_4, %22 : f32 + linalg.yield %17, %20, %23 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) + %7 = tensor.empty() : tensor<2x10x4096x64xf16> + %8 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%6#0 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %14 = arith.truncf %in : f32 to f16 + linalg.yield %14 : f16 + } -> tensor<2x10x4096x64xf16> + %9 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%6#1 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %14 = arith.truncf %in : f32 to f16 + linalg.yield %14 : f16 + } -> tensor<2x10x4096x64xf16> + %10 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%6#2 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %14 = arith.truncf %in : f32 to f16 + linalg.yield %14 : f16 + } -> tensor<2x10x4096x64xf16> + util.return %8, %9, %10 : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16> +} +// CHECK-LABEL: func public @horizontal_fusion1 +// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK: %[[TRUNC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#0 : +// CHECK: %[[TRUNC1:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#1 : +// CHECK: %[[TRUNC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#2 : +// CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]] +// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2 + +// ----- + +util.func @horizontal_fusion2(%lhs : tensor<2x4096x640xi8>, + %rhs0 : tensor<2x640x640xi8>, %rhs1 : tensor<2x640x640xi8>) + -> tensor<2x4096x640xf16> { + %c0_i32 = arith.constant 32 : i32 + %0 = tensor.empty() : tensor<2x4096x640xf16> + %1 = tensor.empty() : tensor<2x4096x640xi32> + %2 = linalg.fill ins(%c0_i32 : i32) + outs(%1 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32> + %3:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs0, %rhs1 + : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>) + outs(%2, %2 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) { + ^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32): + %4 = arith.extsi %in : i8 to i32 + %5 = arith.extsi %in_0 : i8 to i32 + %6 = arith.muli %4, %5 : i32 + %7 = arith.addi %out, %6 : i32 + %8 = arith.extsi %in_1 : i8 to i32 + %9 = arith.muli %7, %8 : i32 + %10 = arith.addi %out_2, %9 : i32 + linalg.yield %7, %10 : i32, i32 + } -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) + %4 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%3#1, %3#0 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) + outs(%0 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %in_0: i32, %out: f16): + %5 = arith.sitofp %in : i32 to f32 + %6 = arith.truncf %5 : f32 to f16 + %7 = arith.sitofp %in_0 : i32 to f32 + %8 = arith.truncf %7 : f32 to f16 + %9 = arith.addf %6, %8 : f16 + linalg.yield %9 : f16 + } -> tensor<2x4096x640xf16> + util.return %4 : tensor<2x4096x640xf16> +} +// CHECK-LABEL: func public @horizontal_fusion2 +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK: %[[TRUNC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#1, %[[GENERIC]]#0 : +// CHECK: flow.return %[[TRUNC]] +// CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @avoid_use_def_violation_on_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %1 = arith.addf %b0, %b1 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = util.optimization_barrier %0 : tensor + %2 = tensor.empty() : tensor + %3 = linalg.generic { + indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], + iterator_types = []} + ins(%0, %1 : tensor, tensor) outs(%2 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %4 = arith.mulf %b0, %b1 : f32 + linalg.yield %4 : f32 + } -> tensor + util.return %3 : tensor +} +// CHECK-LABEL: func public @avoid_use_def_violation_on_consumer_fusion +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK: flow.return %[[GENERIC1]] +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[DISPATCH1]] +// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[DISPATCH1]], %[[BARRIER]] : +// CHECK: flow.return %[[GENERIC2]] +// CHECK: util.return %[[DISPATCH2]]