Skip to content

Commit

Permalink
Add RemoveIV pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
BuildKite committed Jan 31, 2025
1 parent d89468e commit 640a0c9
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 12 deletions.
64 changes: 62 additions & 2 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,63 @@ gentbl_cc_library(
tblgen = "//:enzymexlamlir-tblgen",
)

td_library(
name = "TestTransformOpsTdFiles",
srcs = [
"TransformOps/TestTransformOps.td",
],
deps = [
"@llvm-project//mlir:TransformDialectTdFiles",
]
)

gentbl_cc_library(
name = "TestTransformOpsIncGen",
tbl_outs = [(
["-gen-op-decls"],
"TransformOps/TestTransformOps.h.inc",
), (
["-gen-op-defs"],
"TransformOps/TestTransformOps.cpp.inc",
),
],
td_file = "TransformOps/TestTransformOps.td",
deps = [
":TestTransformOpsTdFiles",
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
)

gentbl_cc_library(
name = "TestTransformOpsImplIncGen",
tbl_outs = [(
["-gen-test-populate-patterns-interface-impl"],
"TransformOps/TestTransformOpsImpl.cpp.inc"
)],
td_file = "TransformOps/TestTransformOps.td",
deps = [
":TestTransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

gentbl_cc_library(
name = "TestTransformPatternsIncGen",
tbl_outs = [
(
["-gen-test-populate-patterns-func-decls"],
"TransformOps/TestTransformPatterns.h.inc",
), (
["-gen-test-populate-patterns-func-defs"],
"TransformOps/TestTransformPatterns.cpp.inc",
)],
td_file = "TransformOps/TestTransformOps.td",
deps = [
":TestTransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "TransformOps",
srcs = glob(["TransformOps/*.cpp"]),
Expand All @@ -132,6 +189,9 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectInterfaces",
":TestTransformOpsIncGen",
":TestTransformOpsImplIncGen",
":TestTransformPatternsIncGen",
":TransformOpsIncGen",
":TransformOpsImplIncGen",
":XLADerivatives",
Expand Down Expand Up @@ -273,7 +333,7 @@ gentbl_cc_library(
)

gentbl_cc_library(
name = "EnzyeHLOPatternsIncGen",
name = "EnzymeHLOPatternsIncGen",
tbl_outs = [
(
["-gen-populate-patterns-func-decls"],
Expand Down Expand Up @@ -312,7 +372,7 @@ cc_library(
deps = [
":EnzymeXLAOpsIncGen",
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
":EnzymeHLOPatternsIncGen",
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:GPUPipelines",
"@llvm-project//llvm:Core",
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7557,6 +7557,8 @@ void mlir::transform::addConcatenateOpCanon(RewritePatternSet &patterns,
patterns.insert<ConcatenateOpCanon>(maxConstantExpansion, &context, benefit);
}

/////////////// End stablehlo patterns

namespace {
struct EnzymeHLOOptPass
: public enzyme::impl::EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/RegistryUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
namespace mlir {
namespace enzyme {
void registerEnzymeJaxTransformExtension(mlir::DialectRegistry &registry);
void registerTestTransformExtension(mlir::DialectRegistry &registry);
} // namespace enzyme
} // namespace mlir

Expand Down Expand Up @@ -124,6 +125,7 @@ void prepareRegistry(mlir::DialectRegistry &registry) {
mlir::linalg::registerTransformDialectExtension(registry);

mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
mlir::enzyme::registerTestTransformExtension(registry);

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
Expand Down
167 changes: 167 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
//===- TestTransformOps.cpp - Definition of test transform extension ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.h"

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp.inc"
#include "src/enzyme_ad/jax/TransformOps/TestTransformPatterns.h.inc"
#include "src/enzyme_ad/jax/TransformOps/TestTransformOpsImpl.cpp.inc"

using namespace mlir;

namespace mlir {
namespace transform {

struct RemoveIVs : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
if (!forOp.getRegion().hasOneBlock())
return failure();
unsigned numIterArgs = forOp.getNumRegionIterArgs();
auto loc = forOp->getLoc();
bool changed = false;
llvm::SetVector<unsigned> removed;
llvm::MapVector<unsigned, Value> steps;
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (unsigned i = 0; i < numIterArgs; i++) {
auto ba = forOp.getRegionIterArgs()[i];
auto init = forOp.getInits()[i];
auto next = yield->getOperand(i);

auto increment = next.getDefiningOp<arith::AddIOp>();
if (!increment)
continue;

Value step = nullptr;
if (increment.getLhs() == ba) {
step = increment.getRhs();
} else {
step = increment.getLhs();
}
if (!step)
continue;

// If it dominates the loop entry
if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion()))
continue;

rewriter.setInsertionPointToStart(forOp.getBody());
Value iterNum = rewriter.create<arith::SubIOp>(
loc, forOp.getInductionVar(), forOp.getLowerBound());
iterNum = rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

Value replacementIV = rewriter.create<arith::MulIOp>(loc, iterNum, step);
replacementIV = rewriter.create<arith::AddIOp>(loc, replacementIV, init);

rewriter.replaceAllUsesWith(ba, replacementIV);

removed.insert(i);
steps.insert({i, step});
changed = true;
}

if (!changed)
return failure();

SmallVector<Value> newInits;
for (unsigned i = 0; i < numIterArgs; i++)
if (!removed.contains(i))
newInits.push_back(forOp.getInits()[i]);

rewriter.setInsertionPoint(forOp);
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newInits);
if (!newForOp.getRegion().empty())
newForOp.getRegion().front().erase();
assert(newForOp.getRegion().empty());
rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(),
newForOp.getRegion().begin());

SmallVector<Value> newYields;
for (unsigned i = 0; i < numIterArgs; i++)
if (!removed.contains(i))
newYields.push_back(yield->getOperand(i));

rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(yield, newYields);

llvm::BitVector toDelete(numIterArgs + 1);
for (unsigned i = 0; i < numIterArgs; i++)
if (removed.contains(i))
toDelete[i + 1] = true;
newForOp.getBody()->eraseArguments(toDelete);

rewriter.setInsertionPoint(newForOp);
unsigned curNewRes = 0;
for (unsigned i = 0; i < numIterArgs; i++) {
auto result = forOp->getResult(i);
if (removed.contains(i)) {
if (result.use_empty())
continue;

rewriter.setInsertionPointAfter(forOp.getOperation());
Value iterNum = rewriter.create<arith::SubIOp>(
loc, forOp.getUpperBound(), forOp.getLowerBound());
iterNum =
rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

Value afterLoop =
rewriter.create<arith::MulIOp>(loc, iterNum, steps[i]);
afterLoop =
rewriter.create<arith::AddIOp>(loc, afterLoop, forOp.getInits()[i]);

rewriter.replaceAllUsesWith(result, afterLoop);
} else {
rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++));
}
}

forOp->getParentOp()->dump();
rewriter.eraseOp(forOp);

return success();
}
};

} // namespace transform
} // namespace mlir

#include "src/enzyme_ad/jax/TransformOps/TestTransformPatterns.cpp.inc"

namespace {
class TestTransformExtension
: public transform::TransformDialectExtension<TestTransformExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformExtension)
using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.cpp.inc"
>();
}
};
} // namespace

void mlir::enzyme::registerTestTransformExtension(
DialectRegistry &registry) {
registry.addExtensions<TestTransformExtension>();
}
22 changes: 22 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TestTransformOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- TestTransformOps.h - Declarations of Transform extension -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "src/enzyme_ad/jax/TransformOps/OpInterfaces.h.inc"

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/TransformOps/TestTransformOps.h.inc"

namespace mlir {
namespace enzyme {
void registerTestTransformExtension(mlir::DialectRegistry &registry);

} // namespace enzyme
} // namespace mlir
19 changes: 19 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TestTransformOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"

class TestPatternOp<string mnemonic, list<Trait> traits = []>
: Op<Transform_Dialect,
"apply_patterns.test." # mnemonic,
// For some reason, inherited methods are not getting declared...
!listconcat(
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>],
traits)> {
let arguments = (ins OptionalAttr<I64Attr>:$benefit);
list<string> patterns = [];
let assemblyFormat = "attr-dict";
}

def ApplyRemoveIVsPatterns : TestPatternOp<
"remove_ivs"> {
let patterns = ["RemoveIVs"];
}
Loading

0 comments on commit 640a0c9

Please sign in to comment.