From 7cfaa08f265f72a294bea52b5a2254484b03440e Mon Sep 17 00:00:00 2001 From: Frank Schlimbach Date: Fri, 20 Oct 2023 11:39:49 +0200 Subject: [PATCH] insert gpu2host copies when returning views of gpuAlloced memrefs --- lib/Transforms/InsertGPUAllocs.cpp | 42 +++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/lib/Transforms/InsertGPUAllocs.cpp b/lib/Transforms/InsertGPUAllocs.cpp index 26a85eca1..b932c362a 100644 --- a/lib/Transforms/InsertGPUAllocs.cpp +++ b/lib/Transforms/InsertGPUAllocs.cpp @@ -263,16 +263,40 @@ class InsertGPUAllocsPass final alloc.getSymbolOperands(), hostShared); auto allocResult = gpuAlloc.getResult(0); builder.setInsertionPoint(term); - for (mlir::OpOperand &use : alloc.getResult().getUses()) { - if (use.getOwner() == term) { - auto newAlloc = builder.create( - loc, alloc.getType(), alloc.getDynamicSizes(), - alloc.getSymbolOperands()); - builder.create(loc, allocResult, - newAlloc.getResult()); - use.set(newAlloc.getResult()); - } + + // follow the users of alloc if they are view-like + // insert copy if they are terminator + auto insertCopyIfViewInTerminal = [&](auto &use) -> bool { + auto _insertCopyIfViewInTerminal = [&](auto &use, auto _insertCopyIfViewInTerminal_) -> bool { + auto user = use.getOwner(); + if (user == term) { + auto newAlloc = builder.create( + loc, alloc.getType(), alloc.getDynamicSizes(), + alloc.getSymbolOperands()); + builder.create(loc, allocResult, + newAlloc.getResult()); + auto castop = builder.create( + loc, use.get().getType(), newAlloc); + use.set(castop.getResult()); + return true; + } + if (::mlir::isa<::mlir::ViewLikeOpInterface>(user)) { + assert(user->getNumResults() == 1); + for (auto &_use : user->getResult(0).getUses()) { + if (_insertCopyIfViewInTerminal_(_use, _insertCopyIfViewInTerminal_)) + return true; + } + } + // on all other cases we do nothing + return true; + }; + return _insertCopyIfViewInTerminal(use, _insertCopyIfViewInTerminal); + }; + + for (auto &use : alloc.getResult().getUses()) { + insertCopyIfViewInTerminal(use); } + alloc.replaceAllUsesWith(allocResult); builder.create(loc, std::nullopt, allocResult); alloc.erase();