diff --git a/include/imex/Utils/PassUtils.h b/include/imex/Utils/PassUtils.h index 0b443084e..45ed12d01 100644 --- a/include/imex/Utils/PassUtils.h +++ b/include/imex/Utils/PassUtils.h @@ -280,7 +280,8 @@ extern void printValsAsMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder, // If this memref has a different shape than mrTyp, also creates a memref.cast extern ::mlir::Value createToMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder, - ::mlir::Value input, ::mlir::Type toTyp); + ::mlir::Value input, ::mlir::Type toTyp, + bool clone = false); // broadcast 2 shapes into one according to the array-API template diff --git a/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp b/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp index ff08b2421..85c1a9b6f 100644 --- a/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp +++ b/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp @@ -1242,7 +1242,7 @@ struct PermuteDimsOpLowering if (!srcArType) return mlir::failure(); auto srcMRType = srcArType.getMemRefType(srcTnsr); - auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType); + auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType, true); auto perm = ::mlir::AffineMapAttr::get(::mlir::AffineMap::getPermutationMap( adaptor.getAxes(), rewriter.getContext())); diff --git a/lib/Utils/PassUtils.cpp b/lib/Utils/PassUtils.cpp index b294920de..f0256325d 100644 --- a/lib/Utils/PassUtils.cpp +++ b/lib/Utils/PassUtils.cpp @@ -473,13 +473,17 @@ void printValsAsMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder, // First creates a toMemrefOp with the same shape as tensor. // If this memref has a different shape than mrTyp, also creates a memref.cast ::mlir::Value createToMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder, - ::mlir::Value input, ::mlir::Type toTyp) { + ::mlir::Value input, ::mlir::Type toTyp, + bool clone) { auto iTyp = mlir::cast<::mlir::RankedTensorType>(input.getType()); auto mrTyp = mlir::cast<::mlir::MemRefType>(toTyp); auto shapedMrTyp = mlir::cast<::mlir::ShapedType>(mrTyp).clone(iTyp.getShape()); - auto shapedMr = builder.create<::mlir::bufferization::ToMemrefOp>( + ::mlir::Value shapedMr = builder.create<::mlir::bufferization::ToMemrefOp>( loc, shapedMrTyp, input); + if (clone) { + shapedMr = builder.create<::mlir::bufferization::CloneOp>(loc, shapedMr); + } return shapedMrTyp == toTyp ? shapedMr : builder.create<::mlir::memref::CastOp>(loc, toTyp, shapedMr) diff --git a/test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir b/test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir index 65afed9e9..e5aa868be 100644 --- a/test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir +++ b/test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir @@ -477,5 +477,6 @@ func.func @test_permute_dims(%arg0: !ndarray.ndarray<5x3x2xi32>) -> !ndarray.nda // CHECK-LABEL: @test_permute_dims // CHECK: [[V0:%.*]] = bufferization.to_tensor // CHECK: [[V1:%.*]] = bufferization.to_memref [[V0]] -// CHECK: [[V2:%.*]] = memref.transpose [[V1]] (d0, d1, d2) -> (d2, d1, d0) -// CHECK-NEXT: return [[V2]] : memref<2x3x5xi32, strided<[?, ?, ?], offset: ?>> +// CHECK: [[V2:%.*]] = bufferization.clone [[V1]] +// CHECK: [[V3:%.*]] = memref.transpose [[V2]] (d0, d1, d2) -> (d2, d1, d0) +// CHECK-NEXT: return [[V3]] : memref<2x3x5xi32, strided<[?, ?, ?], offset: ?>> diff --git a/test/imex-runner/ndarray-gpu.pp b/test/imex-runner/ndarray-gpu.pp index 63c627d5b..013822eee 100644 --- a/test/imex-runner/ndarray-gpu.pp +++ b/test/imex-runner/ndarray-gpu.pp @@ -12,6 +12,7 @@ func.func(empty-tensor-to-alloc-tensor) one-shot-bufferize{bufferize-function-boundaries} imex-remove-temporaries + convert-bufferization-to-memref func.func(convert-linalg-to-parallel-loops) func.func(scf-parallel-loop-fusion) // GPU diff --git a/test/imex-runner/ndarray.pp b/test/imex-runner/ndarray.pp index 3ee5ea2bc..2736f7a4d 100644 --- a/test/imex-runner/ndarray.pp +++ b/test/imex-runner/ndarray.pp @@ -11,6 +11,7 @@ func.func(empty-tensor-to-alloc-tensor) one-shot-bufferize{bufferize-function-boundaries} imex-remove-temporaries + convert-bufferization-to-memref func.func(convert-linalg-to-parallel-loops) func.func(scf-parallel-loop-fusion) drop-regions