From 9e333ad329c639c1536ae0fb84e8150c5abd8180 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 23 Oct 2024 00:32:35 -0700 Subject: [PATCH 1/2] Safer op deletion using llvm::SmallSet (#750) --- mlir/lib/Conversion/AIRLoweringPass.cpp | 11 +++++---- mlir/lib/Conversion/AIRRtToLLVMPass.cpp | 5 ++-- mlir/lib/Conversion/AIRRtToNpuPass.cpp | 24 +++++++------------ .../Transform/AIRDependencyScheduleOpt.cpp | 9 ++++--- mlir/lib/Transform/AIRMiscPasses.cpp | 22 +++++++++++++---- mlir/lib/Util/Dependency.cpp | 5 ++-- 6 files changed, 44 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Conversion/AIRLoweringPass.cpp b/mlir/lib/Conversion/AIRLoweringPass.cpp index 64cc3620d..70cea89c1 100644 --- a/mlir/lib/Conversion/AIRLoweringPass.cpp +++ b/mlir/lib/Conversion/AIRLoweringPass.cpp @@ -33,6 +33,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -652,7 +653,7 @@ LogicalResult lowerAirExecute(Operation *op) { if (!module) return failure(); - SmallVector erased; + llvm::SmallSet erased; module->walk([&](air::ExecuteOp exe) { auto &bb = exe.getBody().front(); unsigned idx = 0; @@ -670,7 +671,7 @@ LogicalResult lowerAirExecute(Operation *op) { int resultIdx = 1; for (auto r : t->getOperands()) exe.getResult(resultIdx++).replaceAllUsesWith(r); - erased.push_back(t); + erased.insert(t); }); exe->getBlock()->getOperations().splice(Block::iterator(exe), bb.getOperations()); @@ -680,7 +681,7 @@ LogicalResult lowerAirExecute(Operation *op) { SmallVector{}); exe.getResult(0).replaceAllUsesWith(w.getResult(0)); } - erased.push_back(exe); + erased.insert(exe); }); for (auto a : erased) a->erase(); @@ -856,7 +857,7 @@ LogicalResult ScfParToAffineForConversion(Operation *op) { if (!f) return failure(); - SmallVector erased; + llvm::SmallSet erased; f.walk([&](scf::ParallelOp scf_par) { for (auto v : scf_par.getLowerBound()) { assert(dyn_cast(v.getDefiningOp()).value() == 0); @@ -894,7 +895,7 @@ LogicalResult ScfParToAffineForConversion(Operation *op) { builder.clone(o, remap); } } - erased.push_back(scf_par); + erased.insert(scf_par); }); for (auto a : erased) { if (a->getNumResults()) diff --git a/mlir/lib/Conversion/AIRRtToLLVMPass.cpp b/mlir/lib/Conversion/AIRRtToLLVMPass.cpp index 940a84cee..6af89305a 100644 --- a/mlir/lib/Conversion/AIRRtToLLVMPass.cpp +++ b/mlir/lib/Conversion/AIRRtToLLVMPass.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -1297,10 +1298,10 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { signalPassFailure(); } - SmallVector erased_extern; + llvm::SmallSet erased_extern; for (auto func : module.getOps()) { if (func.isExternal() && func.symbolKnownUseEmpty(module)) - erased_extern.push_back(func); + erased_extern.insert(func); else func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext())); diff --git a/mlir/lib/Conversion/AIRRtToNpuPass.cpp b/mlir/lib/Conversion/AIRRtToNpuPass.cpp index 2c69aa82a..97e24371e 100644 --- a/mlir/lib/Conversion/AIRRtToNpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToNpuPass.cpp @@ -378,13 +378,13 @@ class HostMemRefCopyOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector erased; + llvm::SmallSet erased; if (auto alloc = op.getSource().getDefiningOp()) { op.getSource().replaceAllUsesWith(op.getTarget()); - erased.push_back(alloc); + erased.insert(alloc); } else if (auto alloc = op.getTarget().getDefiningOp()) { op.getTarget().replaceAllUsesWith(op.getSource()); - erased.push_back(alloc); + erased.insert(alloc); } for (auto o : erased) rewriter.eraseOp(o); @@ -509,13 +509,6 @@ void hoistTargetOpsToNewAffineFor(OpBuilder builder, affine::AffineForOp for_op, } } -template -void push_back_if_unique(SmallVector &vec, T entry) { - if (std::find(vec.begin(), vec.end(), entry) == vec.end()) { - vec.push_back(entry); - } -} - void identifyTargetAffineForAndOps( func::FuncOp f, SmallVector> &target_ops_vec) { // Identify the target for loops and their target child ops @@ -563,7 +556,7 @@ void isolateAIRRtDmaLoopNests(ModuleOp module) { } // Hoist ops out of each scf.for. - SmallVector erased; + llvm::SmallSet erased; for (auto vec : target_ops_vec) { affine::AffineForOp loop_nest_head = vec[0]->getParentOfType(); @@ -572,7 +565,7 @@ void isolateAIRRtDmaLoopNests(ModuleOp module) { } OpBuilder builder(loop_nest_head); hoistTargetOpsToNewAffineFor(builder, loop_nest_head, vec); - push_back_if_unique(erased, loop_nest_head.getOperation()); + erased.insert(loop_nest_head.getOperation()); } for (auto o : erased) o->erase(); @@ -931,17 +924,18 @@ specializeAffineForInAIRRtDmaWrapAndStride(OpBuilder builder, void specializeAffineForInAIRRtDmaWrapAndStride(ModuleOp module) { SmallVector funcOps; module.walk([&](func::FuncOp f) { funcOps.push_back(f); }); - SmallVector erased; + llvm::SmallSet erased; SmallVector unroll_outer_dim; auto specialzeAllAffineFors = - [&](SmallVector funcOps, SmallVector &erased, + [&](SmallVector funcOps, + llvm::SmallSet &erased, SmallVector &unroll_outer_dim) { for (auto f : funcOps) { for (auto for_op : f.getOps()) { OpBuilder builder(for_op); if (specializeAffineForInAIRRtDmaWrapAndStride(builder, for_op) .succeeded()) - erased.push_back(for_op); + erased.insert(for_op); else { // Wait list to be unrolled one outer dimension, and then try // specializing the wraps and strides again. diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 9f127556a..e37401e65 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -5105,14 +5105,14 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern { remap.map(forOp.getRegionIterArgs()[j], new_for_op.getRegionIterArgs()[j]); remap.map(forOp.getInductionVar(), new_for_op.getInductionVar()); - SmallVector erased; + llvm::SmallSet erased; Value yielded_token = nullptr; for (auto &o : forOp.getOps()) { if (&o != new_for_op && &o != forOp.getBody()->getTerminator()) { auto new_o = builder.clone(o, remap); if (isAsyncOp(new_o)) { yielded_token = new_o->getResult(0); - erased.push_back(&o); + erased.insert(&o); } } } @@ -5123,7 +5123,10 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern { builder.create(loc); } for (auto o : erased) { - o->getResult(0).replaceAllUsesWith(new_for_op->getResult(0)); + // Replace all remaining uses of erased op's token with the new for op's. + for (auto res : o->getResults()) + if (isa(res.getType()) && !res.use_empty()) + res.replaceAllUsesWith(new_for_op->getResult(0)); o->erase(); } diff --git a/mlir/lib/Transform/AIRMiscPasses.cpp b/mlir/lib/Transform/AIRMiscPasses.cpp index d0bb2ea5f..2d6df39f8 100644 --- a/mlir/lib/Transform/AIRMiscPasses.cpp +++ b/mlir/lib/Transform/AIRMiscPasses.cpp @@ -33,6 +33,7 @@ #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" #include @@ -1459,7 +1460,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() { return; // Tile memrefs. - SmallVector erased; + llvm::SmallSet erased; for (auto allocOp : targetMemrefs) { int targetColTilingFactor = findGCD(targetMemrefsToColTilingFactors[allocOp]); @@ -1482,7 +1483,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() { IRMapping remap; (void)air::unrollAIRChannelPutGetInScfParallel(builder, par, user, remap); - erased.push_back(par); + erased.insert(par); } else if ((isa(user) && splitTypeAttr.str() == "MM2SChannels") || (isa(user) && @@ -1608,14 +1609,25 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() { dyn_cast(theOtherChanOp[0].getOperation()) .getAsyncToken(); oldToken.replaceAllUsesWith(newWaitAll1); - erased.push_back(theOtherChanOp[0]); - erased.push_back(chanUserOp); + erased.insert(theOtherChanOp[0]); + erased.insert(chanUserOp); } } } - for (auto e : erased) + for (auto e : erased) { + // Replace all remaining uses of erased op's token with a new wait_all. + for (auto res : e->getResults()) { + if (isa(res.getType()) && !res.use_empty()) { + OpBuilder b(e); + res.replaceAllUsesWith( + b.create(e->getLoc(), air::AsyncTokenType::get(ctx), + SmallVector{}) + .getAsyncToken()); + } + } e->erase(); + } auto context = &getContext(); RewritePatternSet canoPatterns(context); diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index 724b36625..77fa58bf3 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -9,6 +9,7 @@ #include "air/Util/Dependency.h" #include "air/Util/Util.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" #include #include @@ -1852,7 +1853,7 @@ void dependencyCanonicalizer::removeDepListRepetition(func::FuncOp func) { // Remove unused air.execute ops which have no side effects void dependencyCanonicalizer::removeUnusedExecuteOp(func::FuncOp func) { - SmallVector erased_ops; + llvm::SmallSet erased_ops; func.walk([&](air::ExecuteOp op) { // Check the type of op inside the execute. Only remove ops with no side // effects @@ -1863,7 +1864,7 @@ void dependencyCanonicalizer::removeUnusedExecuteOp(func::FuncOp func) { if (op->getNumResults() == 2) { auto result = op->getResult(1); if (result.use_empty()) { - erased_ops.push_back(op); + erased_ops.insert(op); } } } From 9160ca7826c57a2a6f4f81cde6aed441ad772b74 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 23 Oct 2024 01:35:25 -0700 Subject: [PATCH 2/2] Take into account herds in peeled for loop (#751) --- mlir/lib/Transform/AIRMiscPasses.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Transform/AIRMiscPasses.cpp b/mlir/lib/Transform/AIRMiscPasses.cpp index 2d6df39f8..c7507045d 100644 --- a/mlir/lib/Transform/AIRMiscPasses.cpp +++ b/mlir/lib/Transform/AIRMiscPasses.cpp @@ -1437,9 +1437,23 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() { // none, then memref splitting is not needed, as no routings or channels can // be saved if only allocating to a single memtile. auto getTileCountInSegment = [](air::SegmentOp seg) { + DenseMap + herdNumTiles; // Herds with the same name are assumed to be different + // time phases of the same physical herd. unsigned tileCount = 0; - seg.walk( - [&](air::HerdOp h) { tileCount += h.getNumCols() * h.getNumRows(); }); + seg.walk([&](air::HerdOp h) { + if (!h.getSymName()) { + tileCount += h.getNumCols() * h.getNumRows(); + return; + } + StringRef herdSym = *h.getSymName(); + herdNumTiles[herdSym] = + herdNumTiles.count(herdSym) + ? std::max(herdNumTiles[herdSym], h.getNumCols() * h.getNumRows()) + : h.getNumCols() * h.getNumRows(); + }); + for (const auto &[herdSym, count] : herdNumTiles) + tileCount += count; return tileCount; }; if (llvm::none_of(allocOps, [&](memref::AllocOp a) {