Skip to content

Commit

Permalink
Fix lit test + add one lit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Nov 28, 2024
1 parent ed82f00 commit d56acd2
Show file tree
Hide file tree
Showing 7 changed files with 488 additions and 434 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class AMDAIEFuseConsumerIntoLoopPass
public:
AMDAIEFuseConsumerIntoLoopPass() = default;
AMDAIEFuseConsumerIntoLoopPass(const AMDAIEFuseConsumerIntoLoopPass &pass) {}
AMDAIEFuseConsumerIntoLoopPass(
const AMDAIEFuseConsumerIntoLoopOptions &options)
: AMDAIEFuseConsumerIntoLoopBase(options) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect>();
Expand All @@ -36,7 +33,20 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
mlir::FunctionOpInterface funcOp = getOperation();
IRRewriter rewriter(context);

// Step 1. Find the first scf loop in postorder walk.
// Step 1. The depth until which we would keep fusing consumer chain.
// TODO(avarma): This should also be part of KernelDispatch logic.
unsigned fuseDepth = 1;
// Check if there is matmul-elementwise fusion opportunity. If so, overwrite
// the `fuseDepth` to be 2.
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>([&](linalg::LinalgOp op) {
if (isMatmulProducerOfElementwise(op)) {
fuseDepth = 2;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

// Step 2. Find the first scf loop in postorder walk.
Operation *scfLoopOp = nullptr;
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](LoopLikeOpInterface op) {
Expand All @@ -54,7 +64,8 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
<< "There is no scf.for/forall loop to fuse with\n");
return;
}
// Step 2. Search the compute op and its consumer slices.

// Step 3. Search the compute op and its consumer slices.
linalg::LinalgOp linalgOp;
scfLoopOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](linalg::LinalgOp op) {
Expand All @@ -65,14 +76,6 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Could not find any compute op\n");
return;
}
// Step 3. The depth until which we would keep fusing consumer chain.
// TODO(avarma): This should also be part of KernelDispatch logic.
unsigned fuseDepth = 1;
// Check if there is matmul-elementwise fusion opportunity. If so, overwrite
// the `fuseDepth` to be 2
if (isMatmulProducerOfElementwise(linalgOp)) {
fuseDepth = 2;
}

Operation *computeOp = linalgOp;
// Step 4. Based on the `fuseDepth`, we would greedily fuse the consumer ops.
Expand Down Expand Up @@ -111,9 +114,7 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {

} // namespace

std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
AMDAIEFuseConsumerIntoLoopOptions options) {
return std::make_unique<AMDAIEFuseConsumerIntoLoopPass>(options);
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass() {
return std::make_unique<AMDAIEFuseConsumerIntoLoopPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ std::unique_ptr<Pass> createAMDAIEFlattenLogicalObjectFifoPass();
std::unique_ptr<Pass> createAMDAIELinalgFunctionOutliningPass();

/// Create a pass to fuse the consumer op into the innermost last scf loop.
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
AMDAIEFuseConsumerIntoLoopOptions options = {});
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass();

/// Create a pass to fuse the linalg.fill into the forall loops.
std::unique_ptr<Pass> createAMDAIEFuseFillIntoForallPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,6 @@ def AMDAIEFuseConsumerIntoLoop :
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> {
let summary = "Fuse the consumer operation into the innermost last scf loop.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFuseConsumerIntoLoopPass()";
let options = [
Option<"useSCFFor", "use-scf-for", "bool", /*default=*/"false",
"Set the innermost scf loop type to fuse consumer ops into">
];
}

def AMDAIEFuseFillIntoForall :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ iree_lit_test_suite(
"dma_to_circular_dma.mlir"
"flatten_logical_objectfifo.mlir"
"linalg_function_outlining.mlir"
"fuse_consumer_into_loop_scf_for.mlir"
"fuse_consumer_into_loop_scf_forall.mlir"
"fuse_consumer_into_loop.mlir"
"fuse_fill_into_forall.mlir"
"fuse_pack_into_loop.mlir"
"hoist_for_affine_apply.mlir"
Expand Down
Loading

0 comments on commit d56acd2

Please sign in to comment.