Skip to content

Commit

Permalink
retrive current BD ID key value in a separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 18, 2024
1 parent 62cb077 commit 672e29a
Showing 1 changed file with 71 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,39 @@ namespace mlir::iree_compiler::AMDAIE {
namespace {

using DmaBdIdKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;
using DmaBdIdPair = std::pair<DmaBdIdKey, uint32_t>;

FailureOr<DmaBdIdPair> retriveDmaBdIdPair(
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
uint32_t currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}

DmaBdIdKey currBdIdKey = {tileOp, connectionOp};
return DmaBdIdPair{currBdIdKey, currBdIdVal};
}

/// Utility function to determine whether a DMA wait op can be folded based on
/// its half DMA copy operation.
Expand Down Expand Up @@ -176,65 +209,44 @@ LogicalResult eraseBatchOperations(IRRewriter &rewriter,

/// Utility function to determine if a DMA wait operation can be folded into a
/// a batch based on its half DMA copy operation.
/// Can't fold wait op if:
/// (1) the current operation is not in the same scope as the batch, or
/// (2) the current connection op already occurs in the batch, or
/// (3) the batch is empty, or
/// (4) the current operation is a packet flow, or
/// (5) the current BD ID on the same tile already occurs in the batch.
FailureOr<bool> canFoldByBatch(
const Operation *batchParentOp,
const DenseSet<AMDAIE::ConnectionOp> &connectionOps,
const DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t &currBdIdVal,
AMDAIE::NpuHalfDmaCpyNdOp currHalfDmaCpyNdOp) {
// Check if the current operation is in the same scope as the rest of the
// batch.
bool isSameScope = currHalfDmaCpyNdOp->getParentOp() == batchParentOp;
AMDAIE::NpuHalfDmaCpyNdOp currHalfDmaCpyNdOp, DmaBdIdPair &currBdIdPair) {
// Not in the same scope? Can't fold.
if (currHalfDmaCpyNdOp->getParentOp() != batchParentOp) return false;

// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
currHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return currHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();
bool isDuplicateConnection = connectionOps.contains(connectionOp);
// Connection op already in the batch, or an empty batch? Can't fold.
AMDAIE::ConnectionOp connectionOp = currBdIdPair.first.second;
if (connectionOps.contains(connectionOp) || connectionOps.empty())
return false;

// Retrieve the flow op.
// Packet flow? Can't fold.
std::optional<AMDAIE::FlowOp> maybeFlowOp = connectionOp.getFlowOp();
if (!maybeFlowOp) {
return connectionOp.emitOpError()
<< "expected to operate on an `amdaie.flow`";
}
AMDAIE::FlowOp flowOp = maybeFlowOp.value();
bool isPacketFlow = flowOp.getIsPacketFlow();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = currHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return currHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}
currBdIdKey = {tileOp, connectionOp};
if (flowOp.getIsPacketFlow()) return false;

// Duplicate BD ID on the same tile? Can't fold.
AMDAIE::TileOp tileOp = currBdIdPair.first.first;
uint32_t currBdIdVal = currBdIdPair.second;
bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(currBdIdVal);
});
if (isDuplicateBdId) return false;

// Can't fold wait op if:
// (1) the current connection op already occurs in the batch, or
// (2) the current BD ID on the same tile already occurs in the batch, or
// (3) the current operation is a packet flow, or
// (4) the batch is empty, or
// (5) the current operation is not in the same scope as the batch.
return !(isDuplicateConnection || isDuplicateBdId || isPacketFlow ||
connectionOps.empty() || !isSameScope);
// Can fold.
return true;
}

/// Traverses the control code in reverse, ensuring that only one DMA wait op is
Expand Down Expand Up @@ -265,11 +277,11 @@ LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;

auto updateWithCurrBdId =
[&](bool canFold, DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t currBdIdVal) {
assert(currBdIdKey.first && "TileOp must not be null");
assert(currBdIdKey.second && "ConnectionOp must not be null");
[&](bool canFold, DmaBdIdPair &currBdIdPair,
DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap) {
DmaBdIdKey currBdIdKey = currBdIdPair.first;
uint32_t currBdIdVal = currBdIdPair.second;
if (!canFold) {
// Clear the BD IDs for all the connections in the batch.
for (auto &entry : dmaBdIdsMap) {
Expand All @@ -293,15 +305,19 @@ LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
DmaBdIdKey currBdIdKey = {nullptr, nullptr};
uint32_t currBdIdVal = 0;
FailureOr<bool> result =
// Retrieve the TileOp, ConnectionOp, and BD ID.
FailureOr<DmaBdIdPair> currBdIdPair =
retriveDmaBdIdPair(npuHalfDmaCpyNdOp);
if (failed(currBdIdPair)) return WalkResult::interrupt();
// Check if the current DMA wait op can be folded into the batch.
FailureOr<bool> canFold =
canFoldByBatch(batchParentOp, connectionOps, dmaBdIdsMap,
currBdIdKey, currBdIdVal, npuHalfDmaCpyNdOp);
if (failed(result)) return WalkResult::interrupt();
toBatch &= *result;
updateWithCurrBdId(*result, connectionOps, dmaBdIdsMap, currBdIdKey,
currBdIdVal);
npuHalfDmaCpyNdOp, *currBdIdPair);
if (failed(canFold)) return WalkResult::interrupt();
// Update the `connectionOps` and `dmaBdIdsMap`,
updateWithCurrBdId(*canFold, *currBdIdPair, connectionOps,
dmaBdIdsMap);
toBatch &= *canFold;
}
}
// Process the previous batch of wait ops, and start a new batch.
Expand Down

0 comments on commit 672e29a

Please sign in to comment.