Skip to content

Commit

Permalink
Fixup issue where AIR only optimizes SHIM DMA BD for wrap < 1024 (Xil…
Browse files Browse the repository at this point in the history
…inx#606)

* Enable constraint where highest wrap is no greater than 64

* Fixup issue where the outermost wrap-and-stride dim gets lost when npu dma op gets tiled at outermost wrap dimension

* Fixup issue on for loop step size when stride = 0

* Fixup an issue where dim with stride = 0 was converted into for loop with step = 1, which then got passed into offset

* Add test; change existing tests to reflect optimized bd allocation
  • Loading branch information
erwei-xilinx authored Jun 15, 2024
1 parent b2df4d7 commit f384223
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 28 deletions.
87 changes: 69 additions & 18 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,21 +513,21 @@ void isolateAIRRtDmaLoopNests(ModuleOp module) {
}

// AIE2 hardware constraints.
const int AIE2_WRAP_UPPER_BOUND = 1024;
const std::vector<int> AIE2_WRAP_UPPER_BOUNDS = {64, 1024, 1024, 1024};
const int AIE2_STRIDE_UPPER_BOUND = 1048576;
const int AIE2_DIM_COUNT = 4;

bool violatesAIE2WrapLimit(airrt::DmaMemcpyNdOp dma) {
SmallVector<Value> wrap_list;
wrap_list.push_back(dma.getLength0());
wrap_list.push_back(dma.getLength1());
wrap_list.push_back(dma.getLength2());
wrap_list.push_back(dma.getLength3());
for (auto wrap : wrap_list) {
if (auto const_val = getConstantIntValue(wrap)) {
wrap_list.push_back(dma.getLength2());
wrap_list.push_back(dma.getLength1());
wrap_list.push_back(dma.getLength0());
for (unsigned i = 0; i < wrap_list.size(); i++) {
if (auto const_val = getConstantIntValue(wrap_list[i])) {
// Detected wrap that goes beyond the AIE2 hardware limit.
if (*const_val >= AIE2_WRAP_UPPER_BOUND) {
if (*const_val >= AIE2_WRAP_UPPER_BOUNDS[i])
return true;
}
} else
assert(false && "has non-static wrap");
}
Expand Down Expand Up @@ -567,6 +567,7 @@ int findLargestFactor(int num, int max) {

void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
auto loc = memcpy_op->getLoc();
auto ctx = memcpy_op->getContext();
auto oper_begin = memcpy_op.getOperands().begin();
SmallVector<Value> offsets(oper_begin + 4, oper_begin + 8);
SmallVector<Value> wraps(oper_begin + 8, oper_begin + 12);
Expand All @@ -579,10 +580,20 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
for (int i = wraps.size() - 1; i >= 0; i--) {
auto const_wrap = *getConstantIntValue(wraps[i]);
auto const_stride = *getConstantIntValue(strides[i]);
if (const_wrap >= AIE2_WRAP_UPPER_BOUND) {
// Found dimension with illegal wrap. Tiling.
int outer_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1);
int inner_wrap = mlir::ceilDiv(const_wrap, outer_wrap);
if (const_wrap >= AIE2_WRAP_UPPER_BOUNDS[i]) {
// Found dimension with illegal wrap. Tiling. (Prefers smaller outer wrap
// values, as long as stride fits)
int a_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUNDS[i] - 1);
int b_wrap = mlir::ceilDiv(const_wrap, a_wrap);
int new_a_stride =
(const_stride * a_wrap) % air::getTensorVolume(llvm::cast<MemRefType>(
memcpy_op.getMemref().getType()));
int inner_wrap = (new_a_stride > AIE2_STRIDE_UPPER_BOUND && i != 0)
? (b_wrap)
: (a_wrap);
int outer_wrap = (new_a_stride > AIE2_STRIDE_UPPER_BOUND && i != 0)
? (a_wrap)
: (b_wrap);
wraps[i] = builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
IntegerAttr::get(builder.getI64Type(), inner_wrap));
Expand All @@ -609,29 +620,52 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
// Unroll highest dimensions of wrap and stride, if the new dimension count
// goes beyond 4.
SmallVector<affine::AffineForOp> for_loop_nest;
Value inner_affine_for_iv = nullptr;
if (wraps.size() > AIE2_DIM_COUNT) {
affine::AffineForOp inner_affine_for = nullptr;
while (wraps.size() > AIE2_DIM_COUNT) {
auto const_offset = *getConstantIntValue(offsets[0]);
auto const_lowest_offset = *getConstantIntValue(offsets.back());
auto const_wrap = *getConstantIntValue(wraps[0]);
auto const_stride = *getConstantIntValue(strides[0]);

// Convert the outer dimension into an affine.for loop.
auto const_upper_bound = const_offset + const_wrap * const_stride;
int const_lower_bound =
const_stride ? (const_offset * const_stride + const_lowest_offset)
: 0;
auto const_upper_bound =
const_stride ? (const_offset * const_stride +
const_wrap * const_stride + const_lowest_offset)
: const_wrap;
int const_step = const_stride ? const_stride : 1;
auto new_for_op =
(const_stride)
(inner_affine_for_iv)
? (builder.create<affine::AffineForOp>(
loc, const_offset, const_upper_bound, const_stride))
: (builder.create<affine::AffineForOp>(loc, 0, const_wrap));
loc,
SmallVector<Value>{builder.create<arith::AddIOp>(
loc, inner_affine_for_iv,
builder.create<arith::ConstantIndexOp>(
loc, const_lower_bound))},
AffineMap::get(ctx),
SmallVector<Value>{builder.create<arith::AddIOp>(
loc, inner_affine_for_iv,
builder.create<arith::ConstantIndexOp>(
loc, const_upper_bound))},
AffineMap::get(ctx), const_step))
: (builder.create<affine::AffineForOp>(
loc, const_lower_bound, const_upper_bound, const_step));
for_loop_nest.push_back(new_for_op);
inner_affine_for = new_for_op;

// Pop front.
offsets.erase(offsets.begin());
wraps.erase(wraps.begin());
strides.erase(strides.begin());

builder.setInsertionPointToStart(inner_affine_for.getBody());
if (const_stride)
inner_affine_for_iv = inner_affine_for.getInductionVar();
}
builder.setInsertionPointToStart(inner_affine_for.getBody());
}

// Stride field implicit last element one, pop.
Expand All @@ -641,8 +675,20 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
SmallVector<Value> new_opers;
SmallVector<Type> tys;
auto old_opers = memcpy_op.getOperands();
// Insert
new_opers.insert(new_opers.end(), old_opers.begin(), old_opers.begin() + 4);
new_opers.insert(new_opers.end(), offsets.begin(), offsets.end());
if (inner_affine_for_iv) {
// Innermost tiled affine.for loop induction variable as lowest offset, if
// original rank exceeds hw limit.
new_opers.insert(new_opers.end(), offsets.begin(), offsets.end() - 1);
auto new_inner_offset = builder.create<arith::AddIOp>(
loc,
builder.create<arith::IndexCastOp>(loc, IntegerType::get(ctx, 64),
inner_affine_for_iv),
offsets.back());
new_opers.push_back(new_inner_offset);
} else
new_opers.insert(new_opers.end(), offsets.begin(), offsets.end());
new_opers.insert(new_opers.end(), wraps.begin(), wraps.end());
new_opers.insert(new_opers.end(), strides.begin(), strides.end());
builder.create<airrt::DmaMemcpyNdOp>(loc, tys, new_opers,
Expand Down Expand Up @@ -909,6 +955,11 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
// Enforce AIE2 hardware constraint: wrap size limit within [0, 1023].
enforceAIE2WrapLimit(module);

// Simplify arith ops (from airrt)
RewritePatternSet canoPatterns_3(ctx);
arith::IndexCastOp::getCanonicalizationPatterns(canoPatterns_3, ctx);
(void)applyPatternsAndFoldGreedily(module, std::move(canoPatterns_3));

ConversionTarget target(getContext());
target.addIllegalDialect<AIRRtDialect>();
target.addLegalDialect<arith::ArithDialect, AIEX::AIEXDialect>();
Expand Down
52 changes: 46 additions & 6 deletions mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,10 @@ module {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 64, 0][4, 8, 64, 256][0, 256, 2048]) {id = 1 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 128, 0][4, 8, 64, 256][0, 256, 2048]) {id = 2 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 192, 0][4, 8, 64, 256][0, 256, 2048]) {id = 3 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG2]][0, 0, 0, 0][4, 4, 64, 64][131072, 64, 2048]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2048x2048xi32>

#map = affine_map<()[s0] -> (s0 * 64)>
Expand Down Expand Up @@ -701,8 +701,8 @@ module {
// CHECK-SAME: %[[VAL_0:.*]]: memref<262144xi32>, %[[VAL_1:.*]]: memref<262144xi32>, %[[VAL_2:.*]]: memref<131072xi32>) {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][2, 4, 256, 128][0, 128, 512]) {id = 0 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][2, 4, 256, 128][0, 128, 512]) {id = 1 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][2, 2, 64, 128][65536, 128, 256]) {id = 4 : i64, metadata = @airMemcpyId45} : memref<131072xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 16384][2, 2, 64, 128][65536, 128, 256]) {id = 5 : i64, metadata = @airMemcpyId46} : memref<131072xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 32768][2, 2, 64, 128][65536, 128, 256]) {id = 0 : i64, metadata = @airMemcpyId47} : memref<131072xi32>
Expand Down Expand Up @@ -930,3 +930,43 @@ module {
return
}
}

// -----

// Outermost wrap must be in range [1:64] for AIE2.

// CHECK-LABEL: func21
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][38, 2, 64, 32][77824, 32, 1216]) {id = 0 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 2957312][38, 2, 64, 32][77824, 32, 1216]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 5914624][38, 2, 64, 32][77824, 32, 1216]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 8871936][38, 2, 64, 32][77824, 32, 1216]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
// CHECK: return

#map = affine_map<()[s0] -> (s0 * 128)>
module {
aie.device(npu1_4col) {
aie.shim_dma_allocation @airMemcpyId10(MM2S, 1, 0)
memref.global "public" @airMemcpyId10 : memref<1x2x64x64xbf16, 1 : i32>
} {sym_name = "matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0"}
airrt.module_metadata{
}
func.func @func21(%arg0: memref<9728x2432xbf16>) {
%c2_i64 = arith.constant 2 : i64
%c2432_i64 = arith.constant 2432 : i64
%c155648_i64 = arith.constant 155648 : i64
%c152_i64 = arith.constant 152 : i64
%c64_i64 = arith.constant 64 : i64
%c10_i32 = arith.constant 10 : i32
%c0_i64 = arith.constant 0 : i64
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 1 {
%0 = affine.apply #map()[%arg4]
%1 = arith.index_cast %arg3 : index to i64
%2 = arith.index_cast %arg4 : index to i64
%3 = arith.index_cast %0 : index to i64
%4 = airrt.dma_memcpy_nd(%c10_i32, %1, %2, %arg0[%c0_i64, %c0_i64, %c0_i64, %3], [%c152_i64, %c2_i64, %c64_i64, %c64_i64], [%c155648_i64, %c64_i64, %c2432_i64]) {metadata = @airMemcpyId10} : (i32, i64, i64, memref<9728x2432xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
}
}
return
}
}
8 changes: 4 additions & 4 deletions mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ module {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32>

module {
Expand Down

0 comments on commit f384223

Please sign in to comment.