Skip to content

Commit

Permalink
Add RemoveIV pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
BuildKite committed Jan 29, 2025
1 parent d89468e commit 4fdc97b
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 0 deletions.
121 changes: 121 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -7557,6 +7559,125 @@ void mlir::transform::addConcatenateOpCanon(RewritePatternSet &patterns,
patterns.insert<ConcatenateOpCanon>(maxConstantExpansion, &context, benefit);
}

/////////////// End stablehlo patterns

struct RemoveIVs : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
if (!forOp.getRegion().hasOneBlock())
return failure();
unsigned numIterArgs = forOp.getNumRegionIterArgs();
auto loc = forOp->getLoc();
bool changed = false;
llvm::SetVector<unsigned> removed;
llvm::MapVector<unsigned, Value> steps;
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (unsigned i = 0; i < numIterArgs; i++) {
auto ba = forOp.getRegionIterArgs()[i];
auto init = forOp.getInits()[i];
auto next = yield->getOperand(i);

auto increment = next.getDefiningOp<arith::AddIOp>();
if (!increment)
continue;

Value step = nullptr;
if (increment.getLhs() == ba) {
step = increment.getRhs();
} else {
step = increment.getLhs();
}
if (!step)
continue;

// If it dominates the loop entry
if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion()))
continue;

rewriter.setInsertionPointToStart(forOp.getBody());
Value iterNum = rewriter.create<arith::SubIOp>(
loc, forOp.getInductionVar(), forOp.getLowerBound());
iterNum = rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

Value replacementIV = rewriter.create<arith::MulIOp>(loc, iterNum, step);
replacementIV = rewriter.create<arith::AddIOp>(loc, replacementIV, init);

rewriter.replaceAllUsesWith(ba, replacementIV);

removed.insert(i);
steps.insert({i, step});
changed = true;
}

if (!changed)
return failure();

SmallVector<Value> newInits;
for (unsigned i = 0; i < numIterArgs; i++)
if (!removed.contains(i))
newInits.push_back(forOp.getInits()[i]);

rewriter.setInsertionPoint(forOp);
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newInits);
if (!newForOp.getRegion().empty())
newForOp.getRegion().front().erase();
assert(newForOp.getRegion().empty());
rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(),
newForOp.getRegion().begin());

SmallVector<Value> newYields;
for (unsigned i = 0; i < numIterArgs; i++)
if (!removed.contains(i))
newYields.push_back(yield->getOperand(i));

rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(yield, newYields);

llvm::BitVector toDelete(numIterArgs + 1);
for (unsigned i = 0; i < numIterArgs; i++)
if (removed.contains(i))
toDelete[i + 1] = true;
newForOp.getBody()->eraseArguments(toDelete);

rewriter.setInsertionPoint(newForOp);
unsigned curNewRes = 0;
for (unsigned i = 0; i < numIterArgs; i++) {
auto result = forOp->getResult(i);
if (removed.contains(i)) {
if (result.use_empty())
continue;

rewriter.setInsertionPointAfter(forOp.getOperation());
Value iterNum = rewriter.create<arith::SubIOp>(
loc, forOp.getUpperBound(), forOp.getLowerBound());
iterNum =
rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

Value afterLoop =
rewriter.create<arith::MulIOp>(loc, iterNum, steps[i]);
afterLoop =
rewriter.create<arith::AddIOp>(loc, afterLoop, forOp.getInits()[i]);

rewriter.replaceAllUsesWith(result, afterLoop);
} else {
rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++));
}
}

forOp->getParentOp()->dump();
rewriter.eraseOp(forOp);

return success();
}
};

void mlir::transform::addRemoveIV(RewritePatternSet &patterns) {
patterns.insert<RemoveIVs>(patterns.getContext());
}

namespace {
struct EnzymeHLOOptPass
: public enzyme::impl::EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ void addSelectOpCanon(RewritePatternSet &patterns, int64_t maxConstantExpansion,
void addConcatenateOpCanon(RewritePatternSet &patterns,
int64_t maxConstantExpansion, MLIRContext &context,
PatternBenefit benefit);
void addRemoveIV(RewritePatternSet &patterns);
} // namespace mlir::transform
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ void SelectOpCanonPatterns::populatePatterns(RewritePatternSet &patterns) {
addSelectOpCanon(patterns, getParameter(), *getContext(),
PatternBenefit(getBenefit().value_or(1)));
}
void ApplyRemoveIVPatterns::populatePatterns(RewritePatternSet &patterns) {
addRemoveIV(patterns);
}

} // namespace transform
} // namespace mlir
Expand Down
6 changes: 6 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -923,3 +923,9 @@ def ApplyPadDotGeneralPatterns : EnzymeHLOParameterizedPatternOp<
}
}];
}

def ApplyRemoveIVPatterns : EnzymeHLOParameterizedPatternOp<
"remove_iv"> {
let arguments = (ins);
let assemblyFormat = "attr-dict";
}
29 changes: 29 additions & 0 deletions test/lit_tests/raising/llvm_to_affine_access.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: enzymexlamlir-opt %s -split-input-file -allow-unregistered-dialect --transform-interpreter | FileCheck %s

module {
func.func @test_remove_ivs(%arg0: index, %arg1: index, %arg2: index) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%result = scf.for %iv = %c0 to %arg1 step %c1 iter_args(%iter_arg = %arg0) -> (index) {
%next = arith.addi %iter_arg, %arg2 : index
scf.yield %next : index
}
return %result : index
}

builtin.module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg2: !transform.any_op) {
%4 = transform.structured.match ops{["func.func"]} in %arg2 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %4 {
transform.apply_patterns.enzyme_hlo.remove_iv
} : !transform.any_op
transform.yield
}
}
}

// CHECK-LABEL: func @test_remove_ivs(
// CHECK-SAME: %[[START:.*]]: index, %[[BOUND:.*]]: index, %[[STEP:.*]]: index
// CHECK: %[[MUL:.*]] = arith.muli %[[BOUND]], %[[STEP]]
// CHECK: %[[RET:.*]] = arith.addi %[[MUL]], %[[START]]
// CHECK: return %[[RET]]

0 comments on commit 4fdc97b

Please sign in to comment.