Skip to content

Commit

Permalink
Add pad folding optimization (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Mar 10, 2024
1 parent c339731 commit ae03e32
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
48 changes: 45 additions & 3 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,27 @@ struct DotReshapeDot final : OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
}
};

struct PadSimplify final : OpRewritePattern<mlir::stablehlo::PadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::PadOp op,
PatternRewriter &rewriter) const override {

for (auto &&[low, high, inner] :
llvm::zip(op.getEdgePaddingLow(), op.getEdgePaddingHigh(),
op.getInteriorPadding())) {
if (low != 0)
return failure();
if (high != 0)
return failure();
if (inner != 0)
return failure();
}
rewriter.replaceOp(op, op.getOperand());
return success();
}
};

/*
%1192 = stablehlo.pad %1189, %cst_0, low = [0], high = [1], interior = [0] :
Expand Down Expand Up @@ -925,15 +946,29 @@ struct MulSimplify : public OpRewritePattern<mlir::stablehlo::MulOp> {
LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
PatternRewriter &rewriter) const final {

// 0 * x -> x
if (matchPattern(op.getLhs(), m_AnyZeroFloat())) {
rewriter.replaceOp(op, op.getLhs());
return success();
}
if (matchPattern(op.getLhs(), m_AnyZeroFloat())) {
// x * 0 -> x
if (matchPattern(op.getRhs(), m_AnyZeroFloat())) {
rewriter.replaceOp(op, op.getRhs());
return success();
}

// 1 * x -> x
if (matchPattern(op.getLhs(), m_OneFloat())) {
rewriter.replaceOp(op, op.getRhs());
return success();
}

// x * 1 -> x
if (matchPattern(op.getRhs(), m_OneFloat())) {
rewriter.replaceOp(op, op.getLhs());
return success();
}

SmallVector<Attribute> constants;
constants.assign(op->getNumOperands(), Attribute());
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
Expand Down Expand Up @@ -963,11 +998,18 @@ struct DivSimplify : public OpRewritePattern<mlir::stablehlo::DivOp> {
LogicalResult matchAndRewrite(mlir::stablehlo::DivOp op,
PatternRewriter &rewriter) const final {

// 0 / x -> 0 [assume non nan here]
if (matchPattern(op.getLhs(), m_AnyZeroFloat())) {
rewriter.replaceOp(op, op.getLhs());
return success();
}

// x / 1 -> x
if (matchPattern(op.getRhs(), m_OneFloat())) {
rewriter.replaceOp(op, op.getLhs());
return success();
}

SmallVector<Attribute> constants;
constants.assign(op->getNumOperands(), Attribute());
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
Expand Down Expand Up @@ -1147,8 +1189,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<SlicePad, SliceSlice, AddPad, DotReshapeDot, ConcatConstProp,
ConcatFuse,
patterns.add<SlicePad, SliceSlice, AddPad, PadSimplify, DotReshapeDot,
ConcatConstProp, ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify,
Expand Down
14 changes: 14 additions & 0 deletions test/lit_tests/foldpad.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> {
%pv = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%pad = stablehlo.pad %a, %pv, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
return %pad : tensor<2x2xf32>
}
}

// CHECK: func.func @main(%arg0: tensor<2x2xf32>
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
// CHECK-NEXT: }

0 comments on commit ae03e32

Please sign in to comment.