Skip to content

Commit 5a4a097

Browse files
fschlimbsilee2
authored andcommitted
adding ndarray.from_memref
optionally try to find defining tomemref when converting tensor to memref lowering ndarray.from_memref, avoiding dynamic strides in to_memref test for finding defining tomemref when converting tensor to memref simplified
1 parent 0367528 commit 5a4a097

File tree

6 files changed

+99
-17
lines changed

6 files changed

+99
-17
lines changed

include/imex/Dialect/NDArray/IR/NDArrayOps.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def NDArray_NDArray : NDArray_Type<"NDArray", "ndarray", [ShapedTypeInterface],
130130
];
131131

132132
let extraClassDeclaration = [{
133-
::mlir::MemRefType getMemRefType() const;
133+
::mlir::MemRefType getMemRefType(::mlir::Value = {}) const;
134134
::mlir::RankedTensorType getTensorType() const;
135135
::imex::ndarray::NDArrayType cloneWithDynDims() const;
136136
bool hasUnitSize() const;
@@ -175,6 +175,21 @@ def DeleteOp : NDArray_Op<"delete"> {
175175
}
176176

177177

178+
def FromMemRefOp : NDArray_Op<"from_memref", [Pure]> {
179+
let summary = "Convert a builtin memref value to a value of type NDArray";
180+
let description = [{
181+
Result type possibly adds NDArray annotations.
182+
}];
183+
184+
let arguments = (ins AnyMemRef:$input);
185+
let results = (outs NDArray_NDArray);
186+
187+
let assemblyFormat = [{
188+
$input attr-dict `:` qualified(type($input)) `->` qualified(type(results))
189+
}];
190+
}
191+
192+
178193
def ToTensorOp : NDArray_Op<"to_tensor", [Pure]> {
179194
let summary = "Convert a NDArray value to a value of MLIR's builtin tensor type";
180195
let description = [{

lib/Conversion/DistToStandard/DistToStandard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ struct ConvertDistToStandardPass
16931693
if (rank) {
16941694
lOffs = createValuesFromMemRef(builder, loc, inputs[nParts]);
16951695
for (auto i = 0u; i < sOffs.size(); ++i) {
1696-
if (sOffs[i] != ::mlir::ShapedType::kDynamic) {
1696+
if (!::mlir::ShapedType::isDynamic(sOffs[i])) {
16971697
lOffs[i] = createIndex(loc, builder, sOffs[i]);
16981698
}
16991699
}

lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,23 @@ struct CastLowering
127127
}
128128
};
129129

130+
/// Convert FromMemRefOp to bufferize.to_tensor
131+
struct FromMemRefLowering
132+
: public ::mlir::OpConversionPattern<::imex::ndarray::FromMemRefOp> {
133+
using OpConversionPattern::OpConversionPattern;
134+
135+
::mlir::LogicalResult
136+
matchAndRewrite(::imex::ndarray::FromMemRefOp op,
137+
::imex::ndarray::FromMemRefOp::Adaptor adaptor,
138+
::mlir::ConversionPatternRewriter &rewriter) const override {
139+
140+
rewriter.replaceOpWithNewOp<::mlir::bufferization::ToTensorOp>(
141+
op, adaptor.getInput());
142+
143+
return ::mlir::success();
144+
}
145+
};
146+
130147
/// Lower to the input operand of the defining op.
131148
struct ToTensorLowering
132149
: public ::mlir::OpConversionPattern<::imex::ndarray::ToTensorOp> {
@@ -163,7 +180,7 @@ struct SubviewLowering
163180
.dyn_cast_or_null<::imex::ndarray::NDArrayType>();
164181
if (!srcArType)
165182
return mlir::failure();
166-
auto srcMRType = srcArType.getMemRefType();
183+
auto srcMRType = srcArType.getMemRefType(srcTnsr);
167184
auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType);
168185

169186
auto *converter = getTypeConverter();
@@ -298,8 +315,8 @@ struct InsertSliceLowering
298315
auto dstArType =
299316
op.getDestination().getType().cast<imex::ndarray::NDArrayType>();
300317

301-
auto srcMRTyp = srcArType.getMemRefType();
302-
auto dstMRTyp = dstArType.getMemRefType();
318+
auto srcMRTyp = srcArType.getMemRefType(src);
319+
auto dstMRTyp = dstArType.getMemRefType(dst);
303320
mlir::Value srcMR = createToMemRef(loc, rewriter, src, srcMRTyp);
304321
auto dstMR = createToMemRef(loc, rewriter, dst, dstMRTyp);
305322

@@ -509,8 +526,8 @@ struct CopyLowering
509526
loc, mrTyp, dynDims, rewriter.getI64IntegerAttr(8));
510527
// and copy if non-0
511528
if (!retArTyp.hasZeroSize()) {
512-
auto srcMR = rewriter.create<::mlir::bufferization::ToMemrefOp>(
513-
loc, srcArTyp.getMemRefType(), src);
529+
auto srcMR =
530+
createToMemRef(loc, rewriter, src, srcArTyp.getMemRefType(src));
514531
// create a region with given env, add copy op within it
515532
auto env = rewriter.getStringAttr("protect_copy_op");
516533
rewriter.create<::imex::region::EnvironmentRegionOp>(
@@ -545,8 +562,9 @@ struct DeleteLowering
545562
return ::mlir::failure();
546563
}
547564

548-
auto inpMR = rewriter.create<::mlir::bufferization::ToMemrefOp>(
549-
op.getLoc(), inpArType.getMemRefType(), adaptor.getInput());
565+
auto inp = adaptor.getInput();
566+
auto inpMR = createToMemRef(op.getLoc(), rewriter, inp,
567+
inpArType.getMemRefType(inp));
550568
rewriter.replaceOpWithNewOp<::mlir::memref::DeallocOp>(op, inpMR);
551569

552570
return ::mlir::success();
@@ -650,7 +668,8 @@ struct ReshapeLowering
650668
auto cpyMR =
651669
createToMemRef(loc, rewriter, cpy,
652670
getMemRefType(op.getContext(), rank, elTyp, false));
653-
auto srcMR = createToMemRef(loc, rewriter, src, srcArTyp.getMemRefType());
671+
auto srcMR =
672+
createToMemRef(loc, rewriter, src, srcArTyp.getMemRefType(src));
654673
rewriter.create<::mlir::memref::CopyOp>(loc, srcMR, cpyMR);
655674
src = cpy;
656675
}
@@ -1304,12 +1323,12 @@ struct ConvertNDArrayToLinalgPass
13041323
[&](mlir::Operation *op) { return typeConverter.isLegal(op); });
13051324

13061325
::mlir::RewritePatternSet patterns(&ctxt);
1307-
patterns.insert<ToTensorLowering, SubviewLowering, ExtractSliceLowering,
1308-
InsertSliceLowering, ImmutableInsertSliceLowering,
1309-
LinSpaceLowering, LoadOpLowering, CreateLowering,
1310-
EWBinOpLowering, DimOpLowering, EWUnyOpLowering,
1311-
ReductionOpLowering, ReshapeLowering, CastLowering,
1312-
CopyLowering, DeleteLowering, CastElemTypeLowering>(
1326+
patterns.insert<
1327+
ToTensorLowering, SubviewLowering, ExtractSliceLowering,
1328+
InsertSliceLowering, ImmutableInsertSliceLowering, LinSpaceLowering,
1329+
LoadOpLowering, CreateLowering, EWBinOpLowering, DimOpLowering,
1330+
EWUnyOpLowering, ReductionOpLowering, ReshapeLowering, CastLowering,
1331+
CopyLowering, DeleteLowering, CastElemTypeLowering, FromMemRefLowering>(
13131332
typeConverter, &ctxt);
13141333
::imex::populateRegionTypeConversionPatterns(patterns, typeConverter);
13151334

lib/Dialect/NDArray/IR/NDArrayOps.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
1616
#include <imex/Utils/PassUtils.h>
1717
#include <llvm/ADT/TypeSwitch.h>
18+
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
1819
#include <mlir/Dialect/Utils/StaticValueUtils.h>
1920
#include <mlir/IR/DialectImplementation.h>
2021

@@ -72,7 +73,16 @@ NDArrayType NDArrayType::get(::llvm::ArrayRef<int64_t> shape,
7273
return get(ctx, shape, elementType, environments, layout);
7374
}
7475

75-
::mlir::MemRefType NDArrayType::getMemRefType() const {
76+
::mlir::MemRefType NDArrayType::getMemRefType(::mlir::Value val) const {
77+
if (val) {
78+
auto defOp = val.getDefiningOp<::mlir::bufferization::ToTensorOp>();
79+
if (defOp) {
80+
return defOp.getMemref()
81+
.getType()
82+
.cloneWith(getShape(), getElementType())
83+
.cast<::mlir::MemRefType>();
84+
}
85+
}
7686
return ::imex::getMemRefType(getContext(), getShape(), getElementType());
7787
}
7888

test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,25 @@ func.func @test_subview(%arg0: !ndarray.ndarray<?xi64>) -> !ndarray.ndarray<?xi6
1717
// CHECK-NEXT: [[V2:%.*]] = bufferization.to_memref
1818
// CHECK-NEXT: return [[V2]] : memref<?xi64, strided<[?], offset: ?>>
1919

20+
// -----
21+
func.func @test_static_mr_2_tnsr_2_static_mr(%arg0: memref<55xi32, strided<[1], offset: 2>>) -> !ndarray.ndarray<3xi32> {
22+
%c0 = arith.constant 0 : index
23+
%c3 = arith.constant 3 : index
24+
%nda = ndarray.from_memref %arg0 : memref<55xi32, strided<[1], offset: 2>> -> !ndarray.ndarray<55xi32>
25+
%0 = ndarray.subview %nda[%c0][%c3][%c3] : !ndarray.ndarray<55xi32> to !ndarray.ndarray<3xi32>
26+
return %0 : !ndarray.ndarray<3xi32>
27+
}
28+
// CHECK-LABEL: @test_static_mr_2_tnsr_2_static_mr
29+
// CHECK-NEXT: [[C0:%.*]] = arith.constant
30+
// CHECK-NEXT: [[C1:%.*]] = arith.constant
31+
// CHECK-NEXT: [[V:%.*]] = bufferization.to_tensor %arg0 : memref<55xi32, strided<[1], offset: 2>>
32+
// CHECK-NEXT: [[V0:%.*]] = bufferization.to_memref [[V]] : memref<55xi32, strided<[1], offset: 2>>
33+
// CHECK-NEXT: [[S0:%.*]] = memref.subview [[V0]][[[C0]]] [[[C1]]] [[[C1]]] : memref<55xi32, strided<[1], offset: 2>> to memref<?xi32, strided<[?], offset: ?>>
34+
// CHECK-NEXT: [[V1:%.*]] = bufferization.to_tensor [[S0]] writable : memref<?xi32, strided<[?], offset: ?>>
35+
// CHECK-NEXT: [[V2:%.*]] = bufferization.to_memref
36+
// CHECK-NEXT: [[V3:%.*]] = memref.cast [[V2]]
37+
// CHECK-NEXT: return [[V3]] : memref<3xi32, strided<[?], offset: ?>>
38+
2039
// -----
2140
func.func @test_extract_slice(%arg0: !ndarray.ndarray<?xi64>) -> !ndarray.ndarray<?xi64> {
2241
%c0 = arith.constant 0 : index
@@ -439,3 +458,13 @@ func.func @test_cast_elemtype_copy(%arg0: !ndarray.ndarray<16xi32>) -> !ndarray.
439458
// CHECK: region.env_region "protect_copy_op"
440459
// CHECK-NEXT: memref.copy
441460
// CHECK-NEXT: }
461+
462+
// -----
463+
func.func @test_from_memref(%arg0: memref<5xi32, strided<[?], offset: ?>>) -> !ndarray.ndarray<5xi32> {
464+
%0 = ndarray.from_memref %arg0 : memref<5xi32, strided<[?], offset: ?>> -> !ndarray.ndarray<5xi32>
465+
return %0 : !ndarray.ndarray<5xi32>
466+
}
467+
// CHECK-LABEL: @test_from_memref
468+
// CHECK: [[V0:%.*]] = bufferization.to_tensor
469+
// CHECK: [[V1:%.*]] = bufferization.to_memref [[V0]]
470+
// CHECK-NEXT: return [[V1]] : memref<5xi32, strided<[?], offset: ?>>

test/Dialect/NDArray/IR/NDArrayOps.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,12 @@ func.func @test_castelem(%arg0: !ndarray.ndarray<5xi64>) -> !ndarray.ndarray<5xi
193193
// CHECK-LABEL: func.func @test_castelem
194194
// CHECK: [[V0:%.*]] = ndarray.cast_elemtype
195195
// CHECK-NEXT: return [[V0]] : !ndarray.ndarray<5xi32>
196+
197+
// -----
198+
func.func @test_from_memref(%arg0: memref<?xi32, strided<[?], offset: ?>>) -> !ndarray.ndarray<?xi32> {
199+
%0 = ndarray.from_memref %arg0 : memref<?xi32, strided<[?], offset: ?>> -> !ndarray.ndarray<?xi32>
200+
return %0 : !ndarray.ndarray<?xi32>
201+
}
202+
// CHECK-LABEL: func.func @test_from_memref
203+
// CHECK: [[V0:%.*]] = ndarray.from_memref
204+
// CHECK-NEXT: return [[V0]] : !ndarray.ndarray<?xi32>

0 commit comments

Comments
 (0)