diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 7140012e8..e66688dbe 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -14,6 +14,8 @@ #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" @@ -7557,6 +7559,125 @@ void mlir::transform::addConcatenateOpCanon(RewritePatternSet &patterns, patterns.insert(maxConstantExpansion, &context, benefit); } +/////////////// End stablehlo patterns + +struct RemoveIVs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + if (!forOp.getRegion().hasOneBlock()) + return failure(); + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + auto loc = forOp->getLoc(); + bool changed = false; + llvm::SetVector removed; + llvm::MapVector steps; + auto yield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < numIterArgs; i++) { + auto ba = forOp.getRegionIterArgs()[i]; + auto init = forOp.getInits()[i]; + auto next = yield->getOperand(i); + + auto increment = next.getDefiningOp(); + if (!increment) + continue; + + Value step = nullptr; + if (increment.getLhs() == ba) { + step = increment.getRhs(); + } else { + step = increment.getLhs(); + } + if (!step) + continue; + + // If it dominates the loop entry + if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion())) + continue; + + rewriter.setInsertionPointToStart(forOp.getBody()); + Value iterNum = rewriter.create( + loc, forOp.getInductionVar(), forOp.getLowerBound()); + iterNum = rewriter.create(loc, iterNum, forOp.getStep()); + + Value replacementIV = rewriter.create(loc, iterNum, step); + replacementIV = rewriter.create(loc, replacementIV, init); + + rewriter.replaceAllUsesWith(ba, replacementIV); + + removed.insert(i); + steps.insert({i, step}); + changed = true; + } + + if (!changed) + return failure(); + + SmallVector newInits; + for (unsigned i = 0; i < numIterArgs; i++) + if (!removed.contains(i)) + newInits.push_back(forOp.getInits()[i]); + + rewriter.setInsertionPoint(forOp); + auto newForOp = rewriter.create(loc, forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep(), newInits); + if (!newForOp.getRegion().empty()) + newForOp.getRegion().front().erase(); + assert(newForOp.getRegion().empty()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().begin()); + + SmallVector newYields; + for (unsigned i = 0; i < numIterArgs; i++) + if (!removed.contains(i)) + newYields.push_back(yield->getOperand(i)); + + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, newYields); + + llvm::BitVector toDelete(numIterArgs + 1); + for (unsigned i = 0; i < numIterArgs; i++) + if (removed.contains(i)) + toDelete[i + 1] = true; + newForOp.getBody()->eraseArguments(toDelete); + + rewriter.setInsertionPoint(newForOp); + unsigned curNewRes = 0; + for (unsigned i = 0; i < numIterArgs; i++) { + auto result = forOp->getResult(i); + if (removed.contains(i)) { + if (result.use_empty()) + continue; + + rewriter.setInsertionPointAfter(forOp.getOperation()); + Value iterNum = rewriter.create( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + iterNum = + rewriter.create(loc, iterNum, forOp.getStep()); + + Value afterLoop = + rewriter.create(loc, iterNum, steps[i]); + afterLoop = + rewriter.create(loc, afterLoop, forOp.getInits()[i]); + + rewriter.replaceAllUsesWith(result, afterLoop); + } else { + rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++)); + } + } + + forOp->getParentOp()->dump(); + rewriter.eraseOp(forOp); + + return success(); + } +}; + +void mlir::transform::addRemoveIV(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + namespace { struct EnzymeHLOOptPass : public enzyme::impl::EnzymeHLOOptPassBase { diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h index cfbc1ed94..4e1e46366 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h @@ -32,4 +32,5 @@ void addSelectOpCanon(RewritePatternSet &patterns, int64_t maxConstantExpansion, void addConcatenateOpCanon(RewritePatternSet &patterns, int64_t maxConstantExpansion, MLIRContext &context, PatternBenefit benefit); +void addRemoveIV(RewritePatternSet &patterns); } // namespace mlir::transform diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.cpp b/src/enzyme_ad/jax/TransformOps/TransformOps.cpp index 57385d5be..6d1a11301 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.cpp +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.cpp @@ -49,6 +49,9 @@ void SelectOpCanonPatterns::populatePatterns(RewritePatternSet &patterns) { addSelectOpCanon(patterns, getParameter(), *getContext(), PatternBenefit(getBenefit().value_or(1))); } +void ApplyRemoveIVPatterns::populatePatterns(RewritePatternSet &patterns) { + addRemoveIV(patterns); +} } // namespace transform } // namespace mlir diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index fe84cb719..74259afea 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -923,3 +923,9 @@ def ApplyPadDotGeneralPatterns : EnzymeHLOParameterizedPatternOp< } }]; } + +def ApplyRemoveIVPatterns : EnzymeHLOParameterizedPatternOp< + "remove_iv"> { + let arguments = (ins); + let assemblyFormat = "attr-dict"; +} diff --git a/test/lit_tests/raising/llvm_to_affine_access.mlir b/test/lit_tests/raising/llvm_to_affine_access.mlir new file mode 100644 index 000000000..8523bf65a --- /dev/null +++ b/test/lit_tests/raising/llvm_to_affine_access.mlir @@ -0,0 +1,29 @@ +// RUN: enzymexlamlir-opt %s -split-input-file -allow-unregistered-dialect --transform-interpreter | FileCheck %s + +module { + func.func @test_remove_ivs(%arg0: index, %arg1: index, %arg2: index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %result = scf.for %iv = %c0 to %arg1 step %c1 iter_args(%iter_arg = %arg0) -> (index) { + %next = arith.addi %iter_arg, %arg2 : index + scf.yield %next : index + } + return %result : index + } + + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg2: !transform.any_op) { + %4 = transform.structured.match ops{["func.func"]} in %arg2 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %4 { + transform.apply_patterns.enzyme_hlo.remove_iv + } : !transform.any_op + transform.yield + } + } +} + +// CHECK-LABEL: func @test_remove_ivs( +// CHECK-SAME: %[[START:.*]]: index, %[[BOUND:.*]]: index, %[[STEP:.*]]: index +// CHECK: %[[MUL:.*]] = arith.muli %[[BOUND]], %[[STEP]] +// CHECK: %[[RET:.*]] = arith.addi %[[MUL]], %[[START]] +// CHECK: return %[[RET]]