Skip to content

Commit

Permalink
Merge branch 'main' into compose_memref_ops_as_cano
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Oct 24, 2024
2 parents 554624f + 9160ca7 commit ef0436a
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 34 deletions.
11 changes: 6 additions & 5 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -652,7 +653,7 @@ LogicalResult lowerAirExecute(Operation *op) {
if (!module)
return failure();

SmallVector<Operation *, 8> erased;
llvm::SmallSet<Operation *, 8> erased;
module->walk([&](air::ExecuteOp exe) {
auto &bb = exe.getBody().front();
unsigned idx = 0;
Expand All @@ -670,7 +671,7 @@ LogicalResult lowerAirExecute(Operation *op) {
int resultIdx = 1;
for (auto r : t->getOperands())
exe.getResult(resultIdx++).replaceAllUsesWith(r);
erased.push_back(t);
erased.insert(t);
});
exe->getBlock()->getOperations().splice(Block::iterator(exe),
bb.getOperations());
Expand All @@ -680,7 +681,7 @@ LogicalResult lowerAirExecute(Operation *op) {
SmallVector<Value>{});
exe.getResult(0).replaceAllUsesWith(w.getResult(0));
}
erased.push_back(exe);
erased.insert(exe);
});
for (auto a : erased)
a->erase();
Expand Down Expand Up @@ -856,7 +857,7 @@ LogicalResult ScfParToAffineForConversion(Operation *op) {
if (!f)
return failure();

SmallVector<Operation *, 8> erased;
llvm::SmallSet<Operation *, 8> erased;
f.walk([&](scf::ParallelOp scf_par) {
for (auto v : scf_par.getLowerBound()) {
assert(dyn_cast<arith::ConstantIndexOp>(v.getDefiningOp()).value() == 0);
Expand Down Expand Up @@ -894,7 +895,7 @@ LogicalResult ScfParToAffineForConversion(Operation *op) {
builder.clone(o, remap);
}
}
erased.push_back(scf_par);
erased.insert(scf_par);
});
for (auto a : erased) {
if (a->getNumResults())
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/AIRRtToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -1297,10 +1298,10 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase<AIRRtToLLVM> {
signalPassFailure();
}

SmallVector<func::FuncOp> erased_extern;
llvm::SmallSet<func::FuncOp, 1> erased_extern;
for (auto func : module.getOps<func::FuncOp>()) {
if (func.isExternal() && func.symbolKnownUseEmpty(module))
erased_extern.push_back(func);
erased_extern.insert(func);
else
func->setAttr("llvm.emit_c_interface",
UnitAttr::get(func.getContext()));
Expand Down
24 changes: 9 additions & 15 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,13 @@ class HostMemRefCopyOpConversion : public OpConversionPattern<memref::CopyOp> {
LogicalResult
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Operation *> erased;
llvm::SmallSet<Operation *, 1> erased;
if (auto alloc = op.getSource().getDefiningOp()) {
op.getSource().replaceAllUsesWith(op.getTarget());
erased.push_back(alloc);
erased.insert(alloc);
} else if (auto alloc = op.getTarget().getDefiningOp()) {
op.getTarget().replaceAllUsesWith(op.getSource());
erased.push_back(alloc);
erased.insert(alloc);
}
for (auto o : erased)
rewriter.eraseOp(o);
Expand Down Expand Up @@ -509,13 +509,6 @@ void hoistTargetOpsToNewAffineFor(OpBuilder builder, affine::AffineForOp for_op,
}
}

template <typename T>
void push_back_if_unique(SmallVector<T> &vec, T entry) {
if (std::find(vec.begin(), vec.end(), entry) == vec.end()) {
vec.push_back(entry);
}
}

void identifyTargetAffineForAndOps(
func::FuncOp f, SmallVector<llvm::SetVector<Operation *>> &target_ops_vec) {
// Identify the target for loops and their target child ops
Expand Down Expand Up @@ -563,7 +556,7 @@ void isolateAIRRtDmaLoopNests(ModuleOp module) {
}

// Hoist ops out of each scf.for.
SmallVector<Operation *> erased;
llvm::SmallSet<Operation *, 1> erased;
for (auto vec : target_ops_vec) {
affine::AffineForOp loop_nest_head =
vec[0]->getParentOfType<affine::AffineForOp>();
Expand All @@ -572,7 +565,7 @@ void isolateAIRRtDmaLoopNests(ModuleOp module) {
}
OpBuilder builder(loop_nest_head);
hoistTargetOpsToNewAffineFor(builder, loop_nest_head, vec);
push_back_if_unique<Operation *>(erased, loop_nest_head.getOperation());
erased.insert(loop_nest_head.getOperation());
}
for (auto o : erased)
o->erase();
Expand Down Expand Up @@ -931,17 +924,18 @@ specializeAffineForInAIRRtDmaWrapAndStride(OpBuilder builder,
void specializeAffineForInAIRRtDmaWrapAndStride(ModuleOp module) {
SmallVector<func::FuncOp> funcOps;
module.walk([&](func::FuncOp f) { funcOps.push_back(f); });
SmallVector<Operation *> erased;
llvm::SmallSet<Operation *, 1> erased;
SmallVector<affine::AffineForOp> unroll_outer_dim;
auto specialzeAllAffineFors =
[&](SmallVector<func::FuncOp> funcOps, SmallVector<Operation *> &erased,
[&](SmallVector<func::FuncOp> funcOps,
llvm::SmallSet<Operation *, 1> &erased,
SmallVector<affine::AffineForOp> &unroll_outer_dim) {
for (auto f : funcOps) {
for (auto for_op : f.getOps<affine::AffineForOp>()) {
OpBuilder builder(for_op);
if (specializeAffineForInAIRRtDmaWrapAndStride(builder, for_op)
.succeeded())
erased.push_back(for_op);
erased.insert(for_op);
else {
// Wait list to be unrolled one outer dimension, and then try
// specializing the wraps and strides again.
Expand Down
9 changes: 6 additions & 3 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5105,14 +5105,14 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
remap.map(forOp.getRegionIterArgs()[j],
new_for_op.getRegionIterArgs()[j]);
remap.map(forOp.getInductionVar(), new_for_op.getInductionVar());
SmallVector<Operation *> erased;
llvm::SmallSet<Operation *, 1> erased;
Value yielded_token = nullptr;
for (auto &o : forOp.getOps()) {
if (&o != new_for_op && &o != forOp.getBody()->getTerminator()) {
auto new_o = builder.clone(o, remap);
if (isAsyncOp(new_o)) {
yielded_token = new_o->getResult(0);
erased.push_back(&o);
erased.insert(&o);
}
}
}
Expand All @@ -5123,7 +5123,10 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
builder.create<scf::YieldOp>(loc);
}
for (auto o : erased) {
o->getResult(0).replaceAllUsesWith(new_for_op->getResult(0));
// Replace all remaining uses of erased op's token with the new for op's.
for (auto res : o->getResults())
if (isa<air::AsyncTokenType>(res.getType()) && !res.use_empty())
res.replaceAllUsesWith(new_for_op->getResult(0));
o->erase();
}

Expand Down
40 changes: 33 additions & 7 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"

#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"

#include <list>
Expand Down Expand Up @@ -1436,9 +1437,23 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
// none, then memref splitting is not needed, as no routings or channels can
// be saved if only allocating to a single memtile.
auto getTileCountInSegment = [](air::SegmentOp seg) {
DenseMap<StringRef, uint64_t>
herdNumTiles; // Herds with the same name are assumed to be different
// time phases of the same physical herd.
unsigned tileCount = 0;
seg.walk(
[&](air::HerdOp h) { tileCount += h.getNumCols() * h.getNumRows(); });
seg.walk([&](air::HerdOp h) {
if (!h.getSymName()) {
tileCount += h.getNumCols() * h.getNumRows();
return;
}
StringRef herdSym = *h.getSymName();
herdNumTiles[herdSym] =
herdNumTiles.count(herdSym)
? std::max(herdNumTiles[herdSym], h.getNumCols() * h.getNumRows())
: h.getNumCols() * h.getNumRows();
});
for (const auto &[herdSym, count] : herdNumTiles)
tileCount += count;
return tileCount;
};
if (llvm::none_of(allocOps, [&](memref::AllocOp a) {
Expand All @@ -1459,7 +1474,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
return;

// Tile memrefs.
SmallVector<Operation *> erased;
llvm::SmallSet<Operation *, 1> erased;
for (auto allocOp : targetMemrefs) {
int targetColTilingFactor =
findGCD(targetMemrefsToColTilingFactors[allocOp]);
Expand All @@ -1482,7 +1497,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
IRMapping remap;
(void)air::unrollAIRChannelPutGetInScfParallel(builder, par, user,
remap);
erased.push_back(par);
erased.insert(par);
} else if ((isa<air::ChannelPutOp>(user) &&
splitTypeAttr.str() == "MM2SChannels") ||
(isa<air::ChannelGetOp>(user) &&
Expand Down Expand Up @@ -1608,14 +1623,25 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
dyn_cast<air::AsyncOpInterface>(theOtherChanOp[0].getOperation())
.getAsyncToken();
oldToken.replaceAllUsesWith(newWaitAll1);
erased.push_back(theOtherChanOp[0]);
erased.push_back(chanUserOp);
erased.insert(theOtherChanOp[0]);
erased.insert(chanUserOp);
}
}
}

for (auto e : erased)
for (auto e : erased) {
// Replace all remaining uses of erased op's token with a new wait_all.
for (auto res : e->getResults()) {
if (isa<air::AsyncTokenType>(res.getType()) && !res.use_empty()) {
OpBuilder b(e);
res.replaceAllUsesWith(
b.create<air::WaitAllOp>(e->getLoc(), air::AsyncTokenType::get(ctx),
SmallVector<Value>{})
.getAsyncToken());
}
}
e->erase();
}

auto context = &getContext();
RewritePatternSet canoPatterns(context);
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "air/Util/Dependency.h"
#include "air/Util/Util.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include <sys/stat.h>

#include <iostream>
Expand Down Expand Up @@ -1852,7 +1853,7 @@ void dependencyCanonicalizer::removeDepListRepetition(func::FuncOp func) {

// Remove unused air.execute ops which have no side effects
void dependencyCanonicalizer::removeUnusedExecuteOp(func::FuncOp func) {
SmallVector<air::ExecuteOp, 1> erased_ops;
llvm::SmallSet<air::ExecuteOp, 1> erased_ops;
func.walk([&](air::ExecuteOp op) {
// Check the type of op inside the execute. Only remove ops with no side
// effects
Expand All @@ -1863,7 +1864,7 @@ void dependencyCanonicalizer::removeUnusedExecuteOp(func::FuncOp func) {
if (op->getNumResults() == 2) {
auto result = op->getResult(1);
if (result.use_empty()) {
erased_ops.push_back(op);
erased_ops.insert(op);
}
}
}
Expand Down

0 comments on commit ef0436a

Please sign in to comment.