Skip to content

Commit

Permalink
[AMDAIEFoldDmaWaits] Fold DMA wait operations across multi columns (#986
Browse files Browse the repository at this point in the history
)

This is an enhancement for
#962.

In the previous PR, DMA waits on the same `connection` (and the same
tile) could be folded, exploiting the fact that each DMA channel has a
queue size of 4.

In this PR, DMA waits across multiple `columns` can also be folded,
provided their corresponding `row`, `channel`, and `direction` are the
same. This optimization leverages the ability to specify `colNum` in
`TCTSync`, where the range `[col, col + colNum)` can be addressed.

The numbers in the following table show the instruction size in words.
| Test (MxKxN) | No Folding | Only Fold by Connection | Only Fold by
Column | Fold Both |

|---------------|------------|--------------------|----------------|-----------|
| 512x4096x512 | 1228 | 1132 | 1120 | 1096 |
| 512x512x4096 | 820 | 772 | 748 | 736 |
| 4096x512x512 | 4628 | 4244 | 4220 | 4124 |
  • Loading branch information
Yu-Zhewen authored Dec 18, 2024
1 parent d24568f commit 2080473
Show file tree
Hide file tree
Showing 4 changed files with 618 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,46 @@ LogicalResult convertOp(AMDAIE::NpuAddressPatchOp op,
}

LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
for (Value token : op.getAsyncTokens()) {
auto pushToQueueOp =
dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(token.getDefiningOp());
// Collect all half DMA ops from the async tokens.
SmallVector<AMDAIE::NpuPushToQueueOp> pushToQueueOps;
for (Value asyncToken : op.getAsyncTokens()) {
auto pushToQueueOp = dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(
asyncToken.getDefiningOp());
if (!pushToQueueOp) {
return op.emitOpError()
<< "should operate on an `amdaie.push_to_queue` op";
<< "should operate on an `amdaie.push_to_queue` op async token";
}
pushToQueueOps.push_back(pushToQueueOp);
}
// Sort the half DMA ops by channel, direction, row, and column.
std::sort(pushToQueueOps.begin(), pushToQueueOps.end(),
[](AMDAIE::NpuPushToQueueOp a, AMDAIE::NpuPushToQueueOp b) {
return std::make_tuple(a.getChannel(), a.getDirection(),
a.getRow(), a.getCol()) <
std::make_tuple(b.getChannel(), b.getDirection(),
b.getRow(), b.getCol());
});
// Batch DMA operations with the same row, channel, and direction into a
// single TCT sync operation, as long as they have consecutive columns.
llvm::MapVector<AMDAIE::NpuPushToQueueOp, uint32_t> columnBatches;
for (auto pushToQueueOp : pushToQueueOps) {
if (!columnBatches.empty()) {
auto &[lastPushOp, lastColNum] = columnBatches.back();
if (lastPushOp.getRow() == pushToQueueOp.getRow() &&
lastPushOp.getCol() + lastColNum == pushToQueueOp.getCol() &&
lastPushOp.getDirection() == pushToQueueOp.getDirection() &&
lastPushOp.getChannel() == pushToQueueOp.getChannel()) {
++lastColNum;
continue;
}
}
columnBatches.insert({pushToQueueOp, 1});
}
// Convert to TCT sync ops.
for (auto &[pushToQueueOp, colNum] : columnBatches) {
if (failed(builder.appendTCTSync(
pushToQueueOp.getCol(), pushToQueueOp.getRow(),
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, 1,
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, colNum,
pushToQueueOp.getChannel()))) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,49 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

using DmaBdIdKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;
using DmaBdIdPair = std::pair<DmaBdIdKey, uint32_t>;

/// Utility function to retrieve TileOp, ConnectionOp, and BD ID from a given
/// half DMA copy operation.
FailureOr<DmaBdIdPair> retrieveDmaBdIdPair(
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
uint32_t currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}

DmaBdIdKey currBdIdKey = {tileOp, connectionOp};
return DmaBdIdPair{currBdIdKey, currBdIdVal};
}

/// Utility function to determine whether a DMA wait op can be folded based on
/// its half DMA copy operation.
FailureOr<bool> canFoldBasedOnHalfDmaCpy(
FailureOr<bool> canFoldByQueue(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
Expand Down Expand Up @@ -101,13 +137,11 @@ FailureOr<bool> canFoldBasedOnHalfDmaCpy(
/// Reverse traversal simplifies handling duplicate BD IDs, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>>
tileConnectToBdIdQueue;
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> tileConnectToBdIdQueue;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
Expand All @@ -116,7 +150,7 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result = canFoldBasedOnHalfDmaCpy(
FailureOr<bool> result = canFoldByQueue(
deviceModel, npuHalfDmaCpyNdOp, tileConnectToBdIdQueue);
if (failed(result)) return WalkResult::interrupt();
toErase &= *result;
Expand Down Expand Up @@ -152,6 +186,162 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
return success();
}

/// For each batch, combine the async tokens into a single NpuDmaWaitOp.
LogicalResult eraseBatchOperations(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps) {
// Skip if there are less than two DMA wait operations.
if (waitOps.size() < 2) return success();

SmallVector<Value> asyncTokens;
Operation *parentOp = waitOps[0]->getParentOp();
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
if (waitOp->getParentOp() != parentOp) {
return waitOp.emitError(
"DMA operations to be batched must belong to the same scope");
}
asyncTokens.append(waitOp.getAsyncTokens().begin(),
waitOp.getAsyncTokens().end());
}

rewriter.setInsertionPointAfter(waitOps.back());
rewriter.create<AMDAIE::NpuDmaWaitOp>(waitOps.back().getLoc(), asyncTokens);
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) rewriter.eraseOp(waitOp);
return success();
}

/// Utility function to determine if a DMA wait operation can be folded into a
/// a batch based on its half DMA copy operation.
/// Can't fold wait op if:
/// (1) the current operation is not in the same scope as the batch, or
/// (2) the current connection op already occurs in the batch, or
/// (3) the batch is empty, or
/// (4) the current operation is a packet flow, or
/// (5) the current BD ID on the same tile already occurs in the batch.
FailureOr<bool> canFoldByBatch(
const Operation *batchParentOp,
const DenseSet<AMDAIE::ConnectionOp> &connectionOps,
const DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
AMDAIE::NpuHalfDmaCpyNdOp currHalfDmaCpyNdOp, DmaBdIdPair currBdIdPair) {
// Not in the same scope? Can't fold.
if (currHalfDmaCpyNdOp->getParentOp() != batchParentOp) return false;

// Connection op already in the batch, or an empty batch? Can't fold.
AMDAIE::ConnectionOp connectionOp = currBdIdPair.first.second;
if (connectionOps.contains(connectionOp) || connectionOps.empty())
return false;

// Packet flow? Can't fold.
std::optional<AMDAIE::FlowOp> maybeFlowOp = connectionOp.getFlowOp();
if (!maybeFlowOp) {
return connectionOp.emitOpError()
<< "expected to operate on an `amdaie.flow`";
}
AMDAIE::FlowOp flowOp = maybeFlowOp.value();
if (flowOp.getIsPacketFlow()) return false;

// Duplicate BD ID on the same tile? Can't fold.
AMDAIE::TileOp tileOp = currBdIdPair.first.first;
uint32_t currBdIdVal = currBdIdPair.second;
bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(currBdIdVal);
});
if (isDuplicateBdId) return false;

// Can fold.
return true;
}

/// Traverses the control code in reverse, ensuring that only one DMA wait op is
/// retained for every batch of DMA copy operations.
///
/// Example Input:
/// %0 = dma_cpy_nd(connection0)
/// dma_wait(%0)
/// %1 = dma_cpy_nd(connection1)
/// %2 = dma_cpy_nd(connection2)
/// %3 = dma_cpy_nd(connection3)
/// dma_wait(%1)
/// dma_wait(%2)
/// dma_wait(%3)
/// Example Output:
/// %0 = dma_cpy_nd(connection0)
/// %1 = dma_cpy_nd(connection1)
/// %2 = dma_cpy_nd(connection2)
/// %3 = dma_cpy_nd(connection3)
/// dma_wait(%0, %1, %2, %3)
/// Reverse traversal simplifies handling duplicate connections, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
SmallVector<AMDAIE::NpuDmaWaitOp> waitOps;
DenseSet<AMDAIE::ConnectionOp> connectionOps;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;

auto updateWithCurrBdId =
[&](bool canFold, DmaBdIdPair currBdIdPair,
DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap) {
DmaBdIdKey currBdIdKey = currBdIdPair.first;
uint32_t currBdIdVal = currBdIdPair.second;
if (!canFold) {
// Clear the BD IDs for all the connections in the batch.
for (auto &entry : dmaBdIdsMap) {
ConnectionOp connectionOp = entry.first.second;
DenseSet<uint32_t> &bdIds = entry.second;
if (connectionOps.contains(connectionOp)) bdIds.clear();
}
connectionOps.clear();
}
connectionOps.insert(currBdIdKey.second);
dmaBdIdsMap[currBdIdKey].insert(currBdIdVal);
};

// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toBatch = true;
Operation *batchParentOp =
waitOps.empty() ? waitOp->getParentOp() : waitOps[0]->getParentOp();
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
// Retrieve the TileOp, ConnectionOp, and BD ID.
FailureOr<DmaBdIdPair> currBdIdPair =
retrieveDmaBdIdPair(npuHalfDmaCpyNdOp);
if (failed(currBdIdPair)) return WalkResult::interrupt();
// Check if the current DMA wait op can be folded into the batch.
FailureOr<bool> canFold =
canFoldByBatch(batchParentOp, connectionOps, dmaBdIdsMap,
npuHalfDmaCpyNdOp, *currBdIdPair);
if (failed(canFold)) return WalkResult::interrupt();
// Update the `connectionOps` and `dmaBdIdsMap`.
updateWithCurrBdId(*canFold, *currBdIdPair, connectionOps,
dmaBdIdsMap);
toBatch &= *canFold;
}
}
// Process the previous batch of wait ops, and start a new batch.
if (!toBatch) {
// Since the controlcode is traversed in reverse order, we need to
// restore the original order of the DMA operations.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseBatchOperations(rewriter, waitOps)))
return WalkResult::interrupt();
waitOps.clear();
}
waitOps.push_back(waitOp);
return WalkResult::advance();
});

if (res.wasInterrupted()) return failure();
// Process the remaining wait ops.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseBatchOperations(rewriter, waitOps))) return failure();
return success();
}

class AMDAIEFoldDmaWaitsPass
: public impl::AMDAIEFoldDmaWaitsBase<AMDAIEFoldDmaWaitsPass> {
public:
Expand Down Expand Up @@ -181,7 +371,10 @@ void AMDAIEFoldDmaWaitsPass::runOnOperation() {

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
if (failed(foldDmaWaits(deviceModel, controlCodeOp))) {
if (failed(foldDmaWaitsByQueue(deviceModel, controlCodeOp))) {
return WalkResult::interrupt();
}
if (failed(foldDmaWaitsByBatch(controlCodeOp))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
Expand Down
Loading

0 comments on commit 2080473

Please sign in to comment.