Skip to content

Commit

Permalink
[AssignTiles] Generalize pass to use LogicalObjFifoInterface and Copy…
Browse files Browse the repository at this point in the history
…OpInterface
  • Loading branch information
jtuyls committed Nov 28, 2024
1 parent 4f00ff9 commit ad2f257
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define IREE_COMPILER_AMDAIE_LOGICALOBJFIFOOPINTERFACE_H_

#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CopyOpInterface.h"

namespace mlir::iree_compiler::AMDAIE {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,20 @@ def LogicalObjFifoOpInterface : OpInterface<"LogicalObjFifoOpInterface"> {
/*defaultImplementation=*/[{
return $_op.getTiles();
}]
>
>,
InterfaceMethod<
/*desc=*/[{
A utility to replace this logical objectFifo operation with a new one with new tiles.
}],
/*retTy=*/"::llvm::FailureOr<::mlir::iree_compiler::AMDAIE::LogicalObjFifoOpInterface>",
/*methodName=*/"replaceWithNewTiles",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"::mlir::ValueRange":$tiles),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.replaceWithNewTiles(rewriter, tiles);
}]
>,
];
}

Expand Down
34 changes: 34 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,14 @@ LogicalResult LogicalObjectFifoFromBuffersOp::verify() {
return success();
}

FailureOr<LogicalObjFifoOpInterface>
LogicalObjectFifoFromBuffersOp::replaceWithNewTiles(
::mlir::RewriterBase &rewriter, ::mlir::ValueRange tiles) {
// NOTE(jornt): This can potentially be implemented by updating the buffer's
// tiles.
return (*this).emitOpError() << "doesn't support tile replacement";
}

//===----------------------------------------------------------------------===//
// AMDAIE_LogicalObjectFifoFromMemrefOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -665,6 +673,17 @@ LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize(
return success();
}

FailureOr<LogicalObjFifoOpInterface>
LogicalObjectFifoFromMemrefOp::replaceWithNewTiles(
::mlir::RewriterBase &rewriter, ::mlir::ValueRange tiles) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(getOperation());
auto newOp =
rewriter.replaceOpWithNewOp<AMDAIE::LogicalObjectFifoFromMemrefOp>(
*this, getType(), getMemref(), tiles);
return cast<LogicalObjFifoOpInterface>(newOp.getOperation());
}

LogicalResult LogicalObjectFifoFromMemrefOp::verify() {
// Check whether the tile arguments are all of type AMDAIE::TileOp
if (llvm::all_of(getTiles(), [](Value result) {
Expand All @@ -675,6 +694,21 @@ LogicalResult LogicalObjectFifoFromMemrefOp::verify() {
return failure();
}

//===----------------------------------------------------------------------===//
// AMDAIE_LogicalObjectFifoPlaceholderOp
//===----------------------------------------------------------------------===//

FailureOr<LogicalObjFifoOpInterface>
LogicalObjectFifoPlaceholderOp::replaceWithNewTiles(
::mlir::RewriterBase &rewriter, ::mlir::ValueRange tiles) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(getOperation());
auto newOp =
rewriter.replaceOpWithNewOp<AMDAIE::LogicalObjectFifoPlaceholderOp>(
*this, getType(), tiles);
return cast<LogicalObjFifoOpInterface>(newOp.getOperation());
}

//===----------------------------------------------------------------------===//
// AMDAIE_LogicalObjectFifoRelease
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 6 additions & 3 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ def AMDAIE_LogicalObjectFifoAcquire:

def AMDAIE_LogicalObjectFifoFromBuffersOp
: AMDAIE_Op<"logicalobjectfifo.from_buffers",
[LogicalObjFifoOpInterface, Pure, AttrSizedOperandSegments]> {
[DeclareOpInterfaceMethods<LogicalObjFifoOpInterface, ["replaceWithNewTiles"]>,
Pure, AttrSizedOperandSegments]> {
let summary = "Create a logical objectFifo from a set of buffers";
let description = [{
Creates a logical objectFifo which encapsulates a set of memref `buffers`.
Expand Down Expand Up @@ -1220,7 +1221,8 @@ def AMDAIE_LogicalObjectFifoFromBuffersOp

def AMDAIE_LogicalObjectFifoFromMemrefOp
: AMDAIE_Op<"logicalobjectfifo.from_memref",
[LogicalObjFifoOpInterface, Pure]> {
[DeclareOpInterfaceMethods<LogicalObjFifoOpInterface, ["replaceWithNewTiles"]>,
Pure]> {
let summary = "Create a logical objectFifo from a memref";
let description = [{
Creates a logical objectFifo which encapsulates a memref. The logical objectFifo
Expand Down Expand Up @@ -1294,7 +1296,8 @@ def AMDAIE_LogicalObjectFifoFromMemrefOp

def AMDAIE_LogicalObjectFifoPlaceholderOp:
AMDAIE_Op<"logicalobjectfifo.placeholder", [
LogicalObjFifoOpInterface, Pure]> {
DeclareOpInterfaceMethods<LogicalObjFifoOpInterface, ["replaceWithNewTiles"]>,
Pure]> {
let summary = "A placeholder for a logical objectFifo.";
let description = [{
Represents a placeholder for a logical objectFifo. The actual logical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,23 @@ namespace mlir::iree_compiler::AMDAIE {
/// logical objectfifo, depending on whether the OperateOn template parameter is
/// set to `OperateOn::Source` respectively `OperateOn::Target`.
template <CopyOpOperateOn OperateOn>
LogicalResult getUserTiles(
AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjectFifo,
SmallVectorImpl<AMDAIE::TileOp> &tiles) {
LogicalResult getUserTiles(AMDAIE::LogicalObjFifoOpInterface logicalObjectFifo,
SmallVectorImpl<AMDAIE::TileOp> &tiles) {
llvm::SmallSetVector<AMDAIE::TileOp, 16> tileSet;
for (Operation *user : logicalObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(user)) {
if (auto copyOp = dyn_cast<CopyOpInterface>(user)) {
auto source = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getSource().getDefiningOp());
auto target = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getTarget().getDefiningOp());
if (!source || !target) continue;
ValueRange tileIndices;
if constexpr (OperateOn == CopyOpOperateOn::Source) {
if (dmaOp.getTargetObjectFifo() != logicalObjectFifo) continue;
tileIndices = dmaOp.getSourceObjectFifo().getTiles();
if (target != logicalObjectFifo) continue;
tileIndices = source.getTiles();
} else if constexpr (OperateOn == CopyOpOperateOn::Target) {
if (dmaOp.getSourceObjectFifo() != logicalObjectFifo) continue;
tileIndices = dmaOp.getTargetObjectFifo().getTiles();
if (source != logicalObjectFifo) continue;
tileIndices = target.getTiles();
}
// Only fill in tiles when all sources have tiles.
if (tileIndices.empty()) return failure();
Expand All @@ -49,7 +53,7 @@ LogicalResult getUserTiles(
/// Utility to recursively find users of the provided logical objectFifo inside
/// `amdaie.core` operations and return the tile coordinates.
LogicalResult findUsersInCoreAndAddTiles(
Operation *op, AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjectFifo,
Operation *op, AMDAIE::LogicalObjFifoOpInterface logicalObjectFifo,
llvm::SmallSetVector<std::pair<int64_t, int64_t>, 16> &tiles) {
for (Operation *userOp : op->getUsers()) {
if (auto coreOp = userOp->getParentOfType<AMDAIE::CoreOp>()) {
Expand All @@ -63,7 +67,7 @@ LogicalResult findUsersInCoreAndAddTiles(
if (auto subviewOp = dyn_cast<memref::SubViewOp>(userOp)) {
return findUsersInCoreAndAddTiles(subviewOp, logicalObjectFifo, tiles);
} else if (auto userLogicalObjectFifo =
dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(userOp)) {
dyn_cast<AMDAIE::LogicalObjFifoOpInterface>(userOp)) {
return findUsersInCoreAndAddTiles(userLogicalObjectFifo,
logicalObjectFifo, tiles);
}
Expand All @@ -73,15 +77,18 @@ LogicalResult findUsersInCoreAndAddTiles(

/// Utility to clear non-local tile assignments.
LogicalResult clearNonLocalTiles(RewriterBase &rewriter, Operation *op) {
op->walk([&](AMDAIE::LogicalObjectFifoFromMemrefOp objFifo) {
WalkResult res = op->walk([&](AMDAIE::LogicalObjFifoOpInterface objFifo) {
if (objFifo.getMemorySpaceAsUInt() != 2) {
rewriter.setInsertionPoint(objFifo);
SmallVector<Value> tiles;
rewriter.replaceOpWithNewOp<AMDAIE::LogicalObjectFifoFromMemrefOp>(
objFifo, cast<LogicalObjectFifoType>(objFifo.getOutput().getType()),
objFifo.getMemref(), tiles);
if (failed(objFifo.replaceWithNewTiles(rewriter, tiles))) {
objFifo.emitOpError() << "could not replace its tiles";
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
return success();
}

Expand All @@ -90,20 +97,17 @@ LogicalResult clearNonLocalTiles(RewriterBase &rewriter, Operation *op) {
/// different tile locations.
LogicalResult duplicateGlobalObjFifos(RewriterBase &rewriter, Operation *op) {
op->walk([&](AMDAIE::DoublyStridedCopyOpInterface copyOp) {
auto source = dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
auto source = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getSource().getDefiningOp());
auto target = dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
auto target = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getTarget().getDefiningOp());
auto createNewObjFifoAndReplaceUsesFrom =
[&](AMDAIE::LogicalObjectFifoFromMemrefOp oldObjFifo) {
[&](AMDAIE::LogicalObjFifoOpInterface oldObjFifo) {
rewriter.setInsertionPoint(copyOp);
auto newObjFifo =
rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(),
cast<LogicalObjectFifoType>(oldObjFifo.getOutput().getType()),
oldObjFifo.getMemref());
auto newObjFifo = cast<AMDAIE::LogicalObjFifoOpInterface>(
rewriter.clone(*oldObjFifo.getOperation()));
rewriter.replaceUsesWithIf(
oldObjFifo.getOutput(), newObjFifo.getOutput(),
oldObjFifo->getResult(0), newObjFifo->getResult(0),
[&](OpOperand &use) {
return use.getOwner() == copyOp.getOperation();
});
Expand All @@ -130,15 +134,21 @@ LogicalResult assignLocalTiles(RewriterBase &rewriter, Operation *op) {

llvm::SmallSetVector<std::pair<int64_t, int64_t>, 16> tileLocations;
if (failed(findUsersInCoreAndAddTiles(
logicalObjectFifo, logicalObjectFifo, tileLocations))) {
logicalObjectFifo,
cast<AMDAIE::LogicalObjFifoOpInterface>(
logicalObjectFifo.getOperation()),
tileLocations))) {
return WalkResult::interrupt();
}
// Handle subviews.
for (Operation *userOp :
logicalObjectFifo.getMemref().getDefiningOp()->getUsers()) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(userOp)) {
if (failed(findUsersInCoreAndAddTiles(subviewOp, logicalObjectFifo,
tileLocations))) {
if (failed(findUsersInCoreAndAddTiles(
subviewOp,
cast<AMDAIE::LogicalObjFifoOpInterface>(
logicalObjectFifo.getOperation()),
tileLocations))) {
return WalkResult::interrupt();
}
}
Expand Down Expand Up @@ -177,16 +187,16 @@ LogicalResult assignLocalTiles(RewriterBase &rewriter, Operation *op) {
/// have tiles assigned yet, we will return a failure and give the linked
/// logical objectfifos a chance to assign tiles before returning to this one.
class FillTiles
: public OpRewritePattern<AMDAIE::LogicalObjectFifoFromMemrefOp> {
using OpRewritePattern<
AMDAIE::LogicalObjectFifoFromMemrefOp>::OpRewritePattern;
: public OpInterfaceRewritePattern<AMDAIE::LogicalObjFifoOpInterface> {
using OpInterfaceRewritePattern<
AMDAIE::LogicalObjFifoOpInterface>::OpInterfaceRewritePattern;

public:
FillTiles(MLIRContext *context, const AMDAIE::AMDAIEDeviceModel &deviceModel)
: OpRewritePattern(context), deviceModel(deviceModel) {}
: OpInterfaceRewritePattern(context), deviceModel(deviceModel) {}

LogicalResult matchAndRewrite(
AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjectFifo,
AMDAIE::LogicalObjFifoOpInterface logicalObjectFifo,
PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "FillTiles: " << logicalObjectFifo << "\n");
if (!logicalObjectFifo.getTiles().empty()) {
Expand Down Expand Up @@ -276,9 +286,22 @@ class FillTiles
"iteration, with more information, a tile location can be found.");
}
rewriter.setInsertionPoint(logicalObjectFifo);
rewriter.replaceOpWithNewOp<AMDAIE::LogicalObjectFifoFromMemrefOp>(
logicalObjectFifo, logicalObjectFifo.getMemref(),
tileLocations.takeVector());
SmallVector<Value> tiles;
tiles.reserve(tileLocations.size());
for (auto [column, row] : tileLocations) {
auto getCol = rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), column);
auto getRow = rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), row);
auto tileOp = rewriter.create<AMDAIE::TileOp>(rewriter.getUnknownLoc(),
getCol, getRow);
tiles.push_back(tileOp.getResult());
}
if (failed(logicalObjectFifo.replaceWithNewTiles(rewriter, tiles))) {
return rewriter.notifyMatchFailure(
logicalObjectFifo,
"Could not replace the tiles in the provided logical objectFifo.");
}
return success();
}

Expand All @@ -294,7 +317,7 @@ LogicalResult assignNonLocalTiles(RewriterBase &rewriter, Operation *op,
const AMDAIEDeviceModel &deviceModel) {
MLIRContext *context = rewriter.getContext();
if (failed(clearNonLocalTiles(rewriter, op)))
return op->emitOpError() << "failed to clear non-local tile assignemts";
return op->emitOpError() << "failed to clear non-local tile assigments";

// Find and fill the tile candidates.
RewritePatternSet fillTilePatterns(context);
Expand Down Expand Up @@ -330,7 +353,7 @@ LogicalResult assignNonLocalTiles(RewriterBase &rewriter, Operation *op,
// After filling tile candidates, find and assign a specific one.
DenseMap<MemRefType, int64_t> logicalObjFifoToTileId;
WalkResult res =
op->walk([&](AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjectFifo) {
op->walk([&](AMDAIE::LogicalObjFifoOpInterface logicalObjectFifo) {
uint8_t memSpace = logicalObjectFifo.getMemorySpaceAsUInt();
if (memSpace != 0 && memSpace != 1) return WalkResult::advance();
if (logicalObjectFifo.getTiles().size() == 0) {
Expand All @@ -356,11 +379,11 @@ LogicalResult assignNonLocalTiles(RewriterBase &rewriter, Operation *op,
rewriter.setInsertionPoint(logicalObjectFifo);
SmallVector<Value> tileResults = {
cast<Value>(assignedTileOp.getResult())};
rewriter.replaceOpWithNewOp<AMDAIE::LogicalObjectFifoFromMemrefOp>(
logicalObjectFifo,
cast<LogicalObjectFifoType>(
logicalObjectFifo.getOutput().getType()),
logicalObjectFifo.getMemref(), tileResults);
if (failed(
logicalObjectFifo.replaceWithNewTiles(rewriter, tileResults))) {
logicalObjectFifo.emitOpError() << "could not replace its tiles.";
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
Expand Down Expand Up @@ -413,7 +436,7 @@ void AMDAIEAssignTilesPass::runOnOperation() {

// Assign tile locations to logical objectFifos on non-local (not L1) memory.
if (failed(assignNonLocalTiles(rewriter, parentOp, deviceModel))) {
parentOp->emitOpError() << "local tile assignment failed";
parentOp->emitOpError() << "non-local tile assignment failed";
return signalPassFailure();
}
LLVM_DEBUG(llvm::dbgs() << "After assignNonLocalTiles: \n"
Expand Down
Loading

0 comments on commit ad2f257

Please sign in to comment.