Skip to content

Commit

Permalink
Limit max expansion in parameter to pass
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 12, 2024
1 parent 98f5dbd commit 6340d61
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
51 changes: 35 additions & 16 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1492,13 +1492,18 @@ struct PowSimplify : public OpRewritePattern<mlir::stablehlo::PowOp> {

struct IotaSimplify : public OpRewritePattern<mlir::stablehlo::IotaOp> {
using OpRewritePattern<mlir::stablehlo::IotaOp>::OpRewritePattern;

size_t max_constant_expansion;
IotaSimplify(size_t max_constant_expansion, MLIRContext *context,
PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
: OpRewritePattern(context, benefit, generatedNames),
max_constant_expansion(max_constant_expansion) {}
LogicalResult matchAndRewrite(mlir::stablehlo::IotaOp op,
PatternRewriter &rewriter) const final {
size_t size = 1;
for (auto sz : op.getType().getShape())
size *= sz;
if (size >= 100000)
if (size >= max_constant_expansion)
return failure();

auto out = mlir::stablehlo::evalIotaOp(op.getIotaDimension(), op.getType());
Expand Down Expand Up @@ -1567,6 +1572,13 @@ struct BroadcastInDimSimplify
: public OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern<mlir::stablehlo::BroadcastInDimOp>::OpRewritePattern;

size_t max_constant_expansion;
BroadcastInDimSimplify(size_t max_constant_expansion, MLIRContext *context,
PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
: OpRewritePattern(context, benefit, generatedNames),
max_constant_expansion(max_constant_expansion) {}

LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr inp;
Expand All @@ -1576,6 +1588,11 @@ struct BroadcastInDimSimplify
if (inp.isSplat()) {
out = inp.resizeSplat(op.getType());
} else {
size_t size = 1;
for (auto sz : op.getType().getShape())
size *= sz;
if (size >= max_constant_expansion)
return failure();
auto ten = mlir::stablehlo::evalConstantOp(inp);
out = fromTensor(mlir::stablehlo::evalBroadcastInDimOp(
ten, mlir::stablehlo::Axes(op.getBroadcastDimensions()),
Expand Down Expand Up @@ -1818,20 +1835,22 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<
ConvertConcat, DynamicSliceToStatic, DynamicUpdateSliceElim,
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad, SliceSlice,
AddPad, PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
ConcatPushBinop<stablehlo::AddOp>, ConcatPushBinop<stablehlo::MulOp>,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape, IotaSimplify,
ConvertSimplify, ReshapeSimplify, BroadcastInDimSimplify, SliceSimplify,
ReduceConcat, SliceConcat, NoopSlice, CosSimplify, SinSimplify,
SqrtSimplify, AddSimplify, SubSimplify, AndSimplify, MaxSimplify,
MinSimplify, OrSimplify, NegateSimplify, MulSimplify, DivSimplify,
PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
patterns.add<ConvertConcat, DynamicSliceToStatic, DynamicUpdateSliceElim,
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad,
SliceSlice, AddPad, PadSimplify, DotReshapeDot,
ConcatConstProp, ConcatFuse, ConcatPushBinop<stablehlo::AddOp>,
ConcatPushBinop<stablehlo::MulOp>,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ConvertSimplify, ReshapeSimplify, SliceSimplify, ReduceConcat,
SliceConcat, NoopSlice, CosSimplify, SinSimplify, SqrtSimplify,
AddSimplify, SubSimplify, AndSimplify, MaxSimplify,
MinSimplify, OrSimplify, NegateSimplify, MulSimplify,
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
patterns.add<IotaSimplify, BroadcastInDimSimplify>(max_constant_expansion,
context);
if (all_finite)
patterns.add<AllFinite>(context);
if (no_nan || all_finite)
Expand Down
7 changes: 7 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def EnzymeHLOOptPass : Pass<"enzyme-hlo-opt"> {
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Whether to raise to assume no variables are nan"
>,
Option<
/*C++ variable name=*/"max_constant_expansion",
/*CLI argument=*/"max_constant_expansion",
/*type=*/"size_t",
/*default=*/"1024",
/*description=*/"Maximum size to expand constants into"
>
];
}
Expand Down

0 comments on commit 6340d61

Please sign in to comment.