Skip to content

Commit

Permalink
[NDArray] Bugfix for "ndarray.permute_dims" by adding "bufferization.…
Browse files Browse the repository at this point in the history
…clone" (#926)

* added clone arg to createToMemRef
* add "convert-bufferization-to-memref" to imex-runner pipeline
  • Loading branch information
AllanZyne authored Oct 16, 2024
1 parent ff13507 commit 8ae485b
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 6 deletions.
3 changes: 2 additions & 1 deletion include/imex/Utils/PassUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ extern void printValsAsMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder,
// If this memref has a different shape than mrTyp, also creates a memref.cast
extern ::mlir::Value createToMemRef(::mlir::Location loc,
::mlir::OpBuilder &builder,
::mlir::Value input, ::mlir::Type toTyp);
::mlir::Value input, ::mlir::Type toTyp,
bool clone = false);

// broadcast 2 shapes into one according to the array-API
template <typename V1, typename V2>
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ struct PermuteDimsOpLowering
if (!srcArType)
return mlir::failure();
auto srcMRType = srcArType.getMemRefType(srcTnsr);
auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType);
auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType, true);

auto perm = ::mlir::AffineMapAttr::get(::mlir::AffineMap::getPermutationMap(
adaptor.getAxes(), rewriter.getContext()));
Expand Down
8 changes: 6 additions & 2 deletions lib/Utils/PassUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,13 +473,17 @@ void printValsAsMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder,
// First creates a toMemrefOp with the same shape as tensor.
// If this memref has a different shape than mrTyp, also creates a memref.cast
::mlir::Value createToMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder,
::mlir::Value input, ::mlir::Type toTyp) {
::mlir::Value input, ::mlir::Type toTyp,
bool clone) {
auto iTyp = mlir::cast<::mlir::RankedTensorType>(input.getType());
auto mrTyp = mlir::cast<::mlir::MemRefType>(toTyp);
auto shapedMrTyp =
mlir::cast<::mlir::ShapedType>(mrTyp).clone(iTyp.getShape());
auto shapedMr = builder.create<::mlir::bufferization::ToMemrefOp>(
::mlir::Value shapedMr = builder.create<::mlir::bufferization::ToMemrefOp>(
loc, shapedMrTyp, input);
if (clone) {
shapedMr = builder.create<::mlir::bufferization::CloneOp>(loc, shapedMr);
}
return shapedMrTyp == toTyp
? shapedMr
: builder.create<::mlir::memref::CastOp>(loc, toTyp, shapedMr)
Expand Down
5 changes: 3 additions & 2 deletions test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -477,5 +477,6 @@ func.func @test_permute_dims(%arg0: !ndarray.ndarray<5x3x2xi32>) -> !ndarray.nda
// CHECK-LABEL: @test_permute_dims
// CHECK: [[V0:%.*]] = bufferization.to_tensor
// CHECK: [[V1:%.*]] = bufferization.to_memref [[V0]]
// CHECK: [[V2:%.*]] = memref.transpose [[V1]] (d0, d1, d2) -> (d2, d1, d0)
// CHECK-NEXT: return [[V2]] : memref<2x3x5xi32, strided<[?, ?, ?], offset: ?>>
// CHECK: [[V2:%.*]] = bufferization.clone [[V1]]
// CHECK: [[V3:%.*]] = memref.transpose [[V2]] (d0, d1, d2) -> (d2, d1, d0)
// CHECK-NEXT: return [[V3]] : memref<2x3x5xi32, strided<[?, ?, ?], offset: ?>>
1 change: 1 addition & 0 deletions test/imex-runner/ndarray-gpu.pp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
func.func(empty-tensor-to-alloc-tensor)
one-shot-bufferize{bufferize-function-boundaries}
imex-remove-temporaries
convert-bufferization-to-memref
func.func(convert-linalg-to-parallel-loops)
func.func(scf-parallel-loop-fusion)
// GPU
Expand Down
1 change: 1 addition & 0 deletions test/imex-runner/ndarray.pp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
func.func(empty-tensor-to-alloc-tensor)
one-shot-bufferize{bufferize-function-boundaries}
imex-remove-temporaries
convert-bufferization-to-memref
func.func(convert-linalg-to-parallel-loops)
func.func(scf-parallel-loop-fusion)
drop-regions
Expand Down

0 comments on commit 8ae485b

Please sign in to comment.