Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Nov 29, 2024
1 parent 770fd34 commit 3104c84
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 51 deletions.
10 changes: 9 additions & 1 deletion compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,14 @@ def AMDAIE_NpuHalfDmaCpyNdOp
ShapedType::kDynamic encodes that the corresponding entry has a dynamic
value.

It also supports the representation of DMA BD chaining using the `use_next_bd`,
`next_bd`, and `start_bd` operands. The `use_next_bd` operand indicates
whether another DMA operation is chained to follow this one.
If `use_next_bd` is `true`, the `next_bd` operand specifies the BD ID of
the next DMA operation in the chain. Within a chain, the `start_bd` operand
identifies the BD ID of the first DMA operation in the sequence.
When `use_next_bd` is `false`, the `start_bd` is set to the same value as `bd_id`.

Example:

```mlir
Expand All @@ -604,7 +612,7 @@ def AMDAIE_NpuHalfDmaCpyNdOp
%5 = amdaie.logicalobjectfifo.from_memref %0, {%tile_0_0}
: memref<32x1024xi32> -> !amdaie.logicalobjectfifo<memref<32768xi32>>
%4 = amdaie.npu.half_dma_cpy_nd async %2(%0[0, 0] [32, 64] [1024, 1]
bd_id = %bd_id channel = %channel)
bd_id = %bd_id channel = %channel use_next_bd = false start_bd = %bd_id)
...
}
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
AMDAIE::WorkgroupOp workgroupOp) {
IRRewriter rewriter(workgroupOp->getContext());

// TODO(Zhewen): assign BD IDs here, to get rid of tileArgIdxToAssignedBdIdOp
// BD ID currently assigned to each DMA operation, used to track the lifetime
// TODO(Zhewen): to get rid of tileArgIdxToAssignedBdIdOps and
// tileArgIdxToDmaCount, integrate BD ID assignment and (partial) control code
// loop unrolling into this pass.

// BD ID that are currenly assigned to DMA operations
DenseMap<std::pair<AMDAIE::TileOp, uint32_t>, SmallVector<AMDAIE::BdIdOp>>
tileArgIdxToAssignedBdIdOps;
// TODO(Zhewen): unroll loops here, to get rid of tileArgIdxToDmaCount
// Counter for the number of DMA operations, helping determine the dependency
DenseMap<std::pair<AMDAIE::TileOp, uint32_t>, uint32_t> tileArgIdxToDmaCount;

// Last DMA operation encountered for each tile argument index pair
// no matter if it is chained or not
// Last DMA operation encountered, no matter if it is chained or not
DenseMap<std::pair<AMDAIE::TileOp, uint32_t>, AMDAIE::NpuHalfDmaCpyNdOp>
tileArgIdxToLastDmaOp;
// Last DMA operation that has been chained for each tile argument index pair
// Last DMA operation that has been chained
DenseMap<std::pair<AMDAIE::TileOp, uint32_t>, AMDAIE::NpuHalfDmaCpyNdOp>
tileArgIdxToLastChainedDmaOp;
// Black list of tile argument index pairs that should not be chained
Expand All @@ -41,7 +43,8 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
WalkResult res = controlCodeOp->walk([&](Operation *op) {
if (auto npuHalfDmaCpyNdOp = dyn_cast<AMDAIE::NpuHalfDmaCpyNdOp>(op)) {
// not shim, do not chain BDs
// not shim, no need to chain, since it will be earsed when lowering to
// NPU instructions
if (npuHalfDmaCpyNdOp.getMemorySpaceAsUInt() != 0) {
return WalkResult::advance();
}
Expand Down Expand Up @@ -118,6 +121,10 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
return WalkResult::interrupt();
}
uint32_t argIdx = subspanOp.getBinding().getZExtValue();

// If the current DMA operation was previously part of the outer loop in
// the control code, force all DMA operations in the inner loop to be
// synchronized, by adding them to the black list.
tileArgIdxToDmaCount[{tileOp, argIdx}]++;
for (auto &[pair, count] : tileArgIdxToDmaCount) {
if (pair.first == tileOp &&
Expand All @@ -127,8 +134,9 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
}
}
}
// if the BD ID is currently used by another DMA op, stop the chain
// for that DMA op from further growing, so that BD ID can be released

// If the BD ID is currently used by another DMA op, stop the chain
// for that DMA op from further growing, by adding it to the black list
for (auto &[pair, bdIdOps] : tileArgIdxToAssignedBdIdOps) {
if (pair.first == tileOp && llvm::is_contained(bdIdOps, bdIdOp)) {
if (!llvm::is_contained(tileArgIdxsBlackList, pair)) {
Expand All @@ -138,12 +146,26 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
}
}

// if not blacklisted and there is a previous DMA op, chain the BD IDs
// If the black list is not empty, there will be a synchronization.
// Make sure all other DMA chains also break at this point to avoid
// dependency issues.
if (tileArgIdxsBlackList.size() > 0) {
for (auto &[pair, bdIdOps] : tileArgIdxToAssignedBdIdOps) {
if (pair.first == tileOp && bdIdOps.size() > 1) {
if (!llvm::is_contained(tileArgIdxsBlackList, pair)) {
tileArgIdxsBlackList.push_back(pair);
}
}
}
}

// When current DMA has not been blacklisted and a previous DMA with same
// argIdx exists, chain them together
chaining &= !llvm::is_contained(tileArgIdxsBlackList,
std::make_pair(tileOp, argIdx)) &&
tileArgIdxToLastDmaOp.contains({tileOp, argIdx});
if (chaining) {
// update previous NpuHalfDmaCpyNdOp by changing its useNextBd and
// update the previous DMA op by changing its useNextBd and
// nextBd
AMDAIE::NpuHalfDmaCpyNdOp lastDmaOp =
tileArgIdxToLastDmaOp[{tileOp, argIdx}];
Expand All @@ -156,7 +178,7 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
lastDmaOp.getChannel(), true, bdIdOp, lastDmaOp.getStartBd());
rewriter.replaceOp(lastDmaOp, chainedDmaOp.getResults());
tileArgIdxToLastChainedDmaOp[{tileOp, argIdx}] = chainedDmaOp;
// update current NpuHalfDmaCpyNdOp by changing its startBd
// update the current DMA op by changing its startBd
rewriter.setInsertionPoint(npuHalfDmaCpyNdOp);
auto npuHalfDmaCpyNdOpNew = rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
npuHalfDmaCpyNdOp.getLoc(), npuHalfDmaCpyNdOp.getResultTypes(),
Expand All @@ -171,25 +193,17 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
npuHalfDmaCpyNdOp = npuHalfDmaCpyNdOpNew;
}

// update the BD ID assignment
// Update BD ID assignment, if it is chaining, safely release the BD IDs
// since a synchronization will happen
if (chaining && tileArgIdxToAssignedBdIdOps.contains({tileOp, argIdx})) {
tileArgIdxToAssignedBdIdOps[{tileOp, argIdx}].push_back(bdIdOp);
} else {
tileArgIdxToAssignedBdIdOps[{tileOp, argIdx}] = {bdIdOp};
}

// not chaining, update the black list

if (tileArgIdxsBlackList.size() > 0) {
for (auto &[pair, bdIdOps] : tileArgIdxToAssignedBdIdOps) {
if (pair.first == tileOp && bdIdOps.size() > 1) {
if (!llvm::is_contained(tileArgIdxsBlackList, pair)) {
tileArgIdxsBlackList.push_back(pair);
}
}
}
}

// The current DMA op is not chained with the previous DMA op (i.e.
// synchroizaiton will happen between these two ops), removing from the
// black list
if (!chaining) {
auto it =
std::find(tileArgIdxsBlackList.begin(), tileArgIdxsBlackList.end(),
Expand All @@ -198,19 +212,24 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
tileArgIdxsBlackList.erase(it);
}
}
// update the last DMA op
// Update the last encountered DMA op
tileArgIdxToLastDmaOp[{tileOp, argIdx}] = npuHalfDmaCpyNdOp;

} else if (auto npuDmaWaitOp = dyn_cast<AMDAIE::NpuDmaWaitOp>(op)) {
// Handle the special case where the blacklist is not empty. This could
// happen when there are multiple DMA operations associated with the same
// tile but different argIdx before a DMA wait. In such cases, one DMA
// operation might initially get chained, but a subsequent DMA operation
// may later report that the chain must be broken to release the BD IDs.
// Handle the special case where there are multiple DMA ops preceding any
// Wait op. In such a case, some DMA ops may be chained first, before they
// are put onto the black list. Therefore, go over the black list and
// unchain the DMA ops when required.

for (auto &[tileOp, argIdx] : tileArgIdxsBlackList) {
if (tileArgIdxToLastChainedDmaOp.contains({tileOp, argIdx})) {
if (tileArgIdxToLastChainedDmaOp.contains({tileOp, argIdx}) &&
tileArgIdxToLastDmaOp.contains({tileOp, argIdx})) {
// break the chain lastChainedDmaOp -> lastDmaOp
AMDAIE::NpuHalfDmaCpyNdOp lastChainedDmaOp =
tileArgIdxToLastChainedDmaOp[{tileOp, argIdx}];
AMDAIE::NpuHalfDmaCpyNdOp lastDmaOp =
tileArgIdxToLastDmaOp[{tileOp, argIdx}];
// revert useNextBd and nextBd in lastChainedDmaOp
bool useNextBd{false};
Value nextBd{nullptr};
rewriter.setInsertionPointAfter(lastChainedDmaOp);
Expand All @@ -224,23 +243,20 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
lastChainedDmaOp.getStartBd());
rewriter.replaceOp(lastChainedDmaOp, unchainedDmaOp.getResults());
tileArgIdxToLastChainedDmaOp.erase({tileOp, argIdx});

if (tileArgIdxToLastDmaOp.contains({tileOp, argIdx})) {
AMDAIE::NpuHalfDmaCpyNdOp lastDmaOp =
tileArgIdxToLastDmaOp[{tileOp, argIdx}];
rewriter.setInsertionPoint(lastDmaOp);
auto lastDmaOpNew = rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
lastDmaOp.getLoc(), lastDmaOp.getResultTypes(),
lastDmaOp.getConnection(), lastDmaOp.getInput(),
lastDmaOp.getMixedOffsets(), lastDmaOp.getMixedSizes(),
lastDmaOp.getMixedStrides(), lastDmaOp.getBdId(),
lastDmaOp.getChannel(), lastDmaOp.getUseNextBd(),
lastDmaOp.getNextBd(), lastDmaOp.getBdId());
tileArgIdxToAssignedBdIdOps[{tileOp, argIdx}] = {
lastDmaOp.getBdIdOp().value()};
rewriter.replaceOp(lastDmaOp, lastDmaOpNew.getResults());
tileArgIdxToLastDmaOp[{tileOp, argIdx}] = lastDmaOpNew;
}
// revert startBd in lastDmaOp
auto startBd = lastDmaOp.getBdId();
rewriter.setInsertionPoint(lastDmaOp);
unchainedDmaOp = rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
lastDmaOp.getLoc(), lastDmaOp.getResultTypes(),
lastDmaOp.getConnection(), lastDmaOp.getInput(),
lastDmaOp.getMixedOffsets(), lastDmaOp.getMixedSizes(),
lastDmaOp.getMixedStrides(), lastDmaOp.getBdId(),
lastDmaOp.getChannel(), lastDmaOp.getUseNextBd(),
lastDmaOp.getNextBd(), startBd);
tileArgIdxToAssignedBdIdOps[{tileOp, argIdx}] = {
lastDmaOp.getBdIdOp().value()};
rewriter.replaceOp(lastDmaOp, unchainedDmaOp.getResults());
tileArgIdxToLastDmaOp[{tileOp, argIdx}] = unchainedDmaOp;
} else {
npuDmaWaitOp.emitError() << "unhandled situation in DMA BD chaining, "
"please try to disable this pass";
Expand All @@ -253,7 +269,7 @@ LogicalResult dmaBdChain(AMDAIE::AMDAIEDeviceModel deviceModel,
return WalkResult::advance();
});

// erase wait op unless it is at the end of a chain
// Only keep DMA Wait Ops if at the end of a chain, erase others
res = controlCodeOp->walk([&](Operation *op) {
if (auto npuDmaWaitOp = dyn_cast<AMDAIE::NpuDmaWaitOp>(op)) {
bool toErase = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,24 @@ LogicalResult moveNpuDmaSyncUsersAfterAncestorInSameBlock(
return success();
}

// Move NPU DMA wait operations with async_source tokens as late as possible
// (after the target DMA wait operation which has async_target token) This is to
// help later optimizations such as DMA BD chaining. Example:
//
// %0 = dma_cpy_nd async_source
// dma_wait(%0 : !amdaie.async_source_token)
// %1 = dma_cpy_nd async_source
// dma_wait(%1 : !amdaie.async_source_token)
// %2 = dma_cpy_nd async_target
// dma_wait(%2 : !amdaie.async_target_token)
// ------------------------------->>>>>>>>>>
// %0 = dma_cpy_nd async_source
// %1 = dma_cpy_nd async_source
// %2 = dma_cpy_nd async_target
// dma_wait(%2 : !amdaie.async_target_token)
// dma_wait(%0 : !amdaie.async_source_token)
// dma_wait(%1 : !amdaie.async_source_token)

LogicalResult moveNpuSourceDmaSyncAfterTargetDmaCpy(RewriterBase &rewriter,
Operation *parentOp) {
// Stores NPU source DMA wait operations to be moved later.
Expand Down

0 comments on commit 3104c84

Please sign in to comment.