Skip to content

Commit

Permalink
Change elementwise op fusion heuristics. (#8723)
Browse files Browse the repository at this point in the history
Current fusion heuristics seems to have degraded over time.
With the ops moved to different dialects, and changes to op semantics,
the control functions used seem to not really capture the original intent.
This PR revisits the control functions used for elementwise operation fusion.

Fusion of elementwise operations with reshapes by expansion.
This change also pulls in the fusion be collapse to clean up some
additional reshapes and replaces some of the one-off patterns that
were intending to achieve a similar effect.

See #8724 and discourse.llvm.org/t/rfc-next-iteration-of-fusion-of-elementwise-operations/59955/4 for discussion of impact of this change.
  • Loading branch information
MaheshRavishankar authored Apr 18, 2022
1 parent 24baca2 commit 61b2bb2
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 153 deletions.
2 changes: 2 additions & 0 deletions iree/compiler/Dialect/Flow/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ cc_library(
"ExpandTensorShapes.cpp",
"ExportBenchmarkFuncs.cpp",
"FusionOfTensorOps.cpp",
"FusionUtils.cpp",
"InferNumericNarrowing.cpp",
"InitializeEmptyTensors.cpp",
"InjectDispatchTracing.cpp",
Expand All @@ -59,6 +60,7 @@ cc_library(
"VerifyInputLegality.cpp",
],
hdrs = [
"FusionUtils.h",
"Passes.h",
"Passes.h.inc",
],
Expand Down
2 changes: 2 additions & 0 deletions iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_cc_library(
NAME
Transforms
HDRS
"FusionUtils.h"
"Passes.h"
"Passes.h.inc"
SRCS
Expand All @@ -37,6 +38,7 @@ iree_cc_library(
"ExpandTensorShapes.cpp"
"ExportBenchmarkFuncs.cpp"
"FusionOfTensorOps.cpp"
"FusionUtils.cpp"
"InferNumericNarrowing.cpp"
"InitializeEmptyTensors.cpp"
"InjectDispatchTracing.cpp"
Expand Down
51 changes: 2 additions & 49 deletions iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.h"
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -827,57 +828,9 @@ struct CreateDispatchRegionOp : Base<OpType> {
// Heuristics for fusing dispatchble ops with root ops using tile + fuse.
//===----------------------------------------------------------------------===//

/// For the fusion of root op -> elementwise operation to be bufferized
/// in-place without use of extra memory, the result of the root operation
/// must be able to reuse the buffer for the result of the elementwise
/// operation. This is possible if input and output are accessed using the same
/// indexing map.
// TODO: This restriction can go away if we can vectorize always, but that has
// a long tail of tasks.
static bool canInsOperandTieWithOutsOperand(OpOperand *insOperand) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(insOperand->getOwner());
if (!linalgOp) return false;
AffineMap insOperandIndexingMap = linalgOp.getTiedIndexingMap(insOperand);
auto canTieWithOutsOperand = [&](OpOperand *outsOperand) {
if (linalgOp.getTiedIndexingMap(outsOperand) != insOperandIndexingMap) {
return false;
}
// TODO(#8411): Until ops are vectorized (always), we need
// to check that the elementtype matches for the operands to be tied.
// For now just doing this check for convolution ops since we expect
// contraction ops to be vectorized.
auto producerOp = insOperand->get().getDefiningOp();
if (isa<linalg::GenericOp, linalg::ConvolutionOpInterface>(producerOp) &&
insOperand->get().getType().cast<ShapedType>().getElementType() !=
outsOperand->get().getType().cast<ShapedType>().getElementType()) {
return false;
}
return true;
};
return llvm::any_of(linalgOp.getOutputOperands(), canTieWithOutsOperand);
}

/// Checks if the producer and consumer LinalgOps can be fused.
static bool areFusableLinalgOps(OpOperand &use) {
auto producerOp = use.get().getDefiningOp<linalg::LinalgOp>();
auto consumerOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (!producerOp || !consumerOp) return false;

// 1. Producer has a single result.
if (producerOp->getNumResults() != 1) return false;

// 2. Consumer is elementwise parallel.
if (consumerOp.getNumLoops() != consumerOp.getNumParallelLoops())
return false;

// 3. In consumer the result of producer is accessed using identity indexing.
AffineMap consumerIndexingMap = consumerOp.getTiedIndexingMap(&use);
if (!consumerIndexingMap.isIdentity()) return false;

// 4. In-place bufferization requirements (for now) require that the use in
// the consumer
// can re-use the buffer for a result.
return canInsOperandTieWithOutsOperand(&use);
return areLinalgOpsFusableUsingTileAndFuse(use);
}

/// Returns true if this is a fusable use.
Expand Down
200 changes: 96 additions & 104 deletions iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,116 +12,103 @@
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-flow-fusion-of-tensor-ops"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {

namespace {
/// Check if the producer generic op is fusable with the consumer generic op.
static bool areFusableOps(MLIRContext *context, Operation *producerOp,
Operation *consumerOp) {
// Check for i1 return types, if so aggressively fuse to avoid `i1` buffers.
if (llvm::all_of(producerOp->getResultTypes(), [](Type t) {
if (t.isInteger(1)) return true;
if (auto shapedType = t.dyn_cast<ShapedType>()) {
if (shapedType.getElementType().isInteger(1)) return true;
}
return false;
})) {
return true;
}

// If producer has a single user, always fuse
if (producerOp->hasOneUse()) return true;

// If the generic op is "just" copy, then fuse always.
Block &body = producerOp->getRegion(0).front();
if (std::begin(body)->hasTrait<OpTrait::IsTerminator>()) return true;

// All other cases dont fuse.
return false;
}

using linalg::LinalgOp;
namespace {

/// Pass to fuse linalg on tensor operations as well as fusion of hal.interface*
/// operations with linalg.tensor_reshape operation.
struct FusionOfTensorOpsPass
: public FusionOfTensorOpsBase<FusionOfTensorOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect>();
registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
}

void runOnOperation() override {
RewritePatternSet fusionPatterns(&getContext());
RewritePatternSet interfacePatterns(&getContext());
Operation *op = getOperation();
MLIRContext *context = op->getContext();

// Only fuse operations where all uses of the producer are generic
// operations. If an operation is used in a named op, it will be computed
// anyway, so the consumers can just use that value.
linalg::ControlFusionFn controlFn = [](const OpResult &producerResult,
OpOperand &consumerOperand) {
Operation *producer = producerResult.getOwner();
Operation *consumer = consumerOperand.getOwner();

// Limit the number of operands. We have hard limit (32) of bindings
// passing down to HAL. Set the number to be as same as the limit --
// IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT.
constexpr int64_t kIreeMaxOperandCount = 32;
DenseSet<Value> operands;
operands.insert(producer->operand_begin(), producer->operand_end());
operands.insert(consumer->operand_begin(),
std::next(consumer->operand_begin(),
consumerOperand.getOperandNumber()));
operands.insert(std::next(consumer->operand_begin(),
consumerOperand.getOperandNumber() + 1),
consumer->operand_end());
if (operands.size() >= kIreeMaxOperandCount) return false;

bool isBroadcast = false;
if (auto genericOp = dyn_cast<linalg::GenericOp>(producer)) {
// Detect op that only broadcast input as fusing them makes the new
// op cheaper.
if (genericOp.getNumParallelLoops() == genericOp.getNumLoops() &&
isa<linalg::YieldOp>(genericOp.getBody()->front())) {
for (OpOperand *opOperand : genericOp.getInputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
if (indexingMap.isProjectedPermutation() &&
indexingMap.getNumDims() != indexingMap.getNumResults()) {
isBroadcast = true;
break;
}
}
}
}
// Only fuse if it has a single linalg generic user. It is a
// simplistic heuristic to avoid duplicating ops that may be
// expensive.
// TODO: Add a cost model to allow ops to be duplicated.
bool hasI1ReturnType =
llvm::any_of(producer->getResultTypes(), [](Type t) {
if (t.isInteger(1)) return true;
if (auto shapedType = t.dyn_cast<ShapedType>()) {
if (shapedType.getElementType().isInteger(1)) return true;
}
return false;
});
if (!isBroadcast && !isa<arith::ConstantOp>(producer) &&
!hasI1ReturnType &&
!llvm::hasSingleElement(producerResult.getUsers())) {
return false;
}
return llvm::all_of(producerResult.getUsers(), [](Operation *user) {
return isa<linalg::GenericOp>(user);
});
};
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
linalg::ControlFusionFn fuseElementwiseOpsControlFn =
[&](const OpResult &producerResult, OpOperand &consumerOperand) {
Operation *producer = producerResult.getOwner();
Operation *consumer = consumerOperand.getOwner();

// Limit the number of operands. We have hard limit (32) of bindings
// passing down to HAL. Set the number to be as same as the limit --
// IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT.
constexpr int64_t kIreeMaxOperandCount = 32;
DenseSet<Value> operands;
operands.insert(producer->operand_begin(), producer->operand_end());
operands.insert(consumer->operand_begin(),
std::next(consumer->operand_begin(),
consumerOperand.getOperandNumber()));
operands.insert(std::next(consumer->operand_begin(),
consumerOperand.getOperandNumber() + 1),
consumer->operand_end());
if (operands.size() >= kIreeMaxOperandCount) return false;

return areFusableOps(context, producer, consumer);
};
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
fuseElementwiseOpsControlFn);

// Simple heuristic to decide if reshaope should be folded in the linalg.
// If the source of the reshape is a linalg op fold to potentially allow the
// two linalg ops to be fused. Otherwise leave it to avoid adding dimensions
// to the consumer linalg op.
linalg::ControlFusionFn foldReshapeBetweenLinalgFn =
// Always fold reshape by expansion.
linalg::ControlFusionFn fuseByExpansionControlFn =
[](const OpResult &producer, const OpOperand &consumer) {
auto collapseOp = producer.getDefiningOp<tensor::CollapseShapeOp>();
if (collapseOp) {
return collapseOp.src().getDefiningOp<LinalgOp>() != nullptr;
}
auto expandOp = producer.getDefiningOp<tensor::ExpandShapeOp>();
if (expandOp) {
return expandOp.src().getDefiningOp<LinalgOp>() != nullptr;
// Do not fuse producer generic op if it has more than one user.
if (auto producerGenericOp =
dyn_cast<linalg::GenericOp>(producer.getOwner())) {
return producerGenericOp->hasOneUse();
}
return false;
// Fuse in all other cases.
return true;
};
linalg::populateFoldReshapeOpsByExpansionPatterns(
fusionPatterns, foldReshapeBetweenLinalgFn);
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
fuseByExpansionControlFn);

// Constant fold Linalg operations.
auto constantFoldControlFn = [](const OpResult &producer,
Expand All @@ -145,36 +132,41 @@ struct FusionOfTensorOpsPass
return signalPassFailure();
}

RewritePatternSet reshapeCanonicalizations(&getContext());
linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
reshapeCanonicalizations);
LLVM_DEBUG({
llvm::dbgs() << "\n--- After first fixed point ---\n";
op->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});

// For fusion by collapsing, do so if the reshape is blocking tile and fuse.
linalg::ControlFusionFn fuseByCollapsingControlFn =
[](const OpResult &producer, const OpOperand &consumer) {
// Check the case where the producer is an expanding reshape op.
auto reshapeOp = dyn_cast<tensor::ExpandShapeOp>(producer.getOwner());
if (!reshapeOp) return true;

auto genericOp = cast<linalg::GenericOp>(consumer.getOwner());

// If the op can alraedy be fused with a producer by tile + fuse, do
// nothing.
for (OpOperand *operand : genericOp.getInputOperands()) {
if (areLinalgOpsFusableUsingTileAndFuse(*operand)) {
return false;
}
}
return true;
};
RewritePatternSet collapsingReshapePatterns(&getContext());
linalg::populateFoldReshapeOpsByCollapsingPatterns(
collapsingReshapePatterns, fuseByCollapsingControlFn);
tensor::CollapseShapeOp::getCanonicalizationPatterns(
reshapeCanonicalizations, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(reshapeCanonicalizations,
context);
linalg::InitTensorOp::getCanonicalizationPatterns(reshapeCanonicalizations,
context);
linalg::FillOp::getCanonicalizationPatterns(reshapeCanonicalizations,
context);
memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
collapsingReshapePatterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(
collapsingReshapePatterns, context);
memref::populateResolveRankedShapeTypeResultDimsPatterns(
collapsingReshapePatterns);
if (failed(applyPatternsAndFoldGreedily(
op->getRegions(), std::move(reshapeCanonicalizations)))) {
return signalPassFailure();
}

// Push the remaining reshapes down the graphs.
RewritePatternSet pushReshapePatterns(&getContext());
linalg::populatePushReshapeOpsPatterns(pushReshapePatterns);
tensor::CollapseShapeOp::getCanonicalizationPatterns(pushReshapePatterns,
context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(pushReshapePatterns,
context);
linalg::InitTensorOp::getCanonicalizationPatterns(pushReshapePatterns,
context);
linalg::FillOp::getCanonicalizationPatterns(pushReshapePatterns, context);
memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
std::move(pushReshapePatterns)))) {
op->getRegions(), std::move(collapsingReshapePatterns)))) {
return signalPassFailure();
}
}
Expand Down
Loading

0 comments on commit 61b2bb2

Please sign in to comment.