diff --git a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir index 391d10c5980a..174c242f5ad9 100644 --- a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir +++ b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir @@ -233,3 +233,42 @@ func.func @rank_reduced_slice() { // CHECK: linalg.generic // CHECK-SAME: ins(%[[SRC_SUBVIEW]] : // CHECK-SAME: outs(%[[DST_SUBVIEW]] : + +// ----- + +// CHECK-LABEL: func @reverse_dim( +// CHECK-DAG: %[[alloc:.*]] = memref.alloc() : memref<2x3xf32> +// CHECK-DAG: %[[global:.*]] = memref.get_global +// CHECK: iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%[[global]] : memref<2x3xf32>) outs(%[[alloc]] : memref<2x3xf32>) +// CHECK: %[[load:.*]] = memref.load %[[alloc]] +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return %[[load]] +func.func @reverse_dim(%pos: index) -> f32 { + %input = arith.constant dense<[[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]]> : tensor<2x3xf32> + + %init = linalg.init_tensor [2, 3] : tensor<2x3xf32> + %0 = iree_linalg_ext.reverse + dimensions(dense<0> : tensor<1xi64>) + ins(%input : tensor<2x3xf32>) + outs(%init : tensor<2x3xf32>) : tensor<2x3xf32> + + %1 = tensor.extract %0[%pos, %pos] : tensor<2x3xf32> + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func @fft_tensor( +// CHECK: memref.alloc +// CHECK: memref.alloc +// CHECK: iree_linalg_ext.fft ins(%{{.*}} : index) outs(%{{.*}}, %{{.*}} : memref<1024xf32>, memref<1024xf32>) +func @fft_tensor(%idx: index) -> (tensor<1024xf32>, tensor<1024xf32>) { + %t0 = linalg.init_tensor [1024] : tensor<1024xf32> + %t1 = linalg.init_tensor [1024] : tensor<1024xf32> + %0:2 = iree_linalg_ext.fft + ins(%idx: index) + outs(%t0, %t1: tensor<1024xf32>, tensor<1024xf32>) + : tensor<1024xf32>, tensor<1024xf32> + return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32> +} diff --git a/iree/compiler/Codegen/Interfaces/BUILD b/iree/compiler/Codegen/Interfaces/BUILD index 965159f135a4..1dd1126ff0b2 100644 --- a/iree/compiler/Codegen/Interfaces/BUILD +++ b/iree/compiler/Codegen/Interfaces/BUILD @@ -58,6 +58,7 @@ cc_library( deps = [ "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/HAL/IR", + "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", "@llvm-project//mlir:ArithmeticTransforms", "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:BufferizationTransforms", diff --git a/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp index 1fdbbe580757..10d3fed93ad5 100644 --- a/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp +++ b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp @@ -6,6 +6,8 @@ #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h" +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" @@ -160,6 +162,120 @@ struct DispatchTensorStoreOpInterface }; } // namespace +/// Generic conversion for any LinalgExtOp on tensors. +static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, + IREE::LinalgExt::LinalgExtOp op, + BufferizationState &state) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + // Nothing to do. This op is already bufferized. + if (op.hasBufferSemantics()) return success(); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasTensorSemantics()) + return op->emitError() << "op does not have tensor semantics"; + + // New input operands for the cloned op. + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumInputs()); + for (OpOperand *opOperand : op.getInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + // Input operands are never written to. + newInputBuffers.push_back(*state.getBuffer( + rewriter, *opOperand, + BufferizationState::ForceInPlacability::FORCE_INPLACE)); + } + + // New output operands for the cloned op. + SmallVector newOutputBuffers; + for (OpResult opResult : op->getOpResults()) { + SmallVector aliasingOpOperands = + state.getAnalysisState().getAliasingOpOperand(opResult); + assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); + FailureOr resultBuffer = + state.getBuffer(rewriter, *aliasingOpOperands.front()); + if (failed(resultBuffer)) return failure(); + newOutputBuffers.push_back(*resultBuffer); + } + + // Merge input/output operands. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + rewriter.setInsertionPoint(op); + // Clone the op, but use the new operands. Move the existing block into the + // new op. Since the new op does not have any tensor results, it does not + // return anything. + auto newOp = cast(op.cloneWithoutRegions( + rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); + int64_t numRegions = op->getNumRegions(); + for (int64_t i = 0; i < numRegions; ++i) { + rewriter.inlineRegionBefore(op->getRegion(i), newOp->getRegion(i), + newOp->getRegion(i).begin()); + } + + // Replace the results of the old op with the new output buffers. + bufferization::replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); + + return success(); +} + +/// Bufferization of ops that implement the LinalgExtOp interface. Replace with +/// a new op that operates entirely on memrefs. +template +struct LinalgExtOpInterface + : public BufferizableOpInterface::ExternalModel, + OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // All operands (including outputs) may be read. + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Operand is written to if it has an aliasing OpResult. + auto bufferizableOp = cast(op); + return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); + } + + SmallVector getAliasingOpOperand( + Operation *op, OpResult opResult, const AnalysisState &state) const { + auto linalgExtOp = cast(op); + + // The i-th OpResult may alias with the i-th "out" tensor. + return {linalgExtOp.getOutputOperand(opResult.getResultNumber())}; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto linalgExtOp = cast(op); + + // The i-th "out" tensor may alias with the i-th OpResult. + if (linalgExtOp.isOutputTensor(&opOperand)) + return {linalgExtOp.getTiedOpResult(&opOperand)}; + return {}; + } + + bufferization::BufferRelation bufferRelation( + Operation *op, OpResult opResult, const AnalysisState &state) const { + return bufferization::BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { + return bufferizeLinalgExtOp(rewriter, + cast(op), state); + } +}; + //===----------------------------------------------------------------------===// // IREE specific post analysis transformations. //===----------------------------------------------------------------------===// @@ -291,6 +407,13 @@ void registerBufferizationInterfaces(DialectRegistry ®istry) { IREE::Flow::DispatchTensorStoreOp::attachInterface< DispatchTensorStoreOpInterface>(*ctx); }); + registry.addExtension( + +[](MLIRContext *ctx, IREE::LinalgExt::IREELinalgExtDialect *dialect) { + IREE::LinalgExt::ReverseOp::attachInterface< + LinalgExtOpInterface>(*ctx); + IREE::LinalgExt::FftOp::attachInterface< + LinalgExtOpInterface>(*ctx); + }); } void addPostAnalysisTransformations(OneShotBufferizationOptions &options) { diff --git a/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/iree/compiler/Codegen/Interfaces/CMakeLists.txt index 8bb02ad9eb87..7e197ec62a57 100644 --- a/iree/compiler/Codegen/Interfaces/CMakeLists.txt +++ b/iree/compiler/Codegen/Interfaces/CMakeLists.txt @@ -32,6 +32,7 @@ iree_cc_library( SRCS "BufferizationInterfaces.cpp" DEPS + IREELinalgExtDialect MLIRArithmeticTransforms MLIRBufferization MLIRBufferizationTransforms diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td index d0bc200dd7b5..6a525ded7c8d 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td @@ -259,6 +259,22 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> { return result; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the result tied to `opOperand`. + }], + /*retTy=*/"OpResult", + /*methodName=*/"getTiedOpResult", + /*args=*/(ins "OpOperand*":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults() ); + return this->getOperation()->getResult(resultIndex); + }] + >, //===------------------------------------------------------------------===// // Input and Output arguments handling. //===------------------------------------------------------------------===// @@ -423,6 +439,26 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> { r.cloneInto(state.addRegion(), bvm); return b.create(state); }] + >, + InterfaceMethod< + /*desc=*/[{ + Clone the current operation with the given location and operands but + leave the regions empty. This is used to abstract away the optional + underlying region creation. This does not change the balance between + input, output_buffer and init_tensors operands. + }], + /*retTy=*/"Operation *", + /*methodName=*/"cloneWithoutRegions", + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt) + state.addRegion(); + return b.create(state); + }] > ]; @@ -431,7 +467,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> { /// shape of the input operands where possible. LogicalResult reifyResultShapes(OpBuilder &b, mlir::ReifiedRankedShapedTypeDims &reifiedReturnShapes); - + //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. // These are useful when cloning and changing operand types.