Skip to content

Commit

Permalink
Handle full reduce of reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 15, 2024
1 parent b3674f8 commit 0a2d2ec
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 16 deletions.
124 changes: 108 additions & 16 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,98 @@ struct ReduceConcat final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
}
};

struct FullReduceReshape final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() != 1)
return failure();

auto inpTy = op.getInputs()[0].getType().cast<RankedTensorType>();
if (op.getDimensions().size() != inpTy.getShape().size())
return failure();

SmallVector<stablehlo::ReshapeOp> reshapes;
DenseMap<Operation *, int> toclone;
{
SmallVector<Value> todo = {op.getInputs()[0]};
while (todo.size()) {
auto cur = todo.pop_back_val();
auto curOp = cur.getDefiningOp();
if (!curOp)
return failure();
if (auto rs = dyn_cast<stablehlo::ReshapeOp>(curOp)) {
reshapes.push_back(rs);
if (reshapes[0].getOperand().getType().getShape() !=
rs.getOperand().getType().getShape())
return failure();
continue;
}
if (!curOp->hasTrait<mlir::OpTrait::Elementwise>())
return failure();
if (!isMemoryEffectFree(curOp))
return failure();
for (auto op : curOp->getOperands())
todo.push_back(op);
toclone[curOp] = curOp->getNumOperands();
}
}

IRMapping map;
SmallVector<Operation *> todo;
for (auto reshape : reshapes) {
map.map(reshape, reshape.getOperand());
for (auto u : reshape.getResult().getUsers()) {
if (toclone.contains(u)) {
toclone[u]--;
if (toclone[u] == 0) {
todo.push_back(u);
toclone.erase(u);
}
}
}
}
while (todo.size()) {
auto cur = todo.pop_back_val();

SmallVector<Value> vals;
for (auto op : cur->getOperands())
vals.push_back(map.lookup(op));

auto res =
rewriter.create(cur->getLoc(), cur->getName().getIdentifier(), vals,
TypeRange(reshapes[0].getOperand().getType()),
cur->getAttrs(), {}, {});

map.map(cur->getResult(0), res->getResult(0));

for (auto u : cur->getResult(0).getUsers()) {
if (toclone.contains(u)) {
toclone[u]--;
if (toclone[u] == 0) {
todo.push_back(u);
toclone.erase(u);
}
}
}
}

SmallVector<int64_t> newReduceDimensions;
for (size_t i = 0,
end = reshapes[0].getOperand().getType().getShape().size();
i < end; i++)
newReduceDimensions.push_back(i);

auto newReduction = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), op->getResultTypes(), map.lookup(op.getInputs()[0]),
op.getInitValues(), newReduceDimensions);
newReduction.getRegion().takeBody(op.getRegion());
rewriter.replaceOp(op, newReduction);
return success();
}
};

struct SliceConcat final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -2194,22 +2286,22 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<ConcatToPad, ConcatAppendingReshape, ReshapeIota, ReshapePad,
ConvertConcat, DynamicSliceToStatic, DynamicUpdateSliceElim,
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad,
SliceSlice, AddPad, PadSimplify, DotReshapeDot,
ConcatConstProp, ConcatFuse, ConcatPushBinop<stablehlo::AddOp>,
ConcatPushBinop<stablehlo::MulOp>,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ConvertSimplify, ReshapeSimplify, SliceSimplify, ReduceConcat,
SliceConcat, NoopSlice, CosSimplify, SinSimplify, SqrtSimplify,
AddSimplify, SubSimplify, AndSimplify, MaxSimplify,
MinSimplify, OrSimplify, NegateSimplify, MulSimplify,
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
TransposeConvert, BroadcastReduce>(context);
patterns.add<
FullReduceReshape, ConcatToPad, ConcatAppendingReshape, ReshapeIota,
ReshapePad, ConvertConcat, DynamicSliceToStatic, DynamicUpdateSliceElim,
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad, SliceSlice,
AddPad, PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
ConcatPushBinop<stablehlo::AddOp>, ConcatPushBinop<stablehlo::MulOp>,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape, ConvertSimplify,
ReshapeSimplify, SliceSimplify, ReduceConcat, SliceConcat, NoopSlice,
CosSimplify, SinSimplify, SqrtSimplify, AddSimplify, SubSimplify,
AndSimplify, MaxSimplify, MinSimplify, OrSimplify, NegateSimplify,
MulSimplify, DivSimplify, PowSimplify,
BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
TransposeConvert, BroadcastReduce>(context);
patterns.add<IotaSimplify, BroadcastInDimSimplify>(max_constant_expansion,
context);
if (all_finite)
Expand Down
27 changes: 27 additions & 0 deletions test/lit_tests/reducereshape.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<20xf32>, %b : tensor<20xf32>) -> tensor<f32> {
%c0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%ar = stablehlo.reshape %a : (tensor<20xf32>) -> tensor<4x5xf32>
%br = stablehlo.reshape %b : (tensor<20xf32>) -> tensor<4x5xf32>

%ma = stablehlo.add %ar, %br : tensor<4x5xf32>
%mb = stablehlo.multiply %ma, %ma : tensor<4x5xf32>


%1308 = stablehlo.reduce(%mb init: %c0) applies stablehlo.add across dimensions = [0, 1] : (tensor<4x5xf32>, tensor<f32>) -> tensor<f32>

return %1308 : tensor<f32>

}
}

// CHECK: func.func @main(%arg0: tensor<20xf32>, %arg1: tensor<20xf32>) -> tensor<f32> {
// CHECK-NEXT: %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %1 = stablehlo.add %arg0, %arg1 : tensor<20xf32>
// CHECK-NEXT: %2 = stablehlo.multiply %1, %1 : tensor<20xf32>
// CHECK-NEXT: %3 = stablehlo.reduce(%2 init: %0) applies stablehlo.add across dimensions = [0] : (tensor<20xf32>, tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return %3 : tensor<f32>
// CHECK-NEXT: }

0 comments on commit 0a2d2ec

Please sign in to comment.