diff --git a/mlir/lib/Conversion/AIRRtToNpuPass.cpp b/mlir/lib/Conversion/AIRRtToNpuPass.cpp index 888e1e44b..aaf16f829 100644 --- a/mlir/lib/Conversion/AIRRtToNpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToNpuPass.cpp @@ -513,21 +513,21 @@ void isolateAIRRtDmaLoopNests(ModuleOp module) { } // AIE2 hardware constraints. -const int AIE2_WRAP_UPPER_BOUND = 1024; +const std::vector AIE2_WRAP_UPPER_BOUNDS = {64, 1024, 1024, 1024}; +const int AIE2_STRIDE_UPPER_BOUND = 1048576; const int AIE2_DIM_COUNT = 4; bool violatesAIE2WrapLimit(airrt::DmaMemcpyNdOp dma) { SmallVector wrap_list; - wrap_list.push_back(dma.getLength0()); - wrap_list.push_back(dma.getLength1()); - wrap_list.push_back(dma.getLength2()); wrap_list.push_back(dma.getLength3()); - for (auto wrap : wrap_list) { - if (auto const_val = getConstantIntValue(wrap)) { + wrap_list.push_back(dma.getLength2()); + wrap_list.push_back(dma.getLength1()); + wrap_list.push_back(dma.getLength0()); + for (unsigned i = 0; i < wrap_list.size(); i++) { + if (auto const_val = getConstantIntValue(wrap_list[i])) { // Detected wrap that goes beyond the AIE2 hardware limit. - if (*const_val >= AIE2_WRAP_UPPER_BOUND) { + if (*const_val >= AIE2_WRAP_UPPER_BOUNDS[i]) return true; - } } else assert(false && "has non-static wrap"); } @@ -567,6 +567,7 @@ int findLargestFactor(int num, int max) { void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) { auto loc = memcpy_op->getLoc(); + auto ctx = memcpy_op->getContext(); auto oper_begin = memcpy_op.getOperands().begin(); SmallVector offsets(oper_begin + 4, oper_begin + 8); SmallVector wraps(oper_begin + 8, oper_begin + 12); @@ -579,10 +580,20 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) { for (int i = wraps.size() - 1; i >= 0; i--) { auto const_wrap = *getConstantIntValue(wraps[i]); auto const_stride = *getConstantIntValue(strides[i]); - if (const_wrap >= AIE2_WRAP_UPPER_BOUND) { - // Found dimension with illegal wrap. Tiling. - int outer_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1); - int inner_wrap = mlir::ceilDiv(const_wrap, outer_wrap); + if (const_wrap >= AIE2_WRAP_UPPER_BOUNDS[i]) { + // Found dimension with illegal wrap. Tiling. (Prefers smaller outer wrap + // values, as long as stride fits) + int a_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUNDS[i] - 1); + int b_wrap = mlir::ceilDiv(const_wrap, a_wrap); + int new_a_stride = + (const_stride * a_wrap) % air::getTensorVolume(llvm::cast( + memcpy_op.getMemref().getType())); + int inner_wrap = (new_a_stride > AIE2_STRIDE_UPPER_BOUND && i != 0) + ? (b_wrap) + : (a_wrap); + int outer_wrap = (new_a_stride > AIE2_STRIDE_UPPER_BOUND && i != 0) + ? (a_wrap) + : (b_wrap); wraps[i] = builder.create( loc, builder.getI64Type(), IntegerAttr::get(builder.getI64Type(), inner_wrap)); @@ -609,20 +620,40 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) { // Unroll highest dimensions of wrap and stride, if the new dimension count // goes beyond 4. SmallVector for_loop_nest; + Value inner_affine_for_iv = nullptr; if (wraps.size() > AIE2_DIM_COUNT) { affine::AffineForOp inner_affine_for = nullptr; while (wraps.size() > AIE2_DIM_COUNT) { auto const_offset = *getConstantIntValue(offsets[0]); + auto const_lowest_offset = *getConstantIntValue(offsets.back()); auto const_wrap = *getConstantIntValue(wraps[0]); auto const_stride = *getConstantIntValue(strides[0]); // Convert the outer dimension into an affine.for loop. - auto const_upper_bound = const_offset + const_wrap * const_stride; + int const_lower_bound = + const_stride ? (const_offset * const_stride + const_lowest_offset) + : 0; + auto const_upper_bound = + const_stride ? (const_offset * const_stride + + const_wrap * const_stride + const_lowest_offset) + : const_wrap; + int const_step = const_stride ? const_stride : 1; auto new_for_op = - (const_stride) + (inner_affine_for_iv) ? (builder.create( - loc, const_offset, const_upper_bound, const_stride)) - : (builder.create(loc, 0, const_wrap)); + loc, + SmallVector{builder.create( + loc, inner_affine_for_iv, + builder.create( + loc, const_lower_bound))}, + AffineMap::get(ctx), + SmallVector{builder.create( + loc, inner_affine_for_iv, + builder.create( + loc, const_upper_bound))}, + AffineMap::get(ctx), const_step)) + : (builder.create( + loc, const_lower_bound, const_upper_bound, const_step)); for_loop_nest.push_back(new_for_op); inner_affine_for = new_for_op; @@ -630,8 +661,11 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) { offsets.erase(offsets.begin()); wraps.erase(wraps.begin()); strides.erase(strides.begin()); + + builder.setInsertionPointToStart(inner_affine_for.getBody()); + if (const_stride) + inner_affine_for_iv = inner_affine_for.getInductionVar(); } - builder.setInsertionPointToStart(inner_affine_for.getBody()); } // Stride field implicit last element one, pop. @@ -641,8 +675,20 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) { SmallVector new_opers; SmallVector tys; auto old_opers = memcpy_op.getOperands(); + // Insert new_opers.insert(new_opers.end(), old_opers.begin(), old_opers.begin() + 4); - new_opers.insert(new_opers.end(), offsets.begin(), offsets.end()); + if (inner_affine_for_iv) { + // Innermost tiled affine.for loop induction variable as lowest offset, if + // original rank exceeds hw limit. + new_opers.insert(new_opers.end(), offsets.begin(), offsets.end() - 1); + auto new_inner_offset = builder.create( + loc, + builder.create(loc, IntegerType::get(ctx, 64), + inner_affine_for_iv), + offsets.back()); + new_opers.push_back(new_inner_offset); + } else + new_opers.insert(new_opers.end(), offsets.begin(), offsets.end()); new_opers.insert(new_opers.end(), wraps.begin(), wraps.end()); new_opers.insert(new_opers.end(), strides.begin(), strides.end()); builder.create(loc, tys, new_opers, @@ -909,6 +955,11 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase { // Enforce AIE2 hardware constraint: wrap size limit within [0, 1023]. enforceAIE2WrapLimit(module); + // Simplify arith ops (from airrt) + RewritePatternSet canoPatterns_3(ctx); + arith::IndexCastOp::getCanonicalizationPatterns(canoPatterns_3, ctx); + (void)applyPatternsAndFoldGreedily(module, std::move(canoPatterns_3)); + ConversionTarget target(getContext()); target.addIllegalDialect(); target.addLegalDialect(); diff --git a/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir b/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir index c503712a5..a81839f62 100644 --- a/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir +++ b/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir @@ -455,10 +455,10 @@ module { // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 64, 0][4, 8, 64, 256][0, 256, 2048]) {id = 1 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 128, 0][4, 8, 64, 256][0, 256, 2048]) {id = 2 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 192, 0][4, 8, 64, 256][0, 256, 2048]) {id = 3 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG2]][0, 0, 0, 0][4, 4, 64, 64][131072, 64, 2048]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2048x2048xi32> #map = affine_map<()[s0] -> (s0 * 64)> @@ -701,8 +701,8 @@ module { // CHECK-SAME: %[[VAL_0:.*]]: memref<262144xi32>, %[[VAL_1:.*]]: memref<262144xi32>, %[[VAL_2:.*]]: memref<131072xi32>) { // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][2, 4, 256, 128][0, 128, 512]) {id = 0 : i64, metadata = @airMemcpyId7} : memref<262144xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][2, 4, 256, 128][0, 128, 512]) {id = 1 : i64, metadata = @airMemcpyId7} : memref<262144xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][2, 2, 64, 128][65536, 128, 256]) {id = 4 : i64, metadata = @airMemcpyId45} : memref<131072xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 16384][2, 2, 64, 128][65536, 128, 256]) {id = 5 : i64, metadata = @airMemcpyId46} : memref<131072xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 32768][2, 2, 64, 128][65536, 128, 256]) {id = 0 : i64, metadata = @airMemcpyId47} : memref<131072xi32> @@ -930,3 +930,43 @@ module { return } } + +// ----- + +// Outermost wrap must be in range [1:64] for AIE2. + +// CHECK-LABEL: func21 +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][38, 2, 64, 32][77824, 32, 1216]) {id = 0 : i64, metadata = @airMemcpyId10} : memref<11829248xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 2957312][38, 2, 64, 32][77824, 32, 1216]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<11829248xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 5914624][38, 2, 64, 32][77824, 32, 1216]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<11829248xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 8871936][38, 2, 64, 32][77824, 32, 1216]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<11829248xi32> +// CHECK: return + +#map = affine_map<()[s0] -> (s0 * 128)> +module { + aie.device(npu1_4col) { + aie.shim_dma_allocation @airMemcpyId10(MM2S, 1, 0) + memref.global "public" @airMemcpyId10 : memref<1x2x64x64xbf16, 1 : i32> + } {sym_name = "matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0"} + airrt.module_metadata{ + } + func.func @func21(%arg0: memref<9728x2432xbf16>) { + %c2_i64 = arith.constant 2 : i64 + %c2432_i64 = arith.constant 2432 : i64 + %c155648_i64 = arith.constant 155648 : i64 + %c152_i64 = arith.constant 152 : i64 + %c64_i64 = arith.constant 64 : i64 + %c10_i32 = arith.constant 10 : i32 + %c0_i64 = arith.constant 0 : i64 + affine.for %arg3 = 0 to 1 { + affine.for %arg4 = 0 to 1 { + %0 = affine.apply #map()[%arg4] + %1 = arith.index_cast %arg3 : index to i64 + %2 = arith.index_cast %arg4 : index to i64 + %3 = arith.index_cast %0 : index to i64 + %4 = airrt.dma_memcpy_nd(%c10_i32, %1, %2, %arg0[%c0_i64, %c0_i64, %c0_i64, %3], [%c152_i64, %c2_i64, %c64_i64, %c64_i64], [%c155648_i64, %c64_i64, %c2432_i64]) {metadata = @airMemcpyId10} : (i32, i64, i64, memref<9728x2432xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event + } + } + return + } +} diff --git a/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir b/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir index 0f4f13dda..e04a2becc 100644 --- a/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir +++ b/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir @@ -122,10 +122,10 @@ module { // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32> module {