From 24cb14e6d2233e819a5455928e4237ef319e6fc8 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Mon, 21 Oct 2024 23:17:33 -0700 Subject: [PATCH] AIRSegmentLoopFusion: A number of fixups and improvements around async dependency (#748) * Trace dep token users through air.wait_all; move fused loop to before the last loop being fused for ssa dominance; more informative failure message when broken dependence is detected after fusion * Unit test checking for a complex case which was failing before --- mlir/include/air/Util/Dependency.h | 1 + .../Transform/AIRDependencyScheduleOpt.cpp | 30 +++++++- mlir/lib/Util/Dependency.cpp | 22 ++++++ .../segment_loop_fusion.mlir | 72 +++++++++++++++++++ 4 files changed, 122 insertions(+), 3 deletions(-) diff --git a/mlir/include/air/Util/Dependency.h b/mlir/include/air/Util/Dependency.h index da5f5c152..5d6387d99 100644 --- a/mlir/include/air/Util/Dependency.h +++ b/mlir/include/air/Util/Dependency.h @@ -69,6 +69,7 @@ Value getAsyncTokenFromOp(Operation *op); void addAsyncDependencyIfNew(Operation *op, Value token); bool isAsyncOp(Operation *op); bool areAsyncDependent(Operation *a, Operation *b); +bool isAsyncDependent(Operation *a, Operation *b); scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, SmallVector target_ops); LogicalResult unrollAIRChannelPutGetInScfParallel(OpBuilder builder, diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index dc0bb6909..9f127556a 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -4759,6 +4759,22 @@ struct ShrinkMemrefSizesByAccessPattern } }; +// Get all users to the async op's async token, with type T. +template +SmallVector getTokenUsersOfType(air::AsyncOpInterface asyncOp) { + SmallVector tokenUsers; + Value token = asyncOp.getAsyncToken(); + for (auto token_user : token.getUsers()) { + if (auto token_user_of_type = dyn_cast(token_user)) + tokenUsers.push_back(token_user_of_type); + else if (auto token_user_wait_all = dyn_cast(token_user)) + for (auto wa_user : token_user_wait_all.getAsyncToken().getUsers()) + if (auto token_user_of_type = dyn_cast(wa_user)) + tokenUsers.push_back(token_user_of_type); + } + return tokenUsers; +} + struct AIRSegmentLoopFusionPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -4866,7 +4882,7 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern { if (llvm::any_of( alloc_dealloc_execs, [&](std::pair exec_pair) { - return exec_pair.first == iaDefOp; + return isAsyncDependent(exec_pair.first, iaDefOp); })) fusableForOps.push_back(forOp); } @@ -4874,7 +4890,7 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern { if (fusableForOps.empty()) return failure(); - rewriter.setInsertionPoint(equalIterationForOps[0]); + rewriter.setInsertionPoint(equalIterationForOps.back()); auto new_loop_op_init_arg = rewriter .create( @@ -4891,7 +4907,8 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern { for (auto execOpPair : alloc_dealloc_execs) { bool canMove = false; air::ExecuteOp alloc_exec = execOpPair.first; - for (auto token_user : alloc_exec.getAsyncToken().getUsers()) + auto token_users = getTokenUsersOfType(alloc_exec); + for (auto token_user : token_users) if (llvm::any_of(equalIterationForOps, [&](scf::ForOp fusableForOp) { return fusableForOp == token_user; })) @@ -4970,6 +4987,13 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern { put_parent = put_parent->getParentOp(); } Operation *get_parent = getOp; + if (!get_parent) { + putOp->emitOpError( + "is producing data for memref in the fused scf.for loop, but no " + "consumer is found for this data within the fused loop. This " + "likely indicates a failure in the compiler pass."); + return; + } while (get_parent->getParentOp() != new_loop_op) { get_parent = get_parent->getParentOp(); } diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index a61bf76d8..724b36625 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -628,6 +628,28 @@ bool areAsyncDependent(Operation *a, Operation *b) { return false; } +// Returns true if b is asynchronously dependent on a. This function performs a +// deep dependency tracing that propagates through air.wait_all ops. +bool isAsyncDependent(Operation *a, Operation *b) { + if (a == b) + return true; + Value token_a = getAsyncTokenFromOp(a); + SmallVector dep_b = getAsyncDependenciesFromOp(b); + if (!token_a) + return false; + if (dep_b.empty()) + return false; + for (auto dep : dep_b) { + if (dep == token_a) + return true; + else if (auto dep_wa_defop = dep.getDefiningOp()) { + if (isAsyncDependent(a, dep_wa_defop)) + return true; + } + } + return false; +} + // Splits an SCF for loop into two for loops, by hoisting target operations in // for loop to a new for loop located at the same scope. scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir index 65fc14001..28abbfa7f 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir @@ -862,3 +862,75 @@ func.func @func9(%arg0: memref<512x256xi8>, %arg1: memref<256x32xi8>) { } return } + +// Scf.parallel unrolling pre-proc., with loop tiling. + +// CHECK-LABEL: func.func @func10 +// CHECK: air.segment +// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c512{{.*}} step %c256{{.*}} +// CHECK: air.channel.get async [{{.*}}] @channel_2[] +// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}} +// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c0{{.*}}] +// CHECK-NEXT: scf.yield +// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}} +// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c1{{.*}}] +// CHECK-NEXT: scf.yield +// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}} +// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c2{{.*}}] +// CHECK-NEXT: scf.yield +// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}} +// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c3{{.*}}] +// CHECK-NEXT: scf.yield +// CHECK: scf.yield + +#map15 = affine_map<()[s0] -> (s0 * 32)> +#map16 = affine_map<()[s0] -> (s0 * 8)> +func.func @func10(%arg0: memref<8x512xi32>, %arg1: memref<256x512xi32>, %arg2: memref<8x256xi32>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %0 = air.launch async (%arg3, %arg4) in (%arg5=%c1, %arg6=%c2) attributes {id = 1 : i32} { + %1 = air.segment @segment_0 async attributes {id = 2 : i32} { + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %c128 = arith.constant 128 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1_0 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %c256 = arith.constant 256 : index + %async_token, %results = air.execute -> (memref<128x512xi32, 1 : i32>) { + %alloc = memref.alloc() : memref<128x512xi32, 1 : i32> + air.execute_terminator %alloc : memref<128x512xi32, 1 : i32> + } + %2 = scf.for %arg7 = %c0 to %c512 step %c256 iter_args(%arg8 = %async_token) -> (!air.async.token) { + %4 = air.channel.get async [%arg8] @channel_2[] (%results[%c0, %arg7] [%c128, %c256] [%c512, %c1_0]) {id = 5 : i32} : (memref<128x512xi32, 1 : i32>) + scf.yield %4 : !air.async.token + } + %3 = scf.parallel (%arg7) = (%c0) to (%c4) step (%c1_0) init (%async_token) -> !air.async.token { + %async_token_2, %results_3 = air.execute -> (index) { + %6 = affine.apply #map15()[%arg7] + air.execute_terminator %6 : index + } + %4 = air.wait_all async [%async_token, %async_token_2] + %5 = scf.for %arg8 = %c0 to %c64 step %c4 iter_args(%arg9 = %4) -> (!air.async.token) { + %async_token_4, %results_5 = air.execute [%arg9] -> (index) { + %7 = affine.apply #map16()[%arg8] + air.execute_terminator %7 : index + } + %6 = air.channel.put async [%async_token_4] @channel_3[%c0, %arg7] (%results[%c0, %c0, %results_3, %results_5] [%c4, %c8, %c4, %c8] [%c8, %c2048, %c512, %c1_0]) {id = 7 : i32} : (memref<128x512xi32, 1 : i32>) + scf.yield %6 : !air.async.token + } + scf.reduce(%5 : !air.async.token) { + ^bb0(%arg8: !air.async.token, %arg9: !air.async.token): + %6 = air.wait_all async [%arg8, %arg9] + scf.reduce.return %6 : !air.async.token + } + } + %async_token_1 = air.execute { + memref.dealloc %results : memref<128x512xi32, 1 : i32> + } + } + } + return +}