Skip to content

Commit

Permalink
moving bufferization of insertslice and subview from ndarraytolinalg …
Browse files Browse the repository at this point in the history
…to bufferizableinterface; add shardinginterface to reshape (incomplete)
  • Loading branch information
fschlimb committed Dec 11, 2024
1 parent bf4ba36 commit f474abb
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 13 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ecf1694333c05fc7180a2ad8fa80bbd709f35006
0eeb79d76a8284fae3e5e3b4ebbbe98d02249235
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===- BufferizableOpInterfaceImpl.h - --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_NDARRAY_EXTENSIONS_BUFFERIZABLEOPINTERFACEIMPL_H_
#define MLIR_DIALECT_NDARRAY_EXTENSIONS_BUFFERIZABLEOPINTERFACEIMPL_H_

namespace mlir {
class DialectRegistry;
}

namespace imex {
namespace ndarray {

void registerBufferizableOpInterfaceExternalModels(
mlir::DialectRegistry &registry);

} // namespace ndarray
} // namespace imex

#endif // MLIR_DIALECT_NDARRAY_EXTENSIONS_BUFFERIZABLEOPINTERFACEIMPL_H_
11 changes: 9 additions & 2 deletions include/imex/Dialect/NDArray/IR/NDArrayOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
// include "mlir/Interfaces/ShapedOpInterfaces.td"
// include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/IR/OpAsmInterface.td"
Expand Down Expand Up @@ -293,6 +294,10 @@ def SubviewOp : NDArray_OpWithOffsetSizesAndStrides<"subview", [
def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DestinationStyleOpInterface,
Pure,
TypesMatchWith<"expected result type to match dest type",
"destination", "result", "$_self">
]> {
let summary = "Copy values from a array into a slice of another.";
let description = [{
Expand All @@ -312,6 +317,7 @@ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);

let assemblyFormat = [{
$source `into` $destination ``
Expand Down Expand Up @@ -377,6 +383,7 @@ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }

mlir::MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
}];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -408,8 +415,8 @@ def ReshapeOp : NDArray_Op<"reshape", []> {
See Array API.
}];

let arguments = (ins AnyType:$source, Variadic<Index>:$shape, OptionalAttr<I1Attr>:$copy);
let results = (outs AnyType);
let arguments = (ins AnyRankedTensor:$source, Variadic<Index>:$shape, OptionalAttr<I1Attr>:$copy);
let results = (outs AnyRankedTensor);

let assemblyFormat = [{
$source $shape attr-dict `:` qualified(type($source)) `->` qualified(type(results))
Expand Down
18 changes: 11 additions & 7 deletions lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,20 +430,24 @@ struct ConvertNDArrayToLinalgPass
});

::mlir::ConversionTarget target(ctxt);
// We convert all NDArray stuff...
target.addIllegalDialect<::imex::ndarray::NDArrayDialect>();
// ...into Linalg, Affine, Tensor, Arith
target.addLegalDialect<
::mlir::linalg::LinalgDialect, ::mlir::arith::ArithDialect,
::mlir::memref::MemRefDialect, ::mlir::tensor::TensorDialect,
::mlir::bufferization::BufferizationDialect, ::mlir::func::FuncDialect,
::imex::region::RegionDialect>();
target.addLegalOp<mlir::UnrealizedConversionCastOp>();

target.addLegalOp<imex::ndarray::SubviewOp, imex::ndarray::InsertSliceOp,
mlir::UnrealizedConversionCastOp>();

// We convert almost all NDArray stuff...
target.addDynamicallyLegalDialect<::imex::ndarray::NDArrayDialect>(
[&](mlir::Operation *op) {
return mlir::isa<imex::ndarray::SubviewOp,
imex::ndarray::InsertSliceOp>(op);
});
::mlir::RewritePatternSet patterns(&ctxt);
patterns.insert<SubviewLowering, InsertSliceLowering, LinSpaceLowering,
ReshapeLowering, CopyLowering, DeleteLowering,
CastElemTypeLowering>(&ctxt);
patterns.insert<LinSpaceLowering, ReshapeLowering, CopyLowering,
DeleteLowering, CastElemTypeLowering>(&ctxt);

if (::mlir::failed(::mlir::applyPartialConversion(getOperation(), target,
::std::move(patterns)))) {
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/NDArray/Extensions/AllExtensions.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AllExtensions.cpp - All NDArray Dialect Extensions ------------------===//
//===- AllExtensions.cpp - All NDArray Dialect Extensions -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -7,10 +7,12 @@
//===----------------------------------------------------------------------===//

#include "imex/Dialect/NDArray/Extensions/AllExtensions.h"
#include "imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h"
#include "imex/Dialect/NDArray/Extensions/MeshShardingExtensions.h"

using namespace mlir;

void imex::ndarray::registerAllExtensions(DialectRegistry &registry) {
registerShardingInterfaceExternalModels(registry);
}
registerBufferizableOpInterfaceExternalModels(registry);
}
168 changes: 168 additions & 0 deletions lib/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h"
#include "imex/Dialect/NDArray/IR/NDArrayOps.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"

using namespace mlir;
using namespace bufferization;

namespace imex {
namespace ndarray {
namespace {

/// Bufferization of tensor.extract_slice. Replace with memref.subview.
struct SubviewOpInterface
: public BufferizableOpInterface::ExternalModel<SubviewOpInterface,
SubviewOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {{op->getOpResult(0), BufferRelation::Unknown}};
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto subviewOp = cast<SubviewOp>(op);
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = subviewOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = subviewOp.getMixedStrides();
Location loc = subviewOp.getLoc();

// Get source buffer.
FailureOr<Value> srcMemref =
getBuffer(rewriter, subviewOp.getSource(), options);
if (failed(srcMemref))
return failure();

// Take a subview of the source buffer.
auto resultMemrefType =
bufferization::getBufferType(subviewOp.getResult(), options);
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
mixedOffsets, mixedSizes, mixedStrides);

replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
}

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto subviewOp = cast<SubviewOp>(op);
assert(value == subviewOp.getResult() && "invalid value");
auto srcMemrefType = bufferization::getBufferType(subviewOp.getSource(),
options, invocationStack);
if (failed(srcMemrefType))
return failure();
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = subviewOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = subviewOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
subviewOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
mixedOffsets, mixedSizes, mixedStrides));
}
};

/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
///
/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
/// implementations for DestinationStyle ops.
struct InsertSliceOpInterface
: public DstBufferizableOpInterfaceExternalModel<
InsertSliceOpInterface, imex::ndarray::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return opOperand ==
cast<imex::ndarray::InsertSliceOp>(op).getSourceMutable();
}

bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}

bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uWrite,
const AnalysisState &state) const {
return true;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
// catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<imex::ndarray::InsertSliceOp>(op);
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
Location loc = insertSliceOp.getLoc();

// Get destination buffer.
FailureOr<Value> dstMemref =
getBuffer(rewriter, insertSliceOp.getDestination(), options);
if (failed(dstMemref))
return failure();

// Take a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
auto subviewMemRefType =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getShape(), dstMemrefType,
mixedOffsets, mixedSizes, mixedStrides));
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);

// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
FailureOr<Value> srcMemref =
getBuffer(rewriter, insertSliceOp.getSource(), options);
if (failed(srcMemref))
return failure();
if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
return failure();

replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
return success();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// Interface registration
//===----------------------------------------------------------------------===//

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(
+[](MLIRContext *ctx, imex::ndarray::NDArrayDialect *dialect) {
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
SubviewOp::attachInterface<SubviewOpInterface>(*ctx);
});
}

} // namespace ndarray
} // namespace imex
4 changes: 3 additions & 1 deletion lib/Dialect/NDArray/Extensions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
BufferizableOpInterfaceImpl.cpp
MeshShardingExtensions.cpp
)

add_imex_extension_library(IMEXNDArrayMeshShardingExtensions
MeshShardingExtensions.cpp
BufferizableOpInterfaceImpl.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/mlir/Dialect/NDArray/Extensions
Expand All @@ -23,4 +25,4 @@ add_imex_extension_library(IMEXNDArrayAllExtensions

LINK_LIBS PUBLIC
IMEXNDArrayMeshShardingExtensions
)
)
16 changes: 16 additions & 0 deletions lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,21 @@ struct LinspaceShardingInterface
}
};

//===----------------------------------------------------------------------===//
// ReshapeShardingInterface
//===----------------------------------------------------------------------===//

struct ReshapeShardingInterface
: public BaseShardingInterface<ReshapeShardingInterface, ReshapeOp> {

SmallVector<mlir::utils::IteratorType>
getLoopIteratorTypes(::mlir::Operation *op) const {
auto rsop = cast<ReshapeOp>(op);
size_t rank = std::max(rsop.getSource().getType().getRank(),
rsop.getResult().getType().getRank());
return {rank, utils::IteratorType::parallel};
}
};
} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -672,6 +687,7 @@ void registerShardingInterfaceExternalModels(mlir::DialectRegistry &registry) {
SubviewOp::attachInterface<SubviewShardingInterface>(*ctx);
InsertSliceOp::attachInterface<InsertSliceShardingInterface>(*ctx);
LinSpaceOp::attachInterface<LinspaceShardingInterface>(*ctx);
ReshapeOp::attachInterface<ReshapeShardingInterface>(*ctx);
registerTrivial<CopyOp, DeleteOp, CastElemTypeOp>(ctx);
});
}
Expand Down

0 comments on commit f474abb

Please sign in to comment.