Skip to content

Commit

Permalink
adding ndarray.from_memref
Browse files Browse the repository at this point in the history
optionally try to find defining tomemref when converting tensor to memref
lowering ndarray.from_memref, avoiding dynamic strides in to_memref
test for finding defining tomemref when converting tensor to memref
simplified
  • Loading branch information
fschlimb authored and silee2 committed Feb 14, 2024
1 parent 0367528 commit 5a4a097
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 17 deletions.
17 changes: 16 additions & 1 deletion include/imex/Dialect/NDArray/IR/NDArrayOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def NDArray_NDArray : NDArray_Type<"NDArray", "ndarray", [ShapedTypeInterface],
];

let extraClassDeclaration = [{
::mlir::MemRefType getMemRefType() const;
::mlir::MemRefType getMemRefType(::mlir::Value = {}) const;
::mlir::RankedTensorType getTensorType() const;
::imex::ndarray::NDArrayType cloneWithDynDims() const;
bool hasUnitSize() const;
Expand Down Expand Up @@ -175,6 +175,21 @@ def DeleteOp : NDArray_Op<"delete"> {
}


def FromMemRefOp : NDArray_Op<"from_memref", [Pure]> {
let summary = "Convert a builtin memref value to a value of type NDArray";
let description = [{
Result type possibly adds NDArray annotations.
}];

let arguments = (ins AnyMemRef:$input);
let results = (outs NDArray_NDArray);

let assemblyFormat = [{
$input attr-dict `:` qualified(type($input)) `->` qualified(type(results))
}];
}


def ToTensorOp : NDArray_Op<"to_tensor", [Pure]> {
let summary = "Convert a NDArray value to a value of MLIR's builtin tensor type";
let description = [{
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/DistToStandard/DistToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,7 +1693,7 @@ struct ConvertDistToStandardPass
if (rank) {
lOffs = createValuesFromMemRef(builder, loc, inputs[nParts]);
for (auto i = 0u; i < sOffs.size(); ++i) {
if (sOffs[i] != ::mlir::ShapedType::kDynamic) {
if (!::mlir::ShapedType::isDynamic(sOffs[i])) {
lOffs[i] = createIndex(loc, builder, sOffs[i]);
}
}
Expand Down
47 changes: 33 additions & 14 deletions lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,23 @@ struct CastLowering
}
};

/// Convert FromMemRefOp to bufferize.to_tensor
struct FromMemRefLowering
: public ::mlir::OpConversionPattern<::imex::ndarray::FromMemRefOp> {
using OpConversionPattern::OpConversionPattern;

::mlir::LogicalResult
matchAndRewrite(::imex::ndarray::FromMemRefOp op,
::imex::ndarray::FromMemRefOp::Adaptor adaptor,
::mlir::ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<::mlir::bufferization::ToTensorOp>(
op, adaptor.getInput());

return ::mlir::success();
}
};

/// Lower to the input operand of the defining op.
struct ToTensorLowering
: public ::mlir::OpConversionPattern<::imex::ndarray::ToTensorOp> {
Expand Down Expand Up @@ -163,7 +180,7 @@ struct SubviewLowering
.dyn_cast_or_null<::imex::ndarray::NDArrayType>();
if (!srcArType)
return mlir::failure();
auto srcMRType = srcArType.getMemRefType();
auto srcMRType = srcArType.getMemRefType(srcTnsr);
auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType);

auto *converter = getTypeConverter();
Expand Down Expand Up @@ -298,8 +315,8 @@ struct InsertSliceLowering
auto dstArType =
op.getDestination().getType().cast<imex::ndarray::NDArrayType>();

auto srcMRTyp = srcArType.getMemRefType();
auto dstMRTyp = dstArType.getMemRefType();
auto srcMRTyp = srcArType.getMemRefType(src);
auto dstMRTyp = dstArType.getMemRefType(dst);
mlir::Value srcMR = createToMemRef(loc, rewriter, src, srcMRTyp);
auto dstMR = createToMemRef(loc, rewriter, dst, dstMRTyp);

Expand Down Expand Up @@ -509,8 +526,8 @@ struct CopyLowering
loc, mrTyp, dynDims, rewriter.getI64IntegerAttr(8));
// and copy if non-0
if (!retArTyp.hasZeroSize()) {
auto srcMR = rewriter.create<::mlir::bufferization::ToMemrefOp>(
loc, srcArTyp.getMemRefType(), src);
auto srcMR =
createToMemRef(loc, rewriter, src, srcArTyp.getMemRefType(src));
// create a region with given env, add copy op within it
auto env = rewriter.getStringAttr("protect_copy_op");
rewriter.create<::imex::region::EnvironmentRegionOp>(
Expand Down Expand Up @@ -545,8 +562,9 @@ struct DeleteLowering
return ::mlir::failure();
}

auto inpMR = rewriter.create<::mlir::bufferization::ToMemrefOp>(
op.getLoc(), inpArType.getMemRefType(), adaptor.getInput());
auto inp = adaptor.getInput();
auto inpMR = createToMemRef(op.getLoc(), rewriter, inp,
inpArType.getMemRefType(inp));
rewriter.replaceOpWithNewOp<::mlir::memref::DeallocOp>(op, inpMR);

return ::mlir::success();
Expand Down Expand Up @@ -650,7 +668,8 @@ struct ReshapeLowering
auto cpyMR =
createToMemRef(loc, rewriter, cpy,
getMemRefType(op.getContext(), rank, elTyp, false));
auto srcMR = createToMemRef(loc, rewriter, src, srcArTyp.getMemRefType());
auto srcMR =
createToMemRef(loc, rewriter, src, srcArTyp.getMemRefType(src));
rewriter.create<::mlir::memref::CopyOp>(loc, srcMR, cpyMR);
src = cpy;
}
Expand Down Expand Up @@ -1304,12 +1323,12 @@ struct ConvertNDArrayToLinalgPass
[&](mlir::Operation *op) { return typeConverter.isLegal(op); });

::mlir::RewritePatternSet patterns(&ctxt);
patterns.insert<ToTensorLowering, SubviewLowering, ExtractSliceLowering,
InsertSliceLowering, ImmutableInsertSliceLowering,
LinSpaceLowering, LoadOpLowering, CreateLowering,
EWBinOpLowering, DimOpLowering, EWUnyOpLowering,
ReductionOpLowering, ReshapeLowering, CastLowering,
CopyLowering, DeleteLowering, CastElemTypeLowering>(
patterns.insert<
ToTensorLowering, SubviewLowering, ExtractSliceLowering,
InsertSliceLowering, ImmutableInsertSliceLowering, LinSpaceLowering,
LoadOpLowering, CreateLowering, EWBinOpLowering, DimOpLowering,
EWUnyOpLowering, ReductionOpLowering, ReshapeLowering, CastLowering,
CopyLowering, DeleteLowering, CastElemTypeLowering, FromMemRefLowering>(
typeConverter, &ctxt);
::imex::populateRegionTypeConversionPatterns(patterns, typeConverter);

Expand Down
12 changes: 11 additions & 1 deletion lib/Dialect/NDArray/IR/NDArrayOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
#include <imex/Utils/PassUtils.h>
#include <llvm/ADT/TypeSwitch.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Utils/StaticValueUtils.h>
#include <mlir/IR/DialectImplementation.h>

Expand Down Expand Up @@ -72,7 +73,16 @@ NDArrayType NDArrayType::get(::llvm::ArrayRef<int64_t> shape,
return get(ctx, shape, elementType, environments, layout);
}

::mlir::MemRefType NDArrayType::getMemRefType() const {
::mlir::MemRefType NDArrayType::getMemRefType(::mlir::Value val) const {
if (val) {
auto defOp = val.getDefiningOp<::mlir::bufferization::ToTensorOp>();
if (defOp) {
return defOp.getMemref()
.getType()
.cloneWith(getShape(), getElementType())
.cast<::mlir::MemRefType>();
}
}
return ::imex::getMemRefType(getContext(), getShape(), getElementType());
}

Expand Down
29 changes: 29 additions & 0 deletions test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ func.func @test_subview(%arg0: !ndarray.ndarray<?xi64>) -> !ndarray.ndarray<?xi6
// CHECK-NEXT: [[V2:%.*]] = bufferization.to_memref
// CHECK-NEXT: return [[V2]] : memref<?xi64, strided<[?], offset: ?>>

// -----
func.func @test_static_mr_2_tnsr_2_static_mr(%arg0: memref<55xi32, strided<[1], offset: 2>>) -> !ndarray.ndarray<3xi32> {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%nda = ndarray.from_memref %arg0 : memref<55xi32, strided<[1], offset: 2>> -> !ndarray.ndarray<55xi32>
%0 = ndarray.subview %nda[%c0][%c3][%c3] : !ndarray.ndarray<55xi32> to !ndarray.ndarray<3xi32>
return %0 : !ndarray.ndarray<3xi32>
}
// CHECK-LABEL: @test_static_mr_2_tnsr_2_static_mr
// CHECK-NEXT: [[C0:%.*]] = arith.constant
// CHECK-NEXT: [[C1:%.*]] = arith.constant
// CHECK-NEXT: [[V:%.*]] = bufferization.to_tensor %arg0 : memref<55xi32, strided<[1], offset: 2>>
// CHECK-NEXT: [[V0:%.*]] = bufferization.to_memref [[V]] : memref<55xi32, strided<[1], offset: 2>>
// CHECK-NEXT: [[S0:%.*]] = memref.subview [[V0]][[[C0]]] [[[C1]]] [[[C1]]] : memref<55xi32, strided<[1], offset: 2>> to memref<?xi32, strided<[?], offset: ?>>
// CHECK-NEXT: [[V1:%.*]] = bufferization.to_tensor [[S0]] writable : memref<?xi32, strided<[?], offset: ?>>
// CHECK-NEXT: [[V2:%.*]] = bufferization.to_memref
// CHECK-NEXT: [[V3:%.*]] = memref.cast [[V2]]
// CHECK-NEXT: return [[V3]] : memref<3xi32, strided<[?], offset: ?>>

// -----
func.func @test_extract_slice(%arg0: !ndarray.ndarray<?xi64>) -> !ndarray.ndarray<?xi64> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -439,3 +458,13 @@ func.func @test_cast_elemtype_copy(%arg0: !ndarray.ndarray<16xi32>) -> !ndarray.
// CHECK: region.env_region "protect_copy_op"
// CHECK-NEXT: memref.copy
// CHECK-NEXT: }

// -----
func.func @test_from_memref(%arg0: memref<5xi32, strided<[?], offset: ?>>) -> !ndarray.ndarray<5xi32> {
%0 = ndarray.from_memref %arg0 : memref<5xi32, strided<[?], offset: ?>> -> !ndarray.ndarray<5xi32>
return %0 : !ndarray.ndarray<5xi32>
}
// CHECK-LABEL: @test_from_memref
// CHECK: [[V0:%.*]] = bufferization.to_tensor
// CHECK: [[V1:%.*]] = bufferization.to_memref [[V0]]
// CHECK-NEXT: return [[V1]] : memref<5xi32, strided<[?], offset: ?>>
9 changes: 9 additions & 0 deletions test/Dialect/NDArray/IR/NDArrayOps.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,12 @@ func.func @test_castelem(%arg0: !ndarray.ndarray<5xi64>) -> !ndarray.ndarray<5xi
// CHECK-LABEL: func.func @test_castelem
// CHECK: [[V0:%.*]] = ndarray.cast_elemtype
// CHECK-NEXT: return [[V0]] : !ndarray.ndarray<5xi32>

// -----
func.func @test_from_memref(%arg0: memref<?xi32, strided<[?], offset: ?>>) -> !ndarray.ndarray<?xi32> {
%0 = ndarray.from_memref %arg0 : memref<?xi32, strided<[?], offset: ?>> -> !ndarray.ndarray<?xi32>
return %0 : !ndarray.ndarray<?xi32>
}
// CHECK-LABEL: func.func @test_from_memref
// CHECK: [[V0:%.*]] = ndarray.from_memref
// CHECK-NEXT: return [[V0]] : !ndarray.ndarray<?xi32>

0 comments on commit 5a4a097

Please sign in to comment.