diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index 3dbc95149068..72443724b6a5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -45,7 +45,8 @@ void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) { void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op, InsertPoint previous) { IRRewriter::Listener::notifyOperationInserted(op, previous); - if (isa(op)) + auto dimOp = dyn_cast(op); + if (dimOp && isa(dimOp.getSource())) dimOps.insert(op); } @@ -60,16 +61,21 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, std::optional idx = dimOp.getConstantIndex(); if (!idx.has_value()) continue; + + if (isa(dimOp.getSource())) { + continue; + } + // Only DimOps with ranked tensors are supported. auto tensorType = llvm::dyn_cast(dimOp.getSource().getType()); if (!tensorType) continue; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(dimOp); if (!tensorType.isDynamicDim(*idx)) { // Rewrite static dimension with constant. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(dimOp); int64_t size = tensorType.getShape()[*idx]; rewriter.replaceOpWithNewOp(dimOp, size); continue; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 3fb9bdb3214c..7974609f8f32 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -266,18 +266,8 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, // Value is an OpResult. Operation *op = value.getDefiningOp(); OpResult opResult = llvm::cast(value); - b.setInsertionPoint(op); - // Case 3: Value is tied. Reify the dimensions of the tied operand. - auto tiedOp = dyn_cast(op); - if (tiedOp) { - Value tiedOperand = tiedOp.getTiedResultOperand(value); - if (tiedOperand && tiedOperand.getType() == value.getType()) - return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims, - createTensorDimOps); - } - - // Case 4: Query ShapeAwareOpInterface. + // Case 3: Query ShapeAwareOpInterface. auto shapeAwareOp = dyn_cast(op); if (shapeAwareOp) { ValueRange dims = @@ -286,6 +276,15 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, return success(); } + // Case 4: Value is tied. Reify the dimensions of the tied operand. + auto tiedOp = dyn_cast(op); + if (tiedOp) { + Value tiedOperand = tiedOp.getTiedResultOperand(value); + if (tiedOperand && tiedOperand.getType() == value.getType()) + return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims, + /*createTensorDimOps=*/true); + } + // Case 5: Query ReifyRankedShapedTypeOpInterface. auto reifyShapeOp = dyn_cast(op); if (reifyShapeOp) { @@ -308,8 +307,14 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, } /// Reify the dynamic dimensions of the given value. +/// Deprecated. Use `getOptimizedDynamicResultDims` instead. LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value, SmallVectorImpl &dynamicDims) { + + OpBuilder::InsertionGuard g(b); + if (auto op = value.getDefiningOp()) { + b.setInsertionPoint(op); + } return reifyDynamicResultDimsImpl(b, value, dynamicDims, /*createTensorDimOps=*/true); } @@ -473,7 +478,7 @@ movePrecedingOpsIntoDispatchRegion(RewriterBase &rewriter, rewriter.setInsertionPoint(target); SmallVector &dims = dispatchOpNewResultsDynamicDims.emplace_back(); - if (failed(reifyDynamicResultDims(rewriter, result, dims))) { + if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) { return target->emitOpError( "failed to reify dynamic dims of result to be yielded from " "dispatch region"); @@ -554,9 +559,10 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target, for (auto [index, result] : llvm::enumerate(target->getResults())) { replacedValues.push_back(result); yieldedResults.push_back(clonedTarget->getResult(index)); - rewriter.setInsertionPoint(target); + OpBuilder::InsertionGuard g1(rewriter); + rewriter.setInsertionPoint(regionOp); SmallVector &dims = dispatchOpNewResultsDynamicDims.emplace_back(); - if (failed(reifyDynamicResultDims(rewriter, result, dims))) { + if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) { return target->emitOpError( "failed to reify dynamic dims of result to be yielded from " "dispatch region"); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index c8896967b7f8..3e8ae66d680b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -692,7 +692,12 @@ isContractionOpSequence(Value yielded) { /// Recognize an operation that is horizontally fused contraction. /// TODO: The logic below is quite convoluted. Might be better /// off having a dedicated operation for this. -bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) { +bool isaHorizontallyFusedContraction(Operation *op) { + auto linalgOp = dyn_cast_or_null(op); + if (!linalgOp) { + return false; + } + if (linalgOp->getNumResults() == 1) { return false; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index c02699f80c25..48ecf9202d95 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -214,7 +214,7 @@ bool isGatherlikeOp(Operation *op); /// Check if a given operation is a horizontally fused contraction operation. /// The expectation is that the LHS is common, and all the operands are /// different RHS. -bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp); +bool isaHorizontallyFusedContraction(Operation *op); } // namespace mlir::iree_compiler::IREE::LinalgExt #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_ diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 562c779ba71d..b2de618764fd 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -24,6 +25,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -39,6 +41,7 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-dispatch-creation-form-dispatch-regions" @@ -330,36 +333,34 @@ static bool hasCompatibleOuterParallelLoops( consumerIndexingMap, rootOuterParallelLoops); } -/// For all uses of an operation, finds the use that dominates all other uses. -static std::optional -getFusableUse(Operation *op, DominanceInfo const &dominanceInfo, - bool aggressiveFusion) { +/// For all uses of an operation, return the uses that could be fused. +/// The returned vector contains the uses in dominance order. +static SmallVector +getFusableUses(MLIRContext *context, Operation *op, + DominanceInfo const &dominanceInfo, bool aggressiveFusion) { if (!aggressiveFusion && llvm::count_if(op->getUses(), [](OpOperand &use) { return !isa(use.getOwner()); }) != 1) { - return std::nullopt; + return {}; } - // Collect non-dim users. - SmallVector nonDimUsers; - for (Operation *user : op->getUsers()) { - if (isa(user)) - continue; - nonDimUsers.push_back(user); - } - - // Find the use in a non-dim user that dominates all other non-dim users. - for (auto &use : op->getUses()) { + // Collect all fusable user candidates. + SetVector fusableUses; + for (OpOperand &use : op->getUses()) { Operation *user = use.getOwner(); - if (isa(user)) + if (isa(user)) { continue; - if (llvm::all_of(nonDimUsers, [&](Operation *c) { - return dominanceInfo.dominates(user, c); - })) { - return &use; } + fusableUses.insert(&use); } - return std::nullopt; + + SmallVector usesVec = fusableUses.takeVector(); + llvm::sort(usesVec, [&](OpOperand *lhsUse, OpOperand *rhsUse) { + return dominanceInfo.properlyDominates(lhsUse->getOwner(), + rhsUse->getOwner()); + }); + + return usesVec; } /// Returns true if the operands are fusable. @@ -576,38 +577,39 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef roots, for (Operation *root : roots) { SmallVector workList; llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root); + int64_t rootNumber = getRootNumber(root); workList.push_back(root); while (!workList.empty()) { Operation *currRoot = workList.pop_back_val(); - assert(hasRootOpAttribute(currRoot) && - "unexpected non-root op in worklist"); - - // Helper function to make the consumer the root instead of the producer - // when they are to be fused. - auto updateRootTo = [&context, &currRoot](Operation *newRoot) { - int64_t rootNumber = getRootNumber(currRoot); - setRootAttribute(context, newRoot, rootNumber); - removeRootOpAttribute(currRoot); - appendToFusionGroup(currRoot, rootNumber); - }; - - std::optional fusableUse = - getFusableUse(currRoot, dominanceInfo, - /*aggressiveFusion=*/options.aggressiveFusion); - if (!fusableUse) - continue; - // Analyse the use to see if it is fusable. - Operation *consumerOp = fusableUse.value()->getOwner(); - if (hasRootOpAttribute(consumerOp) || - hasFusionGroupsAttribute(consumerOp)) { + SmallVector fusableUses = + getFusableUses(context, currRoot, dominanceInfo, + /*aggressiveFusion=*/options.aggressiveFusion); + if (fusableUses.empty()) continue; + + // For now disable the fusing with multiple consumers for all + // operations other than horizontally fused gemms. This should + // work in general but is causing time-outs on some CI examples. + if (!IREE::LinalgExt::isaHorizontallyFusedContraction(root)) { + fusableUses = {fusableUses.front()}; } - if (isFusableWithConsumer(*(fusableUse.value()), rootOuterParallelLoops, - options)) { - updateRootTo(consumerOp); - workList.push_back(consumerOp); + // Analyse the use to see if it is fusable. + for (OpOperand *fusableUse : fusableUses) { + Operation *consumerOp = fusableUse->getOwner(); + if (hasRootOpAttribute(consumerOp) || + hasFusionGroupsAttribute(consumerOp)) { + continue; + } + + if (isFusableWithConsumer(*fusableUse, rootOuterParallelLoops, + options)) { + appendToFusionGroup(consumerOp, rootNumber); + workList.push_back(consumerOp); + } else { + break; + } } } } @@ -701,16 +703,16 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum, continue; } - std::optional fusableUse = - getFusableUse(producer, dominanceInfo, - /*aggressiveFusion=*/options.aggressiveFusion); - if (!fusableUse || fusableUse.value()->getOwner() != candidate) - continue; - if (!isFusableWithProducer(operand, rootOuterParallelLoops, options)) { continue; } + SmallVector fusableUses = + getFusableUses(context, producer, dominanceInfo, + /*aggressiveFusion=*/options.aggressiveFusion); + if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate) + continue; + appendToFusionGroup(producer, groupNum); worklist.push_back(producer); } @@ -814,14 +816,14 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, static LogicalResult createFusionGroups(TensorDimTrackingRewriter &rewriter, mlir::FunctionOpInterface funcOp, - DominanceInfo const &dominanceInfo, + DominanceInfo &dominanceInfo, FormDispatchRegionsPassOptions const &options) { // Step 1: Decide fusion groups (heuristic). This marks rootOps with an // attribute unsigned numRoots = decideFusableLinalgOps(funcOp.getFunctionBody(), dominanceInfo, options); SmallVector roots(numRoots, nullptr); - DenseMap> producers; + DenseMap> fusedOperations; LLVM_DEBUG({ llvm::dbgs() << "\n--- After deciding fusion groups ---\n"; @@ -834,11 +836,12 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, funcOp.walk([&](Operation *op) { if (hasRootOpAttribute(op)) { roots[getRootNumber(op)] = op; + fusedOperations[getRootNumber(op)].push_back(op); removeRootOpAttribute(op); } if (hasFusionGroupsAttribute(op)) { assert(getFusionGroups(op).size() == 1 && "expected exactly one group"); - producers[getFusionGroups(op).front()].push_back(op); + fusedOperations[getFusionGroups(op).front()].push_back(op); removeFusionGroupsAttribute(op); } }); @@ -846,7 +849,32 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, // Step 2. Create a DispatchRegionOp for every fusion group. OpBuilder::InsertionGuard g(rewriter); SmallVector regionOps; - for (const auto &it : llvm::enumerate(roots)) { + for (auto [rootIndex, root] : llvm::enumerate(roots)) { + + // Sort producers and consumers topologically. All fused ops must be in the + // same block as the root. + SmallVector &currFusedOperations = fusedOperations[rootIndex]; + bool sortResult = mlir::computeTopologicalSorting(currFusedOperations); + (void)sortResult; + assert(sortResult && "could not compute topological sorting"); + + int rootPos = 0; + for (auto [index, fusedOperation] : llvm::enumerate(currFusedOperations)) { + if (fusedOperation == root) { + rootPos = index; + break; + } + } + SmallVector producers, consumers; + if (rootPos > 0) { + producers = llvm::to_vector( + ArrayRef(currFusedOperations).take_front(rootPos)); + } + if (rootPos < currFusedOperations.size() - 1) { + consumers = llvm::to_vector( + ArrayRef(currFusedOperations).drop_front(rootPos + 1)); + } + // Simplify tensor::DimOps. { SmallVector dimOps = rewriter.getTensorDimOps(); @@ -857,20 +885,14 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, // Create fusion group. IREE::Flow::DispatchRegionOp regionOp; - auto maybeRegionOp = - IREE::Flow::wrapOpInDispatchRegion(rewriter, it.value()); - if (failed(maybeRegionOp)) - return failure(); + auto maybeRegionOp = IREE::Flow::wrapOpInDispatchRegion(rewriter, root); + if (failed(maybeRegionOp)) { + return root->emitOpError("failed to move root into dispatch"); + } regionOp = *maybeRegionOp; - // Sort producers topologically. All producers must be in the same block - // as the root. - bool sortResult = mlir::computeTopologicalSorting(producers[it.index()]); - (void)sortResult; - assert(sortResult && "could not compute topological sorting"); - // Move ops into the region. - for (Operation *producer : llvm::reverse(producers[it.index()])) { + for (Operation *producer : llvm::reverse(producers)) { // Simplify tensor::DimOps. { SmallVector dimOps = rewriter.getTensorDimOps(); @@ -881,8 +903,31 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, auto newRegionOp = movePrecedingOpsIntoDispatchRegion(rewriter, producer, regionOp); - if (failed(newRegionOp)) - return failure(); + if (failed(newRegionOp)) { + return producer->emitOpError("failed to move producer into region"); + } + regionOp = *newRegionOp; + } + + for (Operation *consumer : consumers) { + // Simplify tensor::DimOps. + { + SmallVector dimOps = rewriter.getTensorDimOps(); + if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) { + return failure(); + } + } + + if (failed(moveOperandDefs(rewriter, consumer, regionOp, dominanceInfo, + regionOp.getOperation()))) { + continue; + } + + auto newRegionOp = IREE::Flow::moveFollowingOpIntoDispatchRegion( + rewriter, consumer, regionOp); + if (failed(newRegionOp)) { + return consumer->emitOpError("failed to move consumer into region"); + } regionOp = *newRegionOp; } // Simplify tensor::DimOps. @@ -916,7 +961,7 @@ struct FormDispatchRegionsPass final /// Create dispatch.region Ops based on a fusion heuristic. void FormDispatchRegionsPass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); - DominanceInfo const &dominanceInfo = getAnalysis(); + DominanceInfo &dominanceInfo = getAnalysis(); TensorDimTrackingRewriter rewriter(funcOp); FormDispatchRegionsPassOptions options{aggressiveFusion, fusePadWithConsumers, fusePadWithProducers}; @@ -924,5 +969,18 @@ void FormDispatchRegionsPass::runOnOperation() { funcOp->emitOpError("failed to create fusion groups"); return signalPassFailure(); } + + // Canonicalize all the dispatch regions to remove unused operands. + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); + IREE::Flow::DispatchRegionOp::getCanonicalizationPatterns(patterns, context); + GreedyRewriteConfig config; + config.maxIterations = GreedyRewriteConfig::kNoLimit; + config.fold = true; + if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) { + funcOp.emitOpError("failed in cleanup patterns"); + return signalPassFailure(); + } } } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index c428091f6cf8..1a5e5f51c720 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -10,7 +10,10 @@ #include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h" #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir::iree_compiler::DispatchCreation { @@ -97,4 +100,44 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } +LogicalResult moveOperandDefs(RewriterBase &rewriter, + ArrayRef operations, + Operation *insertionPoint, + DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations) { + BackwardSliceOptions options; + llvm::DenseSet ignoreOperationsSet; + ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, insertionPoint) && + !ignoreOperationsSet.contains(op); + }; + // Set inclusive to true cause the slice is computed from the operand, and + // we want to include the defining op (which is the point here) + options.omitUsesFromAbove = false; + options.inclusive = true; + + llvm::SetVector slice; + for (auto op : operations) { + for (auto operand : op->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + auto regions = op->getRegions(); + if (regions.empty()) { + continue; + } + llvm::SetVector capturedVals; + mlir::getUsedValuesDefinedAbove(regions, capturedVals); + for (auto value : capturedVals) { + getBackwardSlice(value, &slice, options); + } + } + + mlir::topologicalSort(slice); + for (auto op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index 1d9c9306f7ae..f99d7ae6ff24 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -10,7 +10,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir::iree_compiler::DispatchCreation { @@ -19,4 +21,11 @@ namespace mlir::iree_compiler::DispatchCreation { bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool fuseMultiReduction); +/// Move the definition of operands of `operations` before `insertionPoint`. +LogicalResult moveOperandDefs(RewriterBase &rewriter, + ArrayRef operations, + Operation *insertionPoint, + DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations = {}); + } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir index 16beb5b01599..1b961104b882 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir @@ -1431,9 +1431,9 @@ util.func public @multi_use_producer_fusion(%arg0 : tensor, %arg1 : ten // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[MATMUL]], %[[BIAS]] : // CHECK-SAME: outs(%[[INIT]] : -// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[RESULT0]] -// CHECK-DAG: flow.dispatch.tensor.store %[[MATMUL]], %[[RESULT1]] -// CHECK: util.return %[[DISPATCH]]#1, %[[DISPATCH]]#0 +// CHECK-DAG: flow.dispatch.tensor.store %[[MATMUL]], %[[RESULT0]] +// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[RESULT1]] +// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1 // ----- @@ -1535,9 +1535,9 @@ util.func public @fuse_conv2d_with_multiple_uses(%input: tensor<1x225x225x16xf32 // CHECK-SAME: %[[OUT2:.+]]: !flow.dispatch.tensor> // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[OUT1]] -// CHECK-DAG: flow.dispatch.tensor.store %[[CONV]], %[[OUT2]] -// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1 +// CHECK-DAG: flow.dispatch.tensor.store %[[CONV]], %[[OUT1]] +// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC]], %[[OUT2]] +// CHECK: util.return %[[DISPATCH]]#1, %[[DISPATCH]]#0 // ----- diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir index 9934972ca0f5..4a1b808a9127 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir @@ -1008,3 +1008,206 @@ util.func @scatter_index_producer_fusion(%arg0 : tensor, // CHECK-SAME: ins(%{{.+}}, %[[GENERIC]] : // CHECK: flow.return %[[SCATTER]] // CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @move_captured_from_above_ops(%arg0 : tensor<1x1x2x4xi32>, + %arg1 : f64, %arg2 : f64) -> tensor<2x3xi8> { + %empty = tensor.empty() : tensor<2x3xi32> + %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1] + inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<1x1x2x4xi32> -> tensor<2x3xi32> + %0 = arith.mulf %arg1, %arg2 : f64 + %1 = tensor.empty() : tensor<2x3xi8> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%unpack : tensor<2x3xi32>) outs(%1 : tensor<2x3xi8>) { + ^bb0(%in: i32, %out: i8): + %3 = arith.sitofp %in : i32 to f32 + %4 = arith.truncf %0 : f64 to f32 + %5 = arith.mulf %3, %4 : f32 + %48 = arith.fptosi %5 : f32 to i8 + linalg.yield %48 : i8 + } -> tensor<2x3xi8> + util.return %2 : tensor<2x3xi8> +} +// CHECK-LABEL: func public @move_captured_from_above_ops +// CHECK: %[[OP:.+]] = arith.mulf +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[UNPACK]] : +// CHECK: %[[TRUNCF:.+]] = arith.truncf %[[OP]] +// CHECK: linalg.yield +// CHECK: flow.return %[[GENERIC]] +// CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @horizontal_fusion1(%lhs : tensor<2x4096x640xf16>, + %rhs0 : tensor<10x64x640xf16>, %rhs1 : tensor<10x64x640xf16>, + %rhs2 : tensor<10x64x640xf16>) -> + (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, + tensor<2x10x4096x64xf16>) { + %4 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0 : f32 + %5 = linalg.fill ins(%cst : f32) + outs(%4 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %6:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs0, %rhs1, %rhs2 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%5, %5, %5 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %14 = arith.extf %in : f16 to f32 + %15 = arith.extf %in_0 : f16 to f32 + %16 = arith.mulf %14, %15 : f32 + %17 = arith.addf %out, %16 : f32 + %18 = arith.extf %in_1 : f16 to f32 + %19 = arith.mulf %14, %18 : f32 + %20 = arith.addf %out_3, %19 : f32 + %21 = arith.extf %in_2 : f16 to f32 + %22 = arith.mulf %14, %21 : f32 + %23 = arith.addf %out_4, %22 : f32 + linalg.yield %17, %20, %23 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) + %7 = tensor.empty() : tensor<2x10x4096x64xf16> + %8 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%6#0 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %14 = arith.truncf %in : f32 to f16 + linalg.yield %14 : f16 + } -> tensor<2x10x4096x64xf16> + %9 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%6#1 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %14 = arith.truncf %in : f32 to f16 + linalg.yield %14 : f16 + } -> tensor<2x10x4096x64xf16> + %10 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%6#2 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %14 = arith.truncf %in : f32 to f16 + linalg.yield %14 : f16 + } -> tensor<2x10x4096x64xf16> + util.return %8, %9, %10 : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16> +} +// CHECK-LABEL: func public @horizontal_fusion1 +// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK: %[[TRUNC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#0 : +// CHECK: %[[TRUNC1:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#1 : +// CHECK: %[[TRUNC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#2 : +// CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]] +// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2 + +// ----- + +util.func @horizontal_fusion2(%lhs : tensor<2x4096x640xi8>, + %rhs0 : tensor<2x640x640xi8>, %rhs1 : tensor<2x640x640xi8>) + -> tensor<2x4096x640xf16> { + %c0_i32 = arith.constant 32 : i32 + %0 = tensor.empty() : tensor<2x4096x640xf16> + %1 = tensor.empty() : tensor<2x4096x640xi32> + %2 = linalg.fill ins(%c0_i32 : i32) + outs(%1 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32> + %3:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs0, %rhs1 + : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>) + outs(%2, %2 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) { + ^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32): + %4 = arith.extsi %in : i8 to i32 + %5 = arith.extsi %in_0 : i8 to i32 + %6 = arith.muli %4, %5 : i32 + %7 = arith.addi %out, %6 : i32 + %8 = arith.extsi %in_1 : i8 to i32 + %9 = arith.muli %7, %8 : i32 + %10 = arith.addi %out_2, %9 : i32 + linalg.yield %7, %10 : i32, i32 + } -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) + %4 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%3#1, %3#0 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) + outs(%0 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %in_0: i32, %out: f16): + %5 = arith.sitofp %in : i32 to f32 + %6 = arith.truncf %5 : f32 to f16 + %7 = arith.sitofp %in_0 : i32 to f32 + %8 = arith.truncf %7 : f32 to f16 + %9 = arith.addf %6, %8 : f16 + linalg.yield %9 : f16 + } -> tensor<2x4096x640xf16> + util.return %4 : tensor<2x4096x640xf16> +} +// CHECK-LABEL: func public @horizontal_fusion2 +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK: %[[TRUNC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#1, %[[GENERIC]]#0 : +// CHECK: flow.return %[[TRUNC]] +// CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @avoid_use_def_violation_on_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %1 = arith.addf %b0, %b1 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = util.optimization_barrier %0 : tensor + %2 = tensor.empty() : tensor + %3 = linalg.generic { + indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], + iterator_types = []} + ins(%0, %1 : tensor, tensor) outs(%2 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %4 = arith.mulf %b0, %b1 : f32 + linalg.yield %4 : f32 + } -> tensor + util.return %3 : tensor +} +// CHECK-LABEL: func public @avoid_use_def_violation_on_consumer_fusion +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK: flow.return %[[GENERIC1]] +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[DISPATCH1]] +// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[DISPATCH1]], %[[BARRIER]] : +// CHECK: flow.return %[[GENERIC2]] +// CHECK: util.return %[[DISPATCH2]]