Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NpuTctSync operation #990

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,31 @@ def AMDAIE_NpuWriteBdOp: AMDAIE_Op<"npu.write_bd"> {
let assemblyFormat = [{ attr-dict }];
}

def AMDAIE_NpuTctSyncOp: AMDAIE_Op<"npu.tct_sync"> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a roundtrip test for this op as well?

let summary = "Wait for the TCTs to be emitted.";
let description = [{
This NPU controller operation to synchronize the Task Completion Tokens (TCTs)
on the specified `channel` and `direction`. The ranges of tiles to synchronize
are defined by [col, col+col_num) and [row, row+row_num).

Example:

```mlir
amdaie.npu.tct_sync {col = 0 : ui32, row = 0 : ui32, channel = 0 : ui32,
direction = 1 : i32, col_num = 1 : ui32, row_num = 1 : ui32}
```
}];
let arguments = (
ins UI32Attr:$col,
UI32Attr:$row,
DMAChannelDir:$direction,
UI32Attr:$channel,
UI32Attr:$col_num,
UI32Attr:$row_num
);
let assemblyFormat = [{ attr-dict }];
}

//===----------------------------------------------------------------------===//
// IREE AMDAIE LogicalObjectFifo Ops
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,15 @@ func.func @npu_write_bd() {

// -----

// CHECK-LABEL: func.func @npu_tct_sync
// CHECK: amdaie.npu.tct_sync {channel = 0 : ui32, col = 0 : ui32, col_num = 2 : ui32, direction = 1 : i32, row = 0 : ui32, row_num = 1 : ui32}
func.func @npu_tct_sync() {
amdaie.npu.tct_sync {channel = 0 : ui32, col = 0 : ui32, col_num = 2 : ui32, direction = 1 : i32, row = 0 : ui32, row_num = 1 : ui32}
return
}

// -----

// CHECK-LABEL: func.func @workgroup
// CHECK: amdaie.workgroup
// CHECK: amdaie.core
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,62 @@ struct HalfDmaCpyNdToNpuConverter final
uint8_t minStrideBitWidth;
};

struct DmaWaitToTctSyncConverter final
: OpConversionPattern<AMDAIE::NpuDmaWaitOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
AMDAIE::NpuDmaWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "matchAndRewrite[AMDAIE::NpuDmaWaitOp]\n");
// Collect all half DMA ops from the async tokens.
SmallVector<AMDAIE::NpuPushToQueueOp> pushToQueueOps;
for (Value asyncToken : op.getAsyncTokens()) {
auto pushToQueueOp = dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(
asyncToken.getDefiningOp());
if (!pushToQueueOp) {
return op.emitOpError()
<< "should operate on an `amdaie.push_to_queue` op async token";
}
pushToQueueOps.push_back(pushToQueueOp);
}
// Sort the half DMA ops by direction, channel, row, and column.
std::sort(pushToQueueOps.begin(), pushToQueueOps.end(),
[](AMDAIE::NpuPushToQueueOp a, AMDAIE::NpuPushToQueueOp b) {
return std::make_tuple(a.getDirection(), a.getChannel(),
a.getRow(), a.getCol()) <
std::make_tuple(b.getDirection(), b.getChannel(),
b.getRow(), b.getCol());
});
// Batch DMA operations with the same row, channel, and direction into a
// single TCT sync operation, as long as they have consecutive columns.
llvm::MapVector<AMDAIE::NpuPushToQueueOp, uint32_t> columnBatches;
for (auto pushToQueueOp : pushToQueueOps) {
if (!columnBatches.empty()) {
auto &[lastPushOp, lastColNum] = columnBatches.back();
if (lastPushOp.getRow() == pushToQueueOp.getRow() &&
lastPushOp.getCol() + lastColNum == pushToQueueOp.getCol() &&
lastPushOp.getDirection() == pushToQueueOp.getDirection() &&
lastPushOp.getChannel() == pushToQueueOp.getChannel()) {
++lastColNum;
continue;
}
}
columnBatches.insert({pushToQueueOp, 1});
}
// Convert to TCT sync ops.
for (auto &[pushToQueueOp, colNum] : columnBatches) {
uint32_t rowNum = 1;
rewriter.create<AMDAIE::NpuTctSyncOp>(
op.getLoc(), pushToQueueOp.getCol(), pushToQueueOp.getRow(),
pushToQueueOp.getDirection(), pushToQueueOp.getChannel(), colNum,
rowNum);
}
rewriter.eraseOp(op);
return success();
}
};

namespace {
class AMDAIEControlCodeLoweringPass
: public impl::AMDAIEControlCodeLoweringBase<
Expand All @@ -260,17 +316,37 @@ void AMDAIEControlCodeLoweringPass::runOnOperation() {
"ops.";
return signalPassFailure();
}
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());

RewritePatternSet patterns(context);
ConversionTarget conversionTarget(*context);
conversionTarget.addLegalDialect<AMDAIEDialect>();
conversionTarget.addIllegalOp<AMDAIE::NpuHalfDmaCpyNdOp>();
patterns.insert<HalfDmaCpyNdToNpuConverter>(context, deviceModel);
if (failed(applyPartialConversion(parentOp, conversionTarget,
std::move(patterns)))) {
return signalPassFailure();
// First conversion: HalfDmaCpyNdOp to WriteBdOp, AddressPatchOp and
// PushToQueueOp.
{
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());
RewritePatternSet patterns(context);
ConversionTarget conversionTarget(*context);
conversionTarget.addLegalDialect<AMDAIEDialect>();
conversionTarget.addIllegalOp<AMDAIE::NpuHalfDmaCpyNdOp>();
patterns.insert<HalfDmaCpyNdToNpuConverter>(context, deviceModel);

if (failed(applyPartialConversion(parentOp, conversionTarget,
std::move(patterns)))) {
return signalPassFailure();
}
}

// Second conversion: DmaWaitOp to TctSyncOp.
// The two conversions are separate to simplify the attribute handling, such
// as col, row, direction, channel, etc.
{
RewritePatternSet patterns(context);
ConversionTarget conversionTarget(*context);
conversionTarget.addLegalDialect<AMDAIEDialect>();
conversionTarget.addIllegalOp<AMDAIE::NpuDmaWaitOp>();
patterns.insert<DmaWaitToTctSyncConverter>(context);
if (failed(applyPartialConversion(parentOp, conversionTarget,
std::move(patterns)))) {
return signalPassFailure();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,50 +199,11 @@ LogicalResult convertOp(AMDAIE::NpuAddressPatchOp op,
return success();
}

LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
// Collect all half DMA ops from the async tokens.
SmallVector<AMDAIE::NpuPushToQueueOp> pushToQueueOps;
for (Value asyncToken : op.getAsyncTokens()) {
auto pushToQueueOp = dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(
asyncToken.getDefiningOp());
if (!pushToQueueOp) {
return op.emitOpError()
<< "should operate on an `amdaie.push_to_queue` op async token";
}
pushToQueueOps.push_back(pushToQueueOp);
}
// Sort the half DMA ops by channel, direction, row, and column.
std::sort(pushToQueueOps.begin(), pushToQueueOps.end(),
[](AMDAIE::NpuPushToQueueOp a, AMDAIE::NpuPushToQueueOp b) {
return std::make_tuple(a.getChannel(), a.getDirection(),
a.getRow(), a.getCol()) <
std::make_tuple(b.getChannel(), b.getDirection(),
b.getRow(), b.getCol());
});
// Batch DMA operations with the same row, channel, and direction into a
// single TCT sync operation, as long as they have consecutive columns.
llvm::MapVector<AMDAIE::NpuPushToQueueOp, uint32_t> columnBatches;
for (auto pushToQueueOp : pushToQueueOps) {
if (!columnBatches.empty()) {
auto &[lastPushOp, lastColNum] = columnBatches.back();
if (lastPushOp.getRow() == pushToQueueOp.getRow() &&
lastPushOp.getCol() + lastColNum == pushToQueueOp.getCol() &&
lastPushOp.getDirection() == pushToQueueOp.getDirection() &&
lastPushOp.getChannel() == pushToQueueOp.getChannel()) {
++lastColNum;
continue;
}
}
columnBatches.insert({pushToQueueOp, 1});
}
// Convert to TCT sync ops.
for (auto &[pushToQueueOp, colNum] : columnBatches) {
if (failed(builder.appendTCTSync(
pushToQueueOp.getCol(), pushToQueueOp.getRow(),
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, colNum,
pushToQueueOp.getChannel()))) {
return failure();
}
LogicalResult convertOp(AMDAIE::NpuTctSyncOp op, TransactionBuilder &builder) {
if (failed(builder.appendTCTSync(
op.getCol(), op.getRow(), static_cast<uint32_t>(op.getDirection()),
op.getRowNum(), op.getColNum(), op.getChannel()))) {
return failure();
}
return success();
}
Expand Down Expand Up @@ -304,7 +265,7 @@ LogicalResult controlCodeToTransaction(IRRewriter &rewriter,
WalkResult res = controlCodeOp->walk([&](Operation *op) {
LogicalResult switchResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<AMDAIE::NpuAddressPatchOp, AMDAIE::NpuDmaWaitOp,
.Case<AMDAIE::NpuAddressPatchOp, AMDAIE::NpuTctSyncOp,
AMDAIE::NpuPushToQueueOp, AMDAIE::NpuWriteBdOp>(
[&](auto npuOp) {
if (failed(convertOp(npuOp, builder))) return failure();
Expand Down
Loading
Loading