diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD index 734e44e101bb..d59c4f20c21f 100644 --- a/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/BUILD @@ -42,6 +42,7 @@ cc_library( "ExpandTensorShapes.cpp", "ExportBenchmarkFuncs.cpp", "FusionOfTensorOps.cpp", + "FusionUtils.cpp", "InferNumericNarrowing.cpp", "InitializeEmptyTensors.cpp", "InjectDispatchTracing.cpp", @@ -59,6 +60,7 @@ cc_library( "VerifyInputLegality.cpp", ], hdrs = [ + "FusionUtils.h", "Passes.h", "Passes.h.inc", ], diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index d72e468d70f7..958f6111dc73 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( NAME Transforms HDRS + "FusionUtils.h" "Passes.h" "Passes.h.inc" SRCS @@ -37,6 +38,7 @@ iree_cc_library( "ExpandTensorShapes.cpp" "ExportBenchmarkFuncs.cpp" "FusionOfTensorOps.cpp" + "FusionUtils.cpp" "InferNumericNarrowing.cpp" "InitializeEmptyTensors.cpp" "InjectDispatchTracing.cpp" diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 9ff93ed7af2a..1028b1d54a58 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.h" +#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" #include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" @@ -827,57 +828,9 @@ struct CreateDispatchRegionOp : Base { // Heuristics for fusing dispatchble ops with root ops using tile + fuse. //===----------------------------------------------------------------------===// -/// For the fusion of root op -> elementwise operation to be bufferized -/// in-place without use of extra memory, the result of the root operation -/// must be able to reuse the buffer for the result of the elementwise -/// operation. This is possible if input and output are accessed using the same -/// indexing map. -// TODO: This restriction can go away if we can vectorize always, but that has -// a long tail of tasks. -static bool canInsOperandTieWithOutsOperand(OpOperand *insOperand) { - auto linalgOp = dyn_cast(insOperand->getOwner()); - if (!linalgOp) return false; - AffineMap insOperandIndexingMap = linalgOp.getTiedIndexingMap(insOperand); - auto canTieWithOutsOperand = [&](OpOperand *outsOperand) { - if (linalgOp.getTiedIndexingMap(outsOperand) != insOperandIndexingMap) { - return false; - } - // TODO(#8411): Until ops are vectorized (always), we need - // to check that the elementtype matches for the operands to be tied. - // For now just doing this check for convolution ops since we expect - // contraction ops to be vectorized. - auto producerOp = insOperand->get().getDefiningOp(); - if (isa(producerOp) && - insOperand->get().getType().cast().getElementType() != - outsOperand->get().getType().cast().getElementType()) { - return false; - } - return true; - }; - return llvm::any_of(linalgOp.getOutputOperands(), canTieWithOutsOperand); -} - /// Checks if the producer and consumer LinalgOps can be fused. static bool areFusableLinalgOps(OpOperand &use) { - auto producerOp = use.get().getDefiningOp(); - auto consumerOp = dyn_cast(use.getOwner()); - if (!producerOp || !consumerOp) return false; - - // 1. Producer has a single result. - if (producerOp->getNumResults() != 1) return false; - - // 2. Consumer is elementwise parallel. - if (consumerOp.getNumLoops() != consumerOp.getNumParallelLoops()) - return false; - - // 3. In consumer the result of producer is accessed using identity indexing. - AffineMap consumerIndexingMap = consumerOp.getTiedIndexingMap(&use); - if (!consumerIndexingMap.isIdentity()) return false; - - // 4. In-place bufferization requirements (for now) require that the use in - // the consumer - // can re-use the buffer for a result. - return canInsOperandTieWithOutsOperand(&use); + return areLinalgOpsFusableUsingTileAndFuse(use); } /// Returns true if this is a fusable use. diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp index 70c7c3f5bccb..3bc2b620a819 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp @@ -12,116 +12,103 @@ // //===----------------------------------------------------------------------===// +#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" #include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-flow-fusion-of-tensor-ops" + namespace mlir { namespace iree_compiler { namespace IREE { namespace Flow { -namespace { +/// Check if the producer generic op is fusable with the consumer generic op. +static bool areFusableOps(MLIRContext *context, Operation *producerOp, + Operation *consumerOp) { + // Check for i1 return types, if so aggressively fuse to avoid `i1` buffers. + if (llvm::all_of(producerOp->getResultTypes(), [](Type t) { + if (t.isInteger(1)) return true; + if (auto shapedType = t.dyn_cast()) { + if (shapedType.getElementType().isInteger(1)) return true; + } + return false; + })) { + return true; + } + + // If producer has a single user, always fuse + if (producerOp->hasOneUse()) return true; + + // If the generic op is "just" copy, then fuse always. + Block &body = producerOp->getRegion(0).front(); + if (std::begin(body)->hasTrait()) return true; + + // All other cases dont fuse. + return false; +} -using linalg::LinalgOp; +namespace { /// Pass to fuse linalg on tensor operations as well as fusion of hal.interface* /// operations with linalg.tensor_reshape operation. struct FusionOfTensorOpsPass : public FusionOfTensorOpsBase { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { RewritePatternSet fusionPatterns(&getContext()); - RewritePatternSet interfacePatterns(&getContext()); Operation *op = getOperation(); MLIRContext *context = op->getContext(); // Only fuse operations where all uses of the producer are generic // operations. If an operation is used in a named op, it will be computed // anyway, so the consumers can just use that value. - linalg::ControlFusionFn controlFn = [](const OpResult &producerResult, - OpOperand &consumerOperand) { - Operation *producer = producerResult.getOwner(); - Operation *consumer = consumerOperand.getOwner(); - - // Limit the number of operands. We have hard limit (32) of bindings - // passing down to HAL. Set the number to be as same as the limit -- - // IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT. - constexpr int64_t kIreeMaxOperandCount = 32; - DenseSet operands; - operands.insert(producer->operand_begin(), producer->operand_end()); - operands.insert(consumer->operand_begin(), - std::next(consumer->operand_begin(), - consumerOperand.getOperandNumber())); - operands.insert(std::next(consumer->operand_begin(), - consumerOperand.getOperandNumber() + 1), - consumer->operand_end()); - if (operands.size() >= kIreeMaxOperandCount) return false; - - bool isBroadcast = false; - if (auto genericOp = dyn_cast(producer)) { - // Detect op that only broadcast input as fusing them makes the new - // op cheaper. - if (genericOp.getNumParallelLoops() == genericOp.getNumLoops() && - isa(genericOp.getBody()->front())) { - for (OpOperand *opOperand : genericOp.getInputOperands()) { - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); - if (indexingMap.isProjectedPermutation() && - indexingMap.getNumDims() != indexingMap.getNumResults()) { - isBroadcast = true; - break; - } - } - } - } - // Only fuse if it has a single linalg generic user. It is a - // simplistic heuristic to avoid duplicating ops that may be - // expensive. - // TODO: Add a cost model to allow ops to be duplicated. - bool hasI1ReturnType = - llvm::any_of(producer->getResultTypes(), [](Type t) { - if (t.isInteger(1)) return true; - if (auto shapedType = t.dyn_cast()) { - if (shapedType.getElementType().isInteger(1)) return true; - } - return false; - }); - if (!isBroadcast && !isa(producer) && - !hasI1ReturnType && - !llvm::hasSingleElement(producerResult.getUsers())) { - return false; - } - return llvm::all_of(producerResult.getUsers(), [](Operation *user) { - return isa(user); - }); - }; - linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); + linalg::ControlFusionFn fuseElementwiseOpsControlFn = + [&](const OpResult &producerResult, OpOperand &consumerOperand) { + Operation *producer = producerResult.getOwner(); + Operation *consumer = consumerOperand.getOwner(); + + // Limit the number of operands. We have hard limit (32) of bindings + // passing down to HAL. Set the number to be as same as the limit -- + // IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT. + constexpr int64_t kIreeMaxOperandCount = 32; + DenseSet operands; + operands.insert(producer->operand_begin(), producer->operand_end()); + operands.insert(consumer->operand_begin(), + std::next(consumer->operand_begin(), + consumerOperand.getOperandNumber())); + operands.insert(std::next(consumer->operand_begin(), + consumerOperand.getOperandNumber() + 1), + consumer->operand_end()); + if (operands.size() >= kIreeMaxOperandCount) return false; + + return areFusableOps(context, producer, consumer); + }; + linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, + fuseElementwiseOpsControlFn); - // Simple heuristic to decide if reshaope should be folded in the linalg. - // If the source of the reshape is a linalg op fold to potentially allow the - // two linalg ops to be fused. Otherwise leave it to avoid adding dimensions - // to the consumer linalg op. - linalg::ControlFusionFn foldReshapeBetweenLinalgFn = + // Always fold reshape by expansion. + linalg::ControlFusionFn fuseByExpansionControlFn = [](const OpResult &producer, const OpOperand &consumer) { - auto collapseOp = producer.getDefiningOp(); - if (collapseOp) { - return collapseOp.src().getDefiningOp() != nullptr; - } - auto expandOp = producer.getDefiningOp(); - if (expandOp) { - return expandOp.src().getDefiningOp() != nullptr; + // Do not fuse producer generic op if it has more than one user. + if (auto producerGenericOp = + dyn_cast(producer.getOwner())) { + return producerGenericOp->hasOneUse(); } - return false; + // Fuse in all other cases. + return true; }; - linalg::populateFoldReshapeOpsByExpansionPatterns( - fusionPatterns, foldReshapeBetweenLinalgFn); + linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, + fuseByExpansionControlFn); // Constant fold Linalg operations. auto constantFoldControlFn = [](const OpResult &producer, @@ -145,36 +132,41 @@ struct FusionOfTensorOpsPass return signalPassFailure(); } - RewritePatternSet reshapeCanonicalizations(&getContext()); - linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( - reshapeCanonicalizations); + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After first fixed point ---\n"; + op->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + + // For fusion by collapsing, do so if the reshape is blocking tile and fuse. + linalg::ControlFusionFn fuseByCollapsingControlFn = + [](const OpResult &producer, const OpOperand &consumer) { + // Check the case where the producer is an expanding reshape op. + auto reshapeOp = dyn_cast(producer.getOwner()); + if (!reshapeOp) return true; + + auto genericOp = cast(consumer.getOwner()); + + // If the op can alraedy be fused with a producer by tile + fuse, do + // nothing. + for (OpOperand *operand : genericOp.getInputOperands()) { + if (areLinalgOpsFusableUsingTileAndFuse(*operand)) { + return false; + } + } + return true; + }; + RewritePatternSet collapsingReshapePatterns(&getContext()); + linalg::populateFoldReshapeOpsByCollapsingPatterns( + collapsingReshapePatterns, fuseByCollapsingControlFn); tensor::CollapseShapeOp::getCanonicalizationPatterns( - reshapeCanonicalizations, context); - tensor::ExpandShapeOp::getCanonicalizationPatterns(reshapeCanonicalizations, - context); - linalg::InitTensorOp::getCanonicalizationPatterns(reshapeCanonicalizations, - context); - linalg::FillOp::getCanonicalizationPatterns(reshapeCanonicalizations, - context); - memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns); + collapsingReshapePatterns, context); + tensor::ExpandShapeOp::getCanonicalizationPatterns( + collapsingReshapePatterns, context); + memref::populateResolveRankedShapeTypeResultDimsPatterns( + collapsingReshapePatterns); if (failed(applyPatternsAndFoldGreedily( - op->getRegions(), std::move(reshapeCanonicalizations)))) { - return signalPassFailure(); - } - - // Push the remaining reshapes down the graphs. - RewritePatternSet pushReshapePatterns(&getContext()); - linalg::populatePushReshapeOpsPatterns(pushReshapePatterns); - tensor::CollapseShapeOp::getCanonicalizationPatterns(pushReshapePatterns, - context); - tensor::ExpandShapeOp::getCanonicalizationPatterns(pushReshapePatterns, - context); - linalg::InitTensorOp::getCanonicalizationPatterns(pushReshapePatterns, - context); - linalg::FillOp::getCanonicalizationPatterns(pushReshapePatterns, context); - memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns); - if (failed(applyPatternsAndFoldGreedily(op->getRegions(), - std::move(pushReshapePatterns)))) { + op->getRegions(), std::move(collapsingReshapePatterns)))) { return signalPassFailure(); } } diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp new file mode 100644 index 000000000000..65ab4083892e --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp @@ -0,0 +1,79 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===--- FusionUtils.cpp - Utilities that are useful for fusion ----------===// +// +// Defines utility functions and analyses that are useful across passes +// to help with fusion before dispatch region formation. +// +//===---------------------------------------------------------------------===// +#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +/// For the fusion of root op -> elementwise operation to be bufferized +/// in-place without use of extra memory, the result of the root operation +/// must be able to reuse the buffer for the result of the elementwise +/// operation. This is possible if input and output are accessed using the same +/// indexing map. +// TODO: This restriction can go away if we can vectorize always, but that has +// a long tail of tasks. +static bool canInsOperandTieWithOutsOperand(OpOperand *insOperand) { + auto linalgOp = dyn_cast(insOperand->getOwner()); + if (!linalgOp) return false; + AffineMap insOperandIndexingMap = linalgOp.getTiedIndexingMap(insOperand); + auto canTieWithOutsOperand = [&](OpOperand *outsOperand) { + if (linalgOp.getTiedIndexingMap(outsOperand) != insOperandIndexingMap) { + return false; + } + // TODO(#8411): Until ops are vectorized (always), we need + // to check that the elementtype matches for the operands to be tied. + // For now just doing this check for convolution ops since we expect + // contraction ops to be vectorized. + auto producer = insOperand->get().getDefiningOp(); + if (isa(producer) && + insOperand->get().getType().cast().getElementType() != + outsOperand->get().getType().cast().getElementType()) { + return false; + } + return true; + }; + return llvm::any_of(linalgOp.getOutputOperands(), canTieWithOutsOperand); +} + +bool areLinalgOpsFusableUsingTileAndFuse(OpOperand &use) { + auto producer = use.get().getDefiningOp(); + auto consumer = dyn_cast(use.getOwner()); + if (!producer || !consumer) return false; + + // 1. Producer has a single result. + if (producer->getNumResults() != 1) return false; + + // 2. Consumer is elementwise parallel. + if (consumer.getNumLoops() != consumer.getNumParallelLoops()) return false; + + // 3. Producer has a single use. + // TODO(ravishankarm): Could be relaxed if dominance information + // is used to fuse with consumer, and both results become outputs of the + // dispatch. + if (!producer->hasOneUse()) return false; + + // 4. In consumer the result of producer is accessed using identity indexing. + AffineMap consumerIndexingMap = consumer.getTiedIndexingMap(&use); + if (!consumerIndexingMap.isIdentity()) return false; + + // 5. In-place bufferization requirements (for now) require that the use in + // the consumer can re-use the buffer for a result. + return canInsOperandTieWithOutsOperand(&use); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h b/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h new file mode 100644 index 000000000000..245ad754cb7a --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h @@ -0,0 +1,28 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===--- FusionUtils.h - Utilities that are useful for fusion -------------===// +// +// Declares utility functions and analyses that are useful across passes +// to help with fusion before dispatch region formation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +/// Returns true if the `use` is from a producer linalg op that can be fused +/// with the consumer linalg op using tile + fuse. +bool areLinalgOpsFusableUsingTileAndFuse(OpOperand &use); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir