Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ChannelGenerator] Verify the physical feasibility of DMA channel allocation #1005

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,27 @@ namespace {
/// Assign channels to `amdaie.connection` ops.
LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
IRRewriter rewriter(workgroupOp->getContext());
ChannelGenerator generator;

// Get the device model.
std::optional<AMDAIEDevice> device = getConfigAMDAIEDevice(workgroupOp);
if (!device) {
return workgroupOp->emitOpError()
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());

// Get the number of producer and consumer channels for each tile.
DenseMap<Value, ChannelGenerator> tileToGeneratorMap;
workgroupOp.walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
AMDAIETileType tileType = deviceModel.getTileType(col, row);
uint8_t numDmaChannels =
deviceModel.getDmaProp<uint8_t>(tileType, AMDAIEDmaProp::NumChannels);
tileToGeneratorMap[tileOp.getResult()] =
ChannelGenerator(numDmaChannels, numDmaChannels);
});

SmallVector<AMDAIE::ConnectionOp> connectionOps;
workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) {
connectionOps.push_back(connectionOp);
Expand All @@ -43,18 +63,32 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
rewriter.setInsertionPoint(connectionOp);
SmallVector<Value> sourceChannels;
for (Value tile : sourceLogicalObjFifo.getTiles()) {
uint8_t channel = generator.getProducerDMAChannel(tile);
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignProducerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no producer DMA channel available";
}
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, channel, StrmSwPortType::DMA,
AMDAIE::DMAChannelDir::MM2S);
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
sourceChannels.push_back(channelOp.getResult());
}
SmallVector<Value> targetChannels;
for (Value tile : targetLogicalObjFifo.getTiles()) {
uint8_t channel = generator.getConsumerDMAChannel(tile);
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignConsumerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no consumer DMA channel available";
}
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, channel, StrmSwPortType::DMA,
AMDAIE::DMAChannelDir::S2MM);
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM);
targetChannels.push_back(channelOp.getResult());
}
rewriter.replaceOpWithNewOp<AMDAIE::ConnectionOp>(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-amdaie-assign-channels)" --split-input-file --verify-diagnostics %s | FileCheck %s

module {
// expected-error @+1 {{could not find an AMDAIEDevice attribute}}
amdaie.workgroup {
amdaie.controlcode {
amdaie.end
}
}
}

// -----

// CHECK-LABEL: @assign_channels
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
Expand All @@ -12,10 +23,8 @@
// CHECK: %[[CHANNEL_2:.+]] = amdaie.channel(%[[tile_0_0]], 1, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_3:.+]] = amdaie.channel(%[[tile_0_1]], 1, port_type = DMA, direction = S2MM)
// CHECK: amdaie.connection(%{{.+}} {%[[CHANNEL_3]]}, %{{.+}} {%[[CHANNEL_2]]})
// CHECK: %[[CHANNEL_4:.+]] = amdaie.channel(%[[tile_0_0]], 2, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_5:.+]] = amdaie.channel(%[[tile_0_1]], 2, port_type = DMA, direction = S2MM)
// CHECK: amdaie.connection(%{{.+}} {%[[CHANNEL_5]]}, %{{.+}} {%[[CHANNEL_4]]})
module {
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @assign_channels(%arg0: memref<1x1x8x16xi32, 1>, %arg1: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -26,6 +35,30 @@ module {
%1 = amdaie.logicalobjectfifo.from_memref %arg1, {%tile_0_0} : memref<8x16xi32> -> !amdaie.logicalobjectfifo<memref<8x16xi32>>
%2 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
%3 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
amdaie.controlcode {
amdaie.end
}
}
return
}
}

// -----

// Shim tile (0, 0) has only two producer (MM2S) channels.
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @run_out_of_channel(%arg0: memref<1x1x8x16xi32, 1>, %arg1: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
amdaie.workgroup {
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_1} : memref<1x1x8x16xi32, 1> -> !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>
%1 = amdaie.logicalobjectfifo.from_memref %arg1, {%tile_0_0} : memref<8x16xi32> -> !amdaie.logicalobjectfifo<memref<8x16xi32>>
%2 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
%3 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
// expected-error @+1 {{no producer DMA channel available}}
%4 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
amdaie.controlcode {
amdaie.end
Expand All @@ -37,7 +70,8 @@ module {

// -----

module {
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @no_source(%arg0: memref<1x1x8x16xi32, 1>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down
44 changes: 35 additions & 9 deletions runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,50 @@ using namespace llvm;
namespace mlir::iree_compiler::AMDAIE {

/// Utility to generate valid channels.
/// TODO(jornt): add physical feasibility checks on channels.
class ChannelGenerator {
public:
ChannelGenerator() {}
ChannelGenerator(uint8_t numProducerChannels, uint8_t numConsumerChannels)
: numProducerChannels(numProducerChannels),
numConsumerChannels(numConsumerChannels) {}

/// Given a tile, returns its next usable producer channel.
uint8_t getProducerDMAChannel(Value tile) {
return producerChannelsPerTile[tile]++;
/// Returns its next usable producer channel.
std::optional<uint8_t> getAndAssignProducerDMAChannel() {
for (uint8_t i = 0; i < numProducerChannels; i++) {
if (!assignedProducerChannels.count(i)) {
assignedProducerChannels.insert(i);
return i;
}
}
return std::nullopt;
}

/// Given a tile, returns its next usable consumer channel.
uint8_t getConsumerDMAChannel(Value tile) {
return consumerChannelsPerTile[tile]++;
/// Returns its next usable consumer channel.
std::optional<uint8_t> getAndAssignConsumerDMAChannel() {
for (uint8_t i = 0; i < numConsumerChannels; i++) {
if (!assignedConsumerChannels.count(i)) {
assignedConsumerChannels.insert(i);
return i;
}
}
return std::nullopt;
}

/// Assigns the provided producer channel.
void assignProducerDMAChannel(uint8_t channel) {
assignedProducerChannels.insert(channel);
}

/// Assigns the provided consumer channel.
void assignConsumerDMAChannel(uint8_t channel) {
assignedConsumerChannels.insert(channel);
}

private:
DenseMap<Value, uint8_t> producerChannelsPerTile;
DenseMap<Value, uint8_t> consumerChannelsPerTile;
uint8_t numProducerChannels = 0;
uint8_t numConsumerChannels = 0;
DenseSet<uint8_t> assignedProducerChannels;
DenseSet<uint8_t> assignedConsumerChannels;
};

} // namespace mlir::iree_compiler::AMDAIE
Expand Down
10 changes: 10 additions & 0 deletions runtime/src/iree-amd-aie/aie_runtime/Utils/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ iree_cc_test(
iree-amd-aie::aie_runtime::Utils::Utils
)

iree_cc_test(
NAME
ChannelGeneratorTest
SRCS
"ChannelGeneratorTest.cpp"
DEPS
gtest
iree-amd-aie::aie_runtime::Utils::Utils
)

iree_cc_test(
NAME
LockIdGeneratorTest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <numeric>

#include "gtest/gtest.h"
#include "iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h"

namespace {

using namespace mlir::iree_compiler::AMDAIE;

TEST(ChannelGeneratorTest, GetAssign) {
ChannelGenerator generator(2, 2);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 0);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 0);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel(), std::nullopt);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(), std::nullopt);
}

TEST(ChannelGeneratorTest, Occupied) {
ChannelGenerator generator(4, 4);
generator.assignProducerDMAChannel(0);
generator.assignConsumerDMAChannel(0);
generator.assignProducerDMAChannel(2);
generator.assignConsumerDMAChannel(2);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 3);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 3);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel(), std::nullopt);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(), std::nullopt);
}

} // namespace

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Loading