diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFoldDmaWaits.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFoldDmaWaits.cpp index 21b0af3f5..973d30449 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFoldDmaWaits.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFoldDmaWaits.cpp @@ -18,63 +18,17 @@ namespace { using DmaBdIdKey = std::pair; -/// Utility function to erase the DMA wait operations in the queue, except for -/// the last one. -LogicalResult eraseQueueOperations(IRRewriter &rewriter, - SmallVector &waitOps) { - // Skip if there are less than two DMA wait operations in the queue. - if (waitOps.size() < 2) return success(); - - Operation *parentOp = waitOps.back()->getParentOp(); - // Do not modify the last wait op, it will be kept. - waitOps.pop_back(); - - for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) { - if (waitOp->getParentOp() != parentOp) { - return waitOp.emitError( - "DMA operations to be queued must belong to the same scope"); - } - // Erase the wait op. - SmallVector asyncTokens(waitOp.getAsyncTokens()); - rewriter.eraseOp(waitOp); - for (Value token : asyncTokens) { - auto dmaOp = - dyn_cast_if_present(token.getDefiningOp()); - if (!dmaOp) - waitOp.emitError("expected to operate on an `amdaie.half_dma_cpy_nd`"); - if (dmaOp.use_empty()) { - rewriter.setInsertionPoint(dmaOp); - TypeRange resultTypeRange = TypeRange{}; - // Nullify the result to avoid issuing a token. - rewriter.create( - dmaOp.getLoc(), resultTypeRange, dmaOp.getConnection(), - dmaOp.getInput(), dmaOp.getMixedOffsets(), dmaOp.getMixedSizes(), - dmaOp.getMixedStrides(), dmaOp.getBdId(), dmaOp.getChannel(), - dmaOp.getNextBd(), dmaOp.getStartBd()); - rewriter.eraseOp(dmaOp); - } - } - } - return success(); -} - -/// Utility function to determine whether a DMA wait op can be folded into a -/// queue based on its half DMA copy operation. +/// Utility function to determine whether a DMA wait op can be folded based on +/// its half DMA copy operation. FailureOr canFoldByQueue( const AMDAIE::AMDAIEDeviceModel &deviceModel, - const Operation *queueParentOp, - const DenseMap> &dmaBdIdsMap, - DmaBdIdKey &currBdIdKey, uint32_t &currBdIdVal, - AMDAIE::NpuHalfDmaCpyNdOp &currHalfDmaCpyNdOp) { - // Check if the current operation is in the same scope as the rest of the - // queue. - bool isSameScope = currHalfDmaCpyNdOp->getParentOp() == queueParentOp; - + AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp, + DenseMap> &tileConnectToBdIdQueue) { // Retrieve the connection op. std::optional maybeConnectionOp = - currHalfDmaCpyNdOp.getConnectionOp(); + npuHalfDmaCpyNdOp.getConnectionOp(); if (!maybeConnectionOp) { - return currHalfDmaCpyNdOp.emitOpError() + return npuHalfDmaCpyNdOp.emitOpError() << "expected to operate on an `amdaie.connection`"; } AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value(); @@ -82,21 +36,20 @@ FailureOr canFoldByQueue( // Retrieve the flow op. std::optional maybeFlowOp = connectionOp.getFlowOp(); if (!maybeFlowOp) { - return connectionOp.emitOpError() + return connectionOp->emitOpError() << "expected to operate on an `amdaie.flow`"; } AMDAIE::FlowOp flowOp = maybeFlowOp.value(); bool isPacketFlow = flowOp.getIsPacketFlow(); // Retrieve the BD ID op. - std::optional maybeBdIdOp = currHalfDmaCpyNdOp.getBdIdOp(); + std::optional maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp(); if (!maybeBdIdOp) { - return currHalfDmaCpyNdOp.emitOpError() + return npuHalfDmaCpyNdOp.emitOpError() << "must have a BD ID op to lower to " "`amdaie.npu.write_bd`"; } AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value(); - currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue()); // Retrieve the tile op. AMDAIE::TileOp tileOp = @@ -104,39 +57,44 @@ FailureOr canFoldByQueue( if (!tileOp) { return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`"; } - currBdIdKey = {tileOp, connectionOp}; // Get the maximum queue size. uint32_t col = getConstantIndexOrAssert(tileOp.getCol()); uint32_t row = getConstantIndexOrAssert(tileOp.getRow()); uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row); - bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) { - return entry.first.first == tileOp && entry.second.contains(currBdIdVal); - }); - const DenseSet &bdIds = dmaBdIdsMap.lookup(currBdIdKey); - - // Can't fold wait op if: - // (1) the current BD ID on the same tile already occurs in the queue, or - // (2) the current operation is a packet flow, or - // (3) reaches the maximum queue size, or - // (4) the queue is empty, or - // (5) the current operation is not in the same scope as the queue. - return !(isDuplicateBdId || isPacketFlow || bdIds.size() >= maxQueueSize || - bdIds.empty() || !isSameScope); + // Keep wait op if, either reaches the maximum queue size, or a + // duplicate BD ID in the same tile, or packet flow, or the queue is + // empty + uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue()); + bool isDuplicateBdId = + llvm::any_of(tileConnectToBdIdQueue, [&](const auto &entry) { + return entry.first.first == tileOp && + llvm::is_contained(entry.second, bdId); + }); + SmallVector &bdIdQueue = + tileConnectToBdIdQueue[{tileOp, connectionOp}]; + bool canFold = true; + if (isDuplicateBdId || isPacketFlow || bdIdQueue.size() >= maxQueueSize || + bdIdQueue.empty()) { + bdIdQueue.clear(); + canFold = false; + } + bdIdQueue.push_back(bdId); + return canFold; } /// Traverses the control code in reverse, ensuring that for each connection, /// only one DMA wait op is retained for every maximum queue size. /// /// Example Output: assuming a maximum queue size of 4. -/// dma_cpy_nd(connection=0, bd_id=0) -/// %0 = dma_cpy_nd(connection=0, bd_id=1) +/// dma_cpy_nd +/// %0 = dma_cpy_nd /// dma_wait(%0) -/// dma_cpy_nd(connection=0, bd_id=2) -/// dma_cpy_nd(connection=0, bd_id=3) -/// dma_cpy_nd(connection=0, bd_id=4) -/// %1 = dma_cpy_nd(connection=0, bd_id=5) +/// dma_cpy_nd +/// dma_cpy_nd +/// dma_cpy_nd +/// %1 = dma_cpy_nd /// dma_wait(%1) /// From the bottom up, for every four DMA copy operations, only one DMA wait /// operation is retained. @@ -147,57 +105,49 @@ FailureOr canFoldByQueue( LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel, AMDAIE::ControlCodeOp controlCodeOp) { IRRewriter rewriter(controlCodeOp->getContext()); - SmallVector> waitOpQueues; - DenseMap> dmaBdIdsMap; - - auto updateWithCurrBdId = - [&](bool canFold, DenseMap> &dmaBdIdsMap, - DmaBdIdKey &currBdIdKey, uint32_t currBdIdVal) { - assert(currBdIdKey.first && "TileOp must not be null"); - assert(currBdIdKey.second && "ConnectionOp must not be null"); - if (!canFold) dmaBdIdsMap[currBdIdKey].clear(); - dmaBdIdsMap[currBdIdKey].insert(currBdIdVal); - }; - + std::vector waitOpsToErase; + DenseMap> tileConnectToBdIdQueue; // Traverse the control code in reverse. WalkResult res = controlCodeOp->walk( [&](AMDAIE::NpuDmaWaitOp waitOp) { - bool toFold = true; - Operation *queueParentOp = - waitOpQueues.empty() ? waitOp->getParentOp() - : waitOpQueues.back().front()->getParentOp(); + bool toErase = true; for (Value token : waitOp.getAsyncTokens()) { if (auto npuHalfDmaCpyNdOp = dyn_cast_if_present( token.getDefiningOp())) { - DmaBdIdKey currBdIdKey = {nullptr, nullptr}; - uint32_t currBdIdVal = 0; - FailureOr result = - canFoldByQueue(deviceModel, queueParentOp, dmaBdIdsMap, - currBdIdKey, currBdIdVal, npuHalfDmaCpyNdOp); + FailureOr result = canFoldByQueue( + deviceModel, npuHalfDmaCpyNdOp, tileConnectToBdIdQueue); if (failed(result)) return WalkResult::interrupt(); - toFold &= *result; - updateWithCurrBdId(*result, dmaBdIdsMap, currBdIdKey, currBdIdVal); + toErase &= *result; } } - // Store all the queues, and modify later to avoid invalidating the - // iterator. - if (toFold) { - // Append the wait op to the last queue if it can be folded. - waitOpQueues.back().push_back(waitOp); - } else { - // Create a new queue if the wait op cannot be folded. - waitOpQueues.push_back(SmallVector{waitOp}); - } + // Erase later to avoid invalidating the iterator. + if (toErase) waitOpsToErase.push_back(waitOp); return WalkResult::advance(); }); if (res.wasInterrupted()) return failure(); - for (SmallVector &waitOps : waitOpQueues) { - // Since the controlcode is traversed in reverse order, we need to - // restore the original order of the DMA operations. - std::reverse(waitOps.begin(), waitOps.end()); - if (failed(eraseQueueOperations(rewriter, waitOps))) return failure(); + + for (AMDAIE::NpuDmaWaitOp waitOp : waitOpsToErase) { + SmallVector asyncTokens(waitOp.getAsyncTokens()); + // Erase the wait op. + rewriter.eraseOp(waitOp); + for (Value token : asyncTokens) { + if (auto op = dyn_cast_if_present( + token.getDefiningOp())) { + if (op.use_empty()) { + rewriter.setInsertionPoint(op); + TypeRange resultTypeRange = TypeRange{}; + // Nullify the result to avoid issuing a token. + rewriter.create( + op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(), + op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(), + op.getBdId(), op.getChannel(), op.getNextBd(), op.getStartBd()); + rewriter.eraseOp(op); + } + } + } } + return success(); } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp index 352c8e500..b21ceb025 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp @@ -17,7 +17,7 @@ namespace mlir::iree_compiler::AMDAIE { namespace { -using DmaChainKey = std::pair; +using DmaChain = std::pair; /// Utility function to update `next_bd` and `start_bd` operands. LogicalResult updateChainOperands( @@ -83,9 +83,9 @@ LogicalResult updateChainOperands( /// - Chain X: [0] (the newly added BD ID). /// - Chain Y: [] (emptied after breaking). void checkForChainsToBeBroken( - uint32_t currBdId, const DmaChainKey &currDmaChain, - const DenseMap> &dmaChainToBdIds, - SmallVector &chainsToBreak) { + uint32_t currBdId, const DmaChain &currDmaChain, + const DenseMap> &dmaChainToBdIds, + SmallVector &chainsToBreak) { for (auto &[entry, bdIds] : dmaChainToBdIds) { if (entry.first == currDmaChain.first && bdIds.contains(currBdId)) { // Break the chain that contains the duplicate BD ID. @@ -120,10 +120,9 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel, } // BD IDs that have been assigned in each tile. - DenseMap> dmaChainToBdIds; + DenseMap> dmaChainToBdIds; // Buffers the DMA ops that will be chained. - DenseMap> - dmaChainToDmaOps; + DenseMap> dmaChainToDmaOps; res = controlCodeOp->walk([&](Operation *op) { @@ -186,8 +185,8 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel, // Any duplicate BD ID from the same tile indicates that the chain // cannot grow further and requires breaking to release the // conflicting BD ID. - SmallVector chainsToBreak; - DmaChainKey currDmaChain = {tileOp, connectionOp}; + SmallVector chainsToBreak; + DmaChain currDmaChain = {tileOp, connectionOp}; checkForChainsToBeBroken(bdId, currDmaChain, dmaChainToBdIds, chainsToBreak);