From 640a0c94456f03bfd05c0aa41a741496ebbd7924 Mon Sep 17 00:00:00 2001 From: BuildKite Date: Wed, 29 Jan 2025 18:20:07 +0100 Subject: [PATCH] Add RemoveIV pattern --- src/enzyme_ad/jax/BUILD | 64 ++++++- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 + src/enzyme_ad/jax/RegistryUtils.cpp | 2 + .../jax/TransformOps/TestTransformOps.cpp | 167 ++++++++++++++++++ .../jax/TransformOps/TestTransformOps.h | 22 +++ .../jax/TransformOps/TestTransformOps.td | 19 ++ src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp | 38 ++-- test/lit_tests/patterns/remove_ivs.mlir | 29 +++ 8 files changed, 331 insertions(+), 12 deletions(-) create mode 100644 src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp create mode 100644 src/enzyme_ad/jax/TransformOps/TestTransformOps.h create mode 100644 src/enzyme_ad/jax/TransformOps/TestTransformOps.td create mode 100644 test/lit_tests/patterns/remove_ivs.mlir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 0ae72caef..444f6e81d 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -121,6 +121,63 @@ gentbl_cc_library( tblgen = "//:enzymexlamlir-tblgen", ) +td_library( + name = "TestTransformOpsTdFiles", + srcs = [ + "TransformOps/TestTransformOps.td", + ], + deps = [ + "@llvm-project//mlir:TransformDialectTdFiles", + ] +) + +gentbl_cc_library( + name = "TestTransformOpsIncGen", + tbl_outs = [( + ["-gen-op-decls"], + "TransformOps/TestTransformOps.h.inc", + ), ( + ["-gen-op-defs"], + "TransformOps/TestTransformOps.cpp.inc", + ), + ], + td_file = "TransformOps/TestTransformOps.td", + deps = [ + ":TestTransformOpsTdFiles", + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", +) + +gentbl_cc_library( + name = "TestTransformOpsImplIncGen", + tbl_outs = [( + ["-gen-test-populate-patterns-interface-impl"], + "TransformOps/TestTransformOpsImpl.cpp.inc" + )], + td_file = "TransformOps/TestTransformOps.td", + deps = [ + ":TestTransformOpsTdFiles", + ], + tblgen = "//:enzymexlamlir-tblgen", +) + +gentbl_cc_library( + name = "TestTransformPatternsIncGen", + tbl_outs = [ + ( + ["-gen-test-populate-patterns-func-decls"], + "TransformOps/TestTransformPatterns.h.inc", + ), ( + ["-gen-test-populate-patterns-func-defs"], + "TransformOps/TestTransformPatterns.cpp.inc", + )], + td_file = "TransformOps/TestTransformOps.td", + deps = [ + ":TestTransformOpsTdFiles", + ], + tblgen = "//:enzymexlamlir-tblgen", +) + cc_library( name = "TransformOps", srcs = glob(["TransformOps/*.cpp"]), @@ -132,6 +189,9 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformDialectInterfaces", + ":TestTransformOpsIncGen", + ":TestTransformOpsImplIncGen", + ":TestTransformPatternsIncGen", ":TransformOpsIncGen", ":TransformOpsImplIncGen", ":XLADerivatives", @@ -273,7 +333,7 @@ gentbl_cc_library( ) gentbl_cc_library( - name = "EnzyeHLOPatternsIncGen", + name = "EnzymeHLOPatternsIncGen", tbl_outs = [ ( ["-gen-populate-patterns-func-decls"], @@ -312,7 +372,7 @@ cc_library( deps = [ ":EnzymeXLAOpsIncGen", ":EnzymeXLAPassesIncGen", - ":EnzyeHLOPatternsIncGen", + ":EnzymeHLOPatternsIncGen", "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:GPUPipelines", "@llvm-project//llvm:Core", diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 7140012e8..0995aa915 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7557,6 +7557,8 @@ void mlir::transform::addConcatenateOpCanon(RewritePatternSet &patterns, patterns.insert(maxConstantExpansion, &context, benefit); } +/////////////// End stablehlo patterns + namespace { struct EnzymeHLOOptPass : public enzyme::impl::EnzymeHLOOptPassBase { diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 6f809043d..4a7518b29 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -52,6 +52,7 @@ namespace mlir { namespace enzyme { void registerEnzymeJaxTransformExtension(mlir::DialectRegistry ®istry); +void registerTestTransformExtension(mlir::DialectRegistry ®istry); } // namespace enzyme } // namespace mlir @@ -124,6 +125,7 @@ void prepareRegistry(mlir::DialectRegistry ®istry) { mlir::linalg::registerTransformDialectExtension(registry); mlir::enzyme::registerEnzymeJaxTransformExtension(registry); + mlir::enzyme::registerTestTransformExtension(registry); mlir::registerLLVMDialectImport(registry); mlir::registerNVVMDialectImport(registry); diff --git a/src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp b/src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp new file mode 100644 index 000000000..72fecac0a --- /dev/null +++ b/src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp @@ -0,0 +1,167 @@ +//===- TestTransformOps.cpp - Definition of test transform extension ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp.inc" +#include "src/enzyme_ad/jax/TransformOps/TestTransformPatterns.h.inc" +#include "src/enzyme_ad/jax/TransformOps/TestTransformOpsImpl.cpp.inc" + +using namespace mlir; + +namespace mlir { +namespace transform { + +struct RemoveIVs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + if (!forOp.getRegion().hasOneBlock()) + return failure(); + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + auto loc = forOp->getLoc(); + bool changed = false; + llvm::SetVector removed; + llvm::MapVector steps; + auto yield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < numIterArgs; i++) { + auto ba = forOp.getRegionIterArgs()[i]; + auto init = forOp.getInits()[i]; + auto next = yield->getOperand(i); + + auto increment = next.getDefiningOp(); + if (!increment) + continue; + + Value step = nullptr; + if (increment.getLhs() == ba) { + step = increment.getRhs(); + } else { + step = increment.getLhs(); + } + if (!step) + continue; + + // If it dominates the loop entry + if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion())) + continue; + + rewriter.setInsertionPointToStart(forOp.getBody()); + Value iterNum = rewriter.create( + loc, forOp.getInductionVar(), forOp.getLowerBound()); + iterNum = rewriter.create(loc, iterNum, forOp.getStep()); + + Value replacementIV = rewriter.create(loc, iterNum, step); + replacementIV = rewriter.create(loc, replacementIV, init); + + rewriter.replaceAllUsesWith(ba, replacementIV); + + removed.insert(i); + steps.insert({i, step}); + changed = true; + } + + if (!changed) + return failure(); + + SmallVector newInits; + for (unsigned i = 0; i < numIterArgs; i++) + if (!removed.contains(i)) + newInits.push_back(forOp.getInits()[i]); + + rewriter.setInsertionPoint(forOp); + auto newForOp = rewriter.create(loc, forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep(), newInits); + if (!newForOp.getRegion().empty()) + newForOp.getRegion().front().erase(); + assert(newForOp.getRegion().empty()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().begin()); + + SmallVector newYields; + for (unsigned i = 0; i < numIterArgs; i++) + if (!removed.contains(i)) + newYields.push_back(yield->getOperand(i)); + + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, newYields); + + llvm::BitVector toDelete(numIterArgs + 1); + for (unsigned i = 0; i < numIterArgs; i++) + if (removed.contains(i)) + toDelete[i + 1] = true; + newForOp.getBody()->eraseArguments(toDelete); + + rewriter.setInsertionPoint(newForOp); + unsigned curNewRes = 0; + for (unsigned i = 0; i < numIterArgs; i++) { + auto result = forOp->getResult(i); + if (removed.contains(i)) { + if (result.use_empty()) + continue; + + rewriter.setInsertionPointAfter(forOp.getOperation()); + Value iterNum = rewriter.create( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + iterNum = + rewriter.create(loc, iterNum, forOp.getStep()); + + Value afterLoop = + rewriter.create(loc, iterNum, steps[i]); + afterLoop = + rewriter.create(loc, afterLoop, forOp.getInits()[i]); + + rewriter.replaceAllUsesWith(result, afterLoop); + } else { + rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++)); + } + } + + forOp->getParentOp()->dump(); + rewriter.eraseOp(forOp); + + return success(); + } +}; + +} // namespace transform +} // namespace mlir + +#include "src/enzyme_ad/jax/TransformOps/TestTransformPatterns.cpp.inc" + +namespace { +class TestTransformExtension + : public transform::TransformDialectExtension { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformExtension) + using Base::Base; + + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +void mlir::enzyme::registerTestTransformExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/src/enzyme_ad/jax/TransformOps/TestTransformOps.h b/src/enzyme_ad/jax/TransformOps/TestTransformOps.h new file mode 100644 index 000000000..6d01359a4 --- /dev/null +++ b/src/enzyme_ad/jax/TransformOps/TestTransformOps.h @@ -0,0 +1,22 @@ +//===- TestTransformOps.h - Declarations of Transform extension -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "src/enzyme_ad/jax/TransformOps/OpInterfaces.h.inc" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.h.inc" + +namespace mlir { +namespace enzyme { +void registerTestTransformExtension(mlir::DialectRegistry ®istry); + +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/TransformOps/TestTransformOps.td b/src/enzyme_ad/jax/TransformOps/TestTransformOps.td new file mode 100644 index 000000000..c935b6ab7 --- /dev/null +++ b/src/enzyme_ad/jax/TransformOps/TestTransformOps.td @@ -0,0 +1,19 @@ +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" + +class TestPatternOp traits = []> + : Op], + traits)> { + let arguments = (ins OptionalAttr:$benefit); + list patterns = []; + let assemblyFormat = "attr-dict"; +} + +def ApplyRemoveIVsPatterns : TestPatternOp< + "remove_ivs"> { + let patterns = ["RemoveIVs"]; +} diff --git a/src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp b/src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp index b58ee08a5..1e6e8b04d 100644 --- a/src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp +++ b/src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp @@ -18,6 +18,9 @@ enum ActionType { GenPopulatePatternsFuncDecl, GenPopulatePatternsFuncDef, GenPopulatePatternsInterfaceImpl, + GenTestPopulatePatternsFuncDecl, + GenTestPopulatePatternsFuncDef, + GenTestPopulatePatternsInterfaceImpl, }; static llvm::cl::opt action( @@ -27,7 +30,13 @@ static llvm::cl::opt action( llvm::cl::values(clEnumValN(GenPopulatePatternsFuncDef, "gen-populate-patterns-func-defs", "")), llvm::cl::values(clEnumValN(GenPopulatePatternsInterfaceImpl, - "gen-populate-patterns-interface-impl", ""))); + "gen-populate-patterns-interface-impl", "")), + llvm::cl::values(clEnumValN(GenTestPopulatePatternsFuncDecl, + "gen-test-populate-patterns-func-decls", "")), + llvm::cl::values(clEnumValN(GenTestPopulatePatternsFuncDef, + "gen-test-populate-patterns-func-defs", "")), + llvm::cl::values(clEnumValN(GenTestPopulatePatternsInterfaceImpl, + "gen-test-populate-patterns-interface-impl", ""))); llvm::StringRef getPopulateFunctionNameSuffix(const llvm::Record *rec) { return rec->getName().ends_with("Op") ? rec->getName().drop_back(2) @@ -35,9 +44,10 @@ llvm::StringRef getPopulateFunctionNameSuffix(const llvm::Record *rec) { } static bool emitPopulatePatterns(llvm::raw_ostream &os, - const llvm::RecordKeeper &records) { + const llvm::RecordKeeper &records, + llvm::StringRef patternOpStr) { for (const llvm::Record *rec : - records.getAllDerivedDefinitions("EnzymeHLOPatternOp")) { + records.getAllDerivedDefinitions(patternOpStr)) { os << "void "; llvm::StringRef ns = rec->getValueAsString("cppNamespace"); if (!ns.empty()) @@ -54,9 +64,10 @@ static bool emitPopulatePatterns(llvm::raw_ostream &os, } static bool emitPopulatePatternsFuncDecls(llvm::raw_ostream &os, - const llvm::RecordKeeper &records) { + const llvm::RecordKeeper &records, + llvm::StringRef patternOpStr) { for (const llvm::Record *rec : - records.getAllDerivedDefinitions("EnzymeHLOPatternOp")) { + records.getAllDerivedDefinitions(patternOpStr)) { llvm::StringRef ns = rec->getValueAsString("cppNamespace"); if (ns.starts_with("::")) ns = ns.drop_front(2); @@ -70,9 +81,10 @@ static bool emitPopulatePatternsFuncDecls(llvm::raw_ostream &os, } static bool emitPopulatePatternsFuncDefs(llvm::raw_ostream &os, - const llvm::RecordKeeper &records) { + const llvm::RecordKeeper &records, + llvm::StringRef patternOpStr) { for (const llvm::Record *rec : - records.getAllDerivedDefinitions("EnzymeHLOPatternOp")) { + records.getAllDerivedDefinitions(patternOpStr)) { os << "void "; llvm::StringRef ns = rec->getValueAsString("cppNamespace"); if (!ns.empty()) @@ -94,11 +106,17 @@ static bool tablegenMain(llvm::raw_ostream &os, const llvm::RecordKeeper &records) { switch (action) { case GenPopulatePatternsFuncDecl: - return emitPopulatePatternsFuncDecls(os, records); + return emitPopulatePatternsFuncDecls(os, records, "EnzymeHLOPatternOp"); case GenPopulatePatternsFuncDef: - return emitPopulatePatternsFuncDefs(os, records); + return emitPopulatePatternsFuncDefs(os, records, "EnzymeHLOPatternOp"); case GenPopulatePatternsInterfaceImpl: - return emitPopulatePatterns(os, records); + return emitPopulatePatterns(os, records, "EnzymeHLOPatternOp"); + case GenTestPopulatePatternsFuncDecl: + return emitPopulatePatternsFuncDecls(os, records, "TestPatternOp"); + case GenTestPopulatePatternsFuncDef: + return emitPopulatePatternsFuncDefs(os, records, "TestPatternOp"); + case GenTestPopulatePatternsInterfaceImpl: + return emitPopulatePatterns(os, records, "TestPatternOp"); default: llvm::report_fatal_error("unknown action"); return true; diff --git a/test/lit_tests/patterns/remove_ivs.mlir b/test/lit_tests/patterns/remove_ivs.mlir new file mode 100644 index 000000000..594a82eec --- /dev/null +++ b/test/lit_tests/patterns/remove_ivs.mlir @@ -0,0 +1,29 @@ +// RUN: enzymexlamlir-opt %s -split-input-file -allow-unregistered-dialect --transform-interpreter | FileCheck %s + +module { + func.func @test_remove_ivs(%arg0: index, %arg1: index, %arg2: index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %result = scf.for %iv = %c0 to %arg1 step %c1 iter_args(%iter_arg = %arg0) -> (index) { + %next = arith.addi %iter_arg, %arg2 : index + scf.yield %next : index + } + return %result : index + } + + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg2: !transform.any_op) { + %4 = transform.structured.match ops{["func.func"]} in %arg2 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %4 { + transform.apply_patterns.test.remove_ivs + } : !transform.any_op + transform.yield + } + } +} + +// CHECK-LABEL: func @test_remove_ivs( +// CHECK-SAME: %[[START:.*]]: index, %[[BOUND:.*]]: index, %[[STEP:.*]]: index +// CHECK: %[[MUL:.*]] = arith.muli %[[BOUND]], %[[STEP]] +// CHECK: %[[RET:.*]] = arith.addi %[[MUL]], %[[START]] +// CHECK: return %[[RET]]