Skip to content

Commit

Permalink
Merge pull request #9006 from matthias-springer/bufferize_linalg_ext
Browse files Browse the repository at this point in the history
Support bufferization of linalg_ext.fft and reverse
  • Loading branch information
matthias-springer authored Apr 29, 2022
2 parents 812f9c1 + 802f89b commit 2f01d36
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,42 @@ func.func @rank_reduced_slice() {
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[SRC_SUBVIEW]] :
// CHECK-SAME: outs(%[[DST_SUBVIEW]] :

// -----

// CHECK-LABEL: func @reverse_dim(
// CHECK-DAG: %[[alloc:.*]] = memref.alloc() : memref<2x3xf32>
// CHECK-DAG: %[[global:.*]] = memref.get_global
// CHECK: iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%[[global]] : memref<2x3xf32>) outs(%[[alloc]] : memref<2x3xf32>)
// CHECK: %[[load:.*]] = memref.load %[[alloc]]
// CHECK: memref.dealloc %[[alloc]]
// CHECK: return %[[load]]
func.func @reverse_dim(%pos: index) -> f32 {
%input = arith.constant dense<[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]> : tensor<2x3xf32>

%init = linalg.init_tensor [2, 3] : tensor<2x3xf32>
%0 = iree_linalg_ext.reverse
dimensions(dense<0> : tensor<1xi64>)
ins(%input : tensor<2x3xf32>)
outs(%init : tensor<2x3xf32>) : tensor<2x3xf32>

%1 = tensor.extract %0[%pos, %pos] : tensor<2x3xf32>
return %1 : f32
}

// -----

// CHECK-LABEL: func @fft_tensor(
// CHECK: memref.alloc
// CHECK: memref.alloc
// CHECK: iree_linalg_ext.fft ins(%{{.*}} : index) outs(%{{.*}}, %{{.*}} : memref<1024xf32>, memref<1024xf32>)
func @fft_tensor(%idx: index) -> (tensor<1024xf32>, tensor<1024xf32>) {
%t0 = linalg.init_tensor [1024] : tensor<1024xf32>
%t1 = linalg.init_tensor [1024] : tensor<1024xf32>
%0:2 = iree_linalg_ext.fft
ins(%idx: index)
outs(%t0, %t1: tensor<1024xf32>, tensor<1024xf32>)
: tensor<1024xf32>, tensor<1024xf32>
return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32>
}
1 change: 1 addition & 0 deletions iree/compiler/Codegen/Interfaces/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cc_library(
deps = [
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"@llvm-project//mlir:ArithmeticTransforms",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
Expand Down
123 changes: 123 additions & 0 deletions iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
Expand Down Expand Up @@ -160,6 +162,120 @@ struct DispatchTensorStoreOpInterface
};
} // namespace

/// Generic conversion for any LinalgExtOp on tensors.
static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter,
IREE::LinalgExt::LinalgExtOp op,
BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);

// Nothing to do. This op is already bufferized.
if (op.hasBufferSemantics()) return success();

// Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
// basis.
if (!op.hasTensorSemantics())
return op->emitError() << "op does not have tensor semantics";

// New input operands for the cloned op.
SmallVector<Value> newInputBuffers;
newInputBuffers.reserve(op.getNumInputs());
for (OpOperand *opOperand : op.getInputOperands()) {
if (op.isScalar(opOperand)) {
newInputBuffers.push_back(opOperand->get());
continue;
}
// Input operands are never written to.
newInputBuffers.push_back(*state.getBuffer(
rewriter, *opOperand,
BufferizationState::ForceInPlacability::FORCE_INPLACE));
}

// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
SmallVector<OpOperand *> aliasingOpOperands =
state.getAnalysisState().getAliasingOpOperand(opResult);
assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
FailureOr<Value> resultBuffer =
state.getBuffer(rewriter, *aliasingOpOperands.front());
if (failed(resultBuffer)) return failure();
newOutputBuffers.push_back(*resultBuffer);
}

// Merge input/output operands.
SmallVector<Value> newOperands = newInputBuffers;
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());

// Set insertion point now that potential alloc/dealloc are introduced.
rewriter.setInsertionPoint(op);
// Clone the op, but use the new operands. Move the existing block into the
// new op. Since the new op does not have any tensor results, it does not
// return anything.
auto newOp = cast<IREE::LinalgExt::LinalgExtOp>(op.cloneWithoutRegions(
rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
int64_t numRegions = op->getNumRegions();
for (int64_t i = 0; i < numRegions; ++i) {
rewriter.inlineRegionBefore(op->getRegion(i), newOp->getRegion(i),
newOp->getRegion(i).begin());
}

// Replace the results of the old op with the new output buffers.
bufferization::replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);

return success();
}

/// Bufferization of ops that implement the LinalgExtOp interface. Replace with
/// a new op that operates entirely on memrefs.
template <typename OpTy>
struct LinalgExtOpInterface
: public BufferizableOpInterface::ExternalModel<LinalgExtOpInterface<OpTy>,
OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// All operands (including outputs) may be read.
return true;
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Operand is written to if it has an aliasing OpResult.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
}

SmallVector<OpOperand *> getAliasingOpOperand(
Operation *op, OpResult opResult, const AnalysisState &state) const {
auto linalgExtOp = cast<IREE::LinalgExt::LinalgExtOp>(op);

// The i-th OpResult may alias with the i-th "out" tensor.
return {linalgExtOp.getOutputOperand(opResult.getResultNumber())};
}

SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto linalgExtOp = cast<IREE::LinalgExt::LinalgExtOp>(op);

// The i-th "out" tensor may alias with the i-th OpResult.
if (linalgExtOp.isOutputTensor(&opOperand))
return {linalgExtOp.getTiedOpResult(&opOperand)};
return {};
}

bufferization::BufferRelation bufferRelation(
Operation *op, OpResult opResult, const AnalysisState &state) const {
return bufferization::BufferRelation::Equivalent;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
return bufferizeLinalgExtOp(rewriter,
cast<IREE::LinalgExt::LinalgExtOp>(op), state);
}
};

//===----------------------------------------------------------------------===//
// IREE specific post analysis transformations.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -291,6 +407,13 @@ void registerBufferizationInterfaces(DialectRegistry &registry) {
IREE::Flow::DispatchTensorStoreOp::attachInterface<
DispatchTensorStoreOpInterface>(*ctx);
});
registry.addExtension(
+[](MLIRContext *ctx, IREE::LinalgExt::IREELinalgExtDialect *dialect) {
IREE::LinalgExt::ReverseOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ReverseOp>>(*ctx);
IREE::LinalgExt::FftOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::FftOp>>(*ctx);
});
}

void addPostAnalysisTransformations(OneShotBufferizationOptions &options) {
Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Codegen/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_cc_library(
SRCS
"BufferizationInterfaces.cpp"
DEPS
IREELinalgExtDialect
MLIRArithmeticTransforms
MLIRBufferization
MLIRBufferizationTransforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the result tied to `opOperand`.
}],
/*retTy=*/"OpResult",
/*methodName=*/"getTiedOpResult",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs();
assert(resultIndex >= 0 &&
resultIndex < this->getOperation()->getNumResults() );
return this->getOperation()->getResult(resultIndex);
}]
>,
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -423,6 +439,26 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
r.cloneInto(state.addRegion(), bvm);
return b.create(state);
}]
>,
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands but
leave the regions empty. This is used to abstract away the optional
underlying region creation. This does not change the balance between
input, output_buffer and init_tensors operands.
}],
/*retTy=*/"Operation *",
/*methodName=*/"cloneWithoutRegions",
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
state.addRegion();
return b.create(state);
}]
>
];

Expand All @@ -431,7 +467,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/// shape of the input operands where possible.
LogicalResult reifyResultShapes(OpBuilder &b,
mlir::ReifiedRankedShapedTypeDims &reifiedReturnShapes);

//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.
// These are useful when cloning and changing operand types.
Expand Down

0 comments on commit 2f01d36

Please sign in to comment.