From 225563d4cbc757bfd80aedaff00f2e8585e006ea Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Thu, 31 Oct 2024 08:06:09 +0000 Subject: [PATCH 1/2] Lower reinterpret_cast to EmitC --- .../MemRefToEmitC/MemRefToEmitC.cpp | 31 ++++++++++++- .../MemRefToEmitC/memref-to-emitc-failed.mlir | 45 +++++++++++++++++++ .../MemRefToEmitC/memref-to-emitc.mlir | 9 ++++ 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index f6ce553dd899a..da896d03cd961 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -230,6 +230,33 @@ struct ConvertExpandShape final } }; +struct ConvertReinterpretCast final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = + dyn_cast>(operands.getSource()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + + auto newCastOp = rewriter.create(op->getLoc(), resultTy, + operands.getSource()); + newCastOp.setReference(true); + rewriter.replaceOp(op, newCastOp); + return success(); + } +}; + } // namespace void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { @@ -251,6 +278,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter) { patterns.add( - converter, patterns.getContext()); + ConvertStore, ConvertCollapseShape, ConvertExpandShape, + ConvertReinterpretCast>(converter, patterns.getContext()); } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index 4df7bac0b5580..87ed7a63b9b1c 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -61,3 +61,48 @@ func.func @memref_collapse_dyn_shape(%arg: memref) -> memref { %0 = memref.collapse_shape %arg [[0, 1]] : memref into memref return %0 : memref } + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_dyn_shape +func.func @memref_reinterpret_cast_dyn_shape(%arg: memref<2x5xi32>, %size: index) -> memref { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<2x5xi32> to memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_dyn_offset +func.func @memref_reinterpret_cast_dyn_offset(%arg: memref<2x5xi32>, %offset: index) -> memref<10xi32, strided<[1], offset: ?>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [%offset], sizes: [10], strides: [1] : memref<2x5xi32> to memref<10xi32, strided<[1], offset: ?>> + return %0 : memref<10xi32, strided<[1], offset:? >> +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_static_offset +func.func @memref_reinterpret_cast_static_offset(%arg: memref<2x5xi32>) -> memref<10xi32, strided<[1], offset: 10>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [10], sizes: [10], strides: [1] : memref<2x5xi32> to memref<10xi32, strided<[1], offset: 10>> + return %0 : memref<10xi32, strided<[1], offset: 10>> +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_static_strides +func.func @memref_reinterpret_cast_offset(%arg: memref<2x5xi32>) -> memref<10xi32, strided<[2], offset: 0>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [2] : memref<2x5xi32> to memref<10xi32, strided<[2], offset: 0>> + return %0 : memref<10xi32, strided<[2], offset: 0>> +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_dyn_strides +func.func @memref_reinterpret_cast_offset(%arg: memref<2x5xi32>, %stride: index) -> memref<10xi32, strided<[?], offset: 0>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [%stride] : memref<2x5xi32> to memref<10xi32, strided<[?], offset: 0>> + return %0 : memref<10xi32, strided<[?], offset: 0>> +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 1effcb66cd62b..c620616526659 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -94,3 +94,12 @@ func.func @memref_collapse_shape(%arg: memref<2x5xi32>) -> memref<10xi32> { %0 = memref.collapse_shape %arg [[0, 1]] : memref<2x5xi32> into memref<10xi32> return %0 : memref<10xi32> } + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast +func.func @memref_reinterpret_cast(%arg: memref<2x5xi32>) -> memref<10xi32> { + // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<2x5xi32> to memref<10xi32> + return %0 : memref<10xi32> +} From fbcdf95787b9677207229807aaee2e134db24baf Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 4 Nov 2024 09:46:28 +0000 Subject: [PATCH 2/2] Address review comments --- .../Conversion/MemRefToEmitC/memref-to-emitc.mlir | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index c620616526659..b735ac8975b2e 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -97,8 +97,17 @@ func.func @memref_collapse_shape(%arg: memref<2x5xi32>) -> memref<10xi32> { // ----- -// CHECK-LABEL: memref_reinterpret_cast -func.func @memref_reinterpret_cast(%arg: memref<2x5xi32>) -> memref<10xi32> { +// CHECK-LABEL: memref_reinterpret_cast_subset +func.func @memref_reinterpret_cast_subset(%arg: memref<2x5xi32>) -> memref<8xi32> { + // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<2x5xi32> to !emitc.array<8xi32> ref + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [8], strides: [1] : memref<2x5xi32> to memref<8xi32> + return %0 : memref<8xi32> +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_reshape +func.func @memref_reinterpret_cast_reshape(%arg: memref<2x5xi32>) -> memref<10xi32> { // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<2x5xi32> to memref<10xi32> return %0 : memref<10xi32>