From f474abbd57c43528fa6829bb617c0f1030395416 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 11 Dec 2024 15:26:00 +0100 Subject: [PATCH] moving bufferization of insertslice and subview from ndarraytolinalg to bufferizableinterface; add shardinginterface to reshape (incomplete) --- build_tools/llvm_version.txt | 2 +- .../Extensions/BufferizableOpInterfaceImpl.h | 25 +++ include/imex/Dialect/NDArray/IR/NDArrayOps.td | 11 +- .../NDArrayToLinalg/NDArrayToLinalg.cpp | 18 +- .../NDArray/Extensions/AllExtensions.cpp | 6 +- .../BufferizableOpInterfaceImpl.cpp | 168 ++++++++++++++++++ lib/Dialect/NDArray/Extensions/CMakeLists.txt | 4 +- .../Extensions/MeshShardingExtensions.cpp | 16 ++ 8 files changed, 237 insertions(+), 13 deletions(-) create mode 100644 include/imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h create mode 100644 lib/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.cpp diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index e002713fb..357f190a8 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -ecf1694333c05fc7180a2ad8fa80bbd709f35006 +0eeb79d76a8284fae3e5e3b4ebbbe98d02249235 diff --git a/include/imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h b/include/imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000..a5d517a2f --- /dev/null +++ b/include/imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h @@ -0,0 +1,25 @@ +//===- BufferizableOpInterfaceImpl.h - --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_NDARRAY_EXTENSIONS_BUFFERIZABLEOPINTERFACEIMPL_H_ +#define MLIR_DIALECT_NDARRAY_EXTENSIONS_BUFFERIZABLEOPINTERFACEIMPL_H_ + +namespace mlir { +class DialectRegistry; +} + +namespace imex { +namespace ndarray { + +void registerBufferizableOpInterfaceExternalModels( + mlir::DialectRegistry ®istry); + +} // namespace ndarray +} // namespace imex + +#endif // MLIR_DIALECT_NDARRAY_EXTENSIONS_BUFFERIZABLEOPINTERFACEIMPL_H_ diff --git a/include/imex/Dialect/NDArray/IR/NDArrayOps.td b/include/imex/Dialect/NDArray/IR/NDArrayOps.td index fec30e3e7..856f1e9a5 100644 --- a/include/imex/Dialect/NDArray/IR/NDArrayOps.td +++ b/include/imex/Dialect/NDArray/IR/NDArrayOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" // include "mlir/Interfaces/ShapedOpInterfaces.td" // include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/IR/OpAsmInterface.td" @@ -293,6 +294,10 @@ def SubviewOp : NDArray_OpWithOffsetSizesAndStrides<"subview", [ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [ AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, + DestinationStyleOpInterface, + Pure, + TypesMatchWith<"expected result type to match dest type", + "destination", "result", "$_self"> ]> { let summary = "Copy values from a array into a slice of another."; let description = [{ @@ -312,6 +317,7 @@ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [ DenseI64ArrayAttr:$static_sizes, DenseI64ArrayAttr:$static_strides ); + let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source `into` $destination `` @@ -377,6 +383,7 @@ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [ /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } + mlir::MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); } }]; let hasCanonicalizer = 1; @@ -408,8 +415,8 @@ def ReshapeOp : NDArray_Op<"reshape", []> { See Array API. }]; - let arguments = (ins AnyType:$source, Variadic:$shape, OptionalAttr:$copy); - let results = (outs AnyType); + let arguments = (ins AnyRankedTensor:$source, Variadic:$shape, OptionalAttr:$copy); + let results = (outs AnyRankedTensor); let assemblyFormat = [{ $source $shape attr-dict `:` qualified(type($source)) `->` qualified(type(results)) diff --git a/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp b/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp index 9dc00ab38..29016d7c3 100644 --- a/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp +++ b/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp @@ -430,20 +430,24 @@ struct ConvertNDArrayToLinalgPass }); ::mlir::ConversionTarget target(ctxt); - // We convert all NDArray stuff... - target.addIllegalDialect<::imex::ndarray::NDArrayDialect>(); // ...into Linalg, Affine, Tensor, Arith target.addLegalDialect< ::mlir::linalg::LinalgDialect, ::mlir::arith::ArithDialect, ::mlir::memref::MemRefDialect, ::mlir::tensor::TensorDialect, ::mlir::bufferization::BufferizationDialect, ::mlir::func::FuncDialect, ::imex::region::RegionDialect>(); - target.addLegalOp(); - + target.addLegalOp(); + + // We convert almost all NDArray stuff... + target.addDynamicallyLegalDialect<::imex::ndarray::NDArrayDialect>( + [&](mlir::Operation *op) { + return mlir::isa(op); + }); ::mlir::RewritePatternSet patterns(&ctxt); - patterns.insert(&ctxt); + patterns.insert(&ctxt); if (::mlir::failed(::mlir::applyPartialConversion(getOperation(), target, ::std::move(patterns)))) { diff --git a/lib/Dialect/NDArray/Extensions/AllExtensions.cpp b/lib/Dialect/NDArray/Extensions/AllExtensions.cpp index 9aa1f6eee..c52b1964f 100644 --- a/lib/Dialect/NDArray/Extensions/AllExtensions.cpp +++ b/lib/Dialect/NDArray/Extensions/AllExtensions.cpp @@ -1,4 +1,4 @@ -//===- AllExtensions.cpp - All NDArray Dialect Extensions ------------------===// +//===- AllExtensions.cpp - All NDArray Dialect Extensions -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,10 +7,12 @@ //===----------------------------------------------------------------------===// #include "imex/Dialect/NDArray/Extensions/AllExtensions.h" +#include "imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h" #include "imex/Dialect/NDArray/Extensions/MeshShardingExtensions.h" using namespace mlir; void imex::ndarray::registerAllExtensions(DialectRegistry ®istry) { registerShardingInterfaceExternalModels(registry); -} \ No newline at end of file + registerBufferizableOpInterfaceExternalModels(registry); +} diff --git a/lib/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.cpp b/lib/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..4ad7d8a8c --- /dev/null +++ b/lib/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,168 @@ +//===- ShardingInterfaceImpl.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h" +#include "imex/Dialect/NDArray/IR/NDArrayOps.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" + +using namespace mlir; +using namespace bufferization; + +namespace imex { +namespace ndarray { +namespace { + +/// Bufferization of tensor.extract_slice. Replace with memref.subview. +struct SubviewOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {{op->getOpResult(0), BufferRelation::Unknown}}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto subviewOp = cast(op); + SmallVector mixedOffsets = subviewOp.getMixedOffsets(); + SmallVector mixedSizes = subviewOp.getMixedSizes(); + SmallVector mixedStrides = subviewOp.getMixedStrides(); + Location loc = subviewOp.getLoc(); + + // Get source buffer. + FailureOr srcMemref = + getBuffer(rewriter, subviewOp.getSource(), options); + if (failed(srcMemref)) + return failure(); + + // Take a subview of the source buffer. + auto resultMemrefType = + bufferization::getBufferType(subviewOp.getResult(), options); + if (failed(resultMemrefType)) + return failure(); + Value subView = rewriter.create( + loc, llvm::cast(*resultMemrefType), *srcMemref, + mixedOffsets, mixedSizes, mixedStrides); + + replaceOpWithBufferizedValues(rewriter, op, subView); + return success(); + } + + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + SmallVector &invocationStack) const { + auto subviewOp = cast(op); + assert(value == subviewOp.getResult() && "invalid value"); + auto srcMemrefType = bufferization::getBufferType(subviewOp.getSource(), + options, invocationStack); + if (failed(srcMemrefType)) + return failure(); + SmallVector mixedOffsets = subviewOp.getMixedOffsets(); + SmallVector mixedSizes = subviewOp.getMixedSizes(); + SmallVector mixedStrides = subviewOp.getMixedStrides(); + return cast(memref::SubViewOp::inferRankReducedResultType( + subviewOp.getType().getShape(), llvm::cast(*srcMemrefType), + mixedOffsets, mixedSizes, mixedStrides)); + } +}; + +/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under +/// certain circumstances, this op can also be a no-op. +/// +/// Note: DstBufferizableOpInterfaceExternalModel provides many default method +/// implementations for DestinationStyle ops. +struct InsertSliceOpInterface + : public DstBufferizableOpInterfaceExternalModel< + InsertSliceOpInterface, imex::ndarray::InsertSliceOp> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return opOperand == + cast(op).getSourceMutable(); + } + + bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uWrite, + const AnalysisState &state) const { + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + // insert_slice ops arise from tiling and bufferizing them out-of-place is + // generally a deal breaker. When used with loops, this ends up cloning the + // whole tensor on every single iteration and is a symptom of a + // catastrophically bad scheduling decision. + // TODO: be very loud about it or even consider failing the pass. + auto insertSliceOp = cast(op); + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); + Location loc = insertSliceOp.getLoc(); + + // Get destination buffer. + FailureOr dstMemref = + getBuffer(rewriter, insertSliceOp.getDestination(), options); + if (failed(dstMemref)) + return failure(); + + // Take a subview of the destination buffer. + auto dstMemrefType = cast(dstMemref->getType()); + auto subviewMemRefType = + cast(memref::SubViewOp::inferRankReducedResultType( + insertSliceOp.getSourceType().getShape(), dstMemrefType, + mixedOffsets, mixedSizes, mixedStrides)); + Value subView = rewriter.create( + loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, + mixedStrides); + + // Copy tensor. If this tensor.insert_slice has a matching + // tensor.extract_slice, the copy operation will eventually fold away. + FailureOr srcMemref = + getBuffer(rewriter, insertSliceOp.getSource(), options); + if (failed(srcMemref)) + return failure(); + if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) + return failure(); + + replaceOpWithBufferizedValues(rewriter, op, *dstMemref); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Interface registration +//===----------------------------------------------------------------------===// + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, imex::ndarray::NDArrayDialect *dialect) { + InsertSliceOp::attachInterface(*ctx); + SubviewOp::attachInterface(*ctx); + }); +} + +} // namespace ndarray +} // namespace imex diff --git a/lib/Dialect/NDArray/Extensions/CMakeLists.txt b/lib/Dialect/NDArray/Extensions/CMakeLists.txt index 3e6299405..75af15809 100644 --- a/lib/Dialect/NDArray/Extensions/CMakeLists.txt +++ b/lib/Dialect/NDArray/Extensions/CMakeLists.txt @@ -1,10 +1,12 @@ set(LLVM_OPTIONAL_SOURCES AllExtensions.cpp + BufferizableOpInterfaceImpl.cpp MeshShardingExtensions.cpp ) add_imex_extension_library(IMEXNDArrayMeshShardingExtensions MeshShardingExtensions.cpp + BufferizableOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/mlir/Dialect/NDArray/Extensions @@ -23,4 +25,4 @@ add_imex_extension_library(IMEXNDArrayAllExtensions LINK_LIBS PUBLIC IMEXNDArrayMeshShardingExtensions - ) \ No newline at end of file + ) diff --git a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp index 401318605..41dd3c1f8 100644 --- a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp +++ b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp @@ -654,6 +654,21 @@ struct LinspaceShardingInterface } }; +//===----------------------------------------------------------------------===// +// ReshapeShardingInterface +//===----------------------------------------------------------------------===// + +struct ReshapeShardingInterface + : public BaseShardingInterface { + + SmallVector + getLoopIteratorTypes(::mlir::Operation *op) const { + auto rsop = cast(op); + size_t rank = std::max(rsop.getSource().getType().getRank(), + rsop.getResult().getType().getRank()); + return {rank, utils::IteratorType::parallel}; + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -672,6 +687,7 @@ void registerShardingInterfaceExternalModels(mlir::DialectRegistry ®istry) { SubviewOp::attachInterface(*ctx); InsertSliceOp::attachInterface(*ctx); LinSpaceOp::attachInterface(*ctx); + ReshapeOp::attachInterface(*ctx); registerTrivial(ctx); }); }