Skip to content

Commit

Permalink
Further simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 12, 2024
1 parent 4a4245a commit 2bce99e
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 31 deletions.
237 changes: 206 additions & 31 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ template <typename T> Attribute makeAttr(mlir::Type elemType, T val) {

namespace {

struct SliceSimplification final : OpRewritePattern<mlir::stablehlo::SliceOp> {
struct NoopSlice final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op,
Expand Down Expand Up @@ -500,6 +500,29 @@ struct ReduceToReshape final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
}
};

struct ConvertConcat final : OpRewritePattern<mlir::stablehlo::ConvertOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op,
PatternRewriter &rewriter) const override {
auto concat = op.getOperand().getDefiningOp<stablehlo::ConcatenateOp>();
if (!concat)
return failure();

SmallVector<Value> newvals;
for (auto v : concat.getOperands()) {
newvals.push_back(rewriter.create<stablehlo::ConvertOp>(
op.getLoc(),
RankedTensorType::get(v.getType().cast<RankedTensorType>().getShape(),
op.getType().getElementType()),
v));
}
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
op, newvals, concat.getDimension());
return success();
}
};

struct ReduceConcat final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -910,24 +933,6 @@ struct BroadcastToReshape final
return failure();
assert(op.getBroadcastDimensions().size() ==
op.getOperand().getType().getShape().size());
DenseElementsAttr inp;
matchPattern(op->getOperand(0), m_Constant(&inp));
if (inp) {
if (inp.isSplat()) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(),
mlir::SplatElementsAttr::get(op.getType(),
inp.getSplatValue<mlir::Attribute>()));
return success();
}
auto inp0 = mlir::stablehlo::evalConstantOp(inp);
auto out = mlir::stablehlo::evalBroadcastInDimOp(
inp0, mlir::stablehlo::Axes(op.getBroadcastDimensions()),
op.getType());
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(),
fromTensor(out));
return success();
}

// Ensure these are sorted
for (auto en : llvm::enumerate(op.getBroadcastDimensions())) {
Expand Down Expand Up @@ -1426,6 +1431,175 @@ struct PowSimplify : public OpRewritePattern<mlir::stablehlo::PowOp> {
}
};

struct IotaSimplify : public OpRewritePattern<mlir::stablehlo::IotaOp> {
using OpRewritePattern<mlir::stablehlo::IotaOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::IotaOp op,
PatternRewriter &rewriter) const final {
size_t size = 1;
for (auto sz : op.getType().getShape())
size *= sz;
if (size >= 100000)
return failure();

auto out = mlir::stablehlo::evalIotaOp(op.getIotaDimension(), op.getType());
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(),
fromTensor(out));
return success();
}
};

struct ConvertSimplify : public OpRewritePattern<mlir::stablehlo::ConvertOp> {
using OpRewritePattern<mlir::stablehlo::ConvertOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr inp;
matchPattern(op->getOperand(0), m_Constant(&inp));
if (inp) {
stablehlo::Tensor ten;
RankedTensorType ty = op.getType();
if (inp.isSplat()) {
ten = stablehlo::makeTensor(inp.resizeSplat(
RankedTensorType::get({}, inp.getType().getElementType())));
ty = RankedTensorType::get({}, op.getType().getElementType());
} else {
ten = mlir::stablehlo::evalConstantOp(inp);
}
auto out = fromTensor(mlir::stablehlo::evalConvertOp(ten, ty));
if (inp.isSplat())
out = out.resizeSplat(op.getType());

rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(), out);
return success();
}

return failure();
}
};

struct SliceSimplify : public OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern<mlir::stablehlo::SliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr inp;
matchPattern(op->getOperand(0), m_Constant(&inp));
if (inp) {
DenseElementsAttr out;
if (inp.isSplat()) {
out = inp.resizeSplat(op.getType());
} else {
auto ten = mlir::stablehlo::evalConstantOp(inp);
out = fromTensor(mlir::stablehlo::evalSliceOp(
ten, stablehlo::Sizes(op.getStartIndices()),
stablehlo::Sizes(op.getStrides()), op.getType()));
}

rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(), out);
return success();
}

return failure();
}
};

struct BroadcastInDimSimplify
: public OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern<mlir::stablehlo::BroadcastInDimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr inp;
matchPattern(op->getOperand(0), m_Constant(&inp));
if (inp) {
DenseElementsAttr out;
if (inp.isSplat()) {
out = inp.resizeSplat(op.getType());
} else {
auto ten = mlir::stablehlo::evalConstantOp(inp);
out = fromTensor(mlir::stablehlo::evalBroadcastInDimOp(
ten, mlir::stablehlo::Axes(op.getBroadcastDimensions()),
op.getType()));
}

rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(), out);
return success();
}

return failure();
}
};

struct ReshapeSimplify : public OpRewritePattern<mlir::stablehlo::ReshapeOp> {
using OpRewritePattern<mlir::stablehlo::ReshapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr inp;
matchPattern(op->getOperand(0), m_Constant(&inp));
if (inp) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(),
inp.isSplat() ? inp.resizeSplat(inp.getType())
: inp.reshape(op.getType()));
return success();
}

return failure();
}
};

struct MaxSimplify : public OpRewritePattern<mlir::stablehlo::MaxOp> {
using OpRewritePattern<mlir::stablehlo::MaxOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::MaxOp op,
PatternRewriter &rewriter) const final {
if (op.getOperand(0) == op.getOperand(1)) {
rewriter.replaceOp(op, op.getOperand(0));
return success();
}
SmallVector<Attribute> constants;
constants.assign(op->getNumOperands(), Attribute());
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&constants[i]));
if (auto res =
constFoldBinaryOpConditional<FloatAttr, FloatAttr::ValueType, void>(
constants,
[](const APFloat &a, const APFloat &b)
-> std::optional<APFloat> { return (a > b) ? a : b; })) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
}
};

struct MinSimplify : public OpRewritePattern<mlir::stablehlo::MinOp> {
using OpRewritePattern<mlir::stablehlo::MinOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::MinOp op,
PatternRewriter &rewriter) const final {
if (op.getOperand(0) == op.getOperand(1)) {
rewriter.replaceOp(op, op.getOperand(0));
return success();
}
SmallVector<Attribute> constants;
constants.assign(op->getNumOperands(), Attribute());
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&constants[i]));
if (auto res =
constFoldBinaryOpConditional<FloatAttr, FloatAttr::ValueType, void>(
constants,
[](const APFloat &a, const APFloat &b)
-> std::optional<APFloat> { return (a < b) ? a : b; })) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
}
};

struct CosSimplify : public OpRewritePattern<mlir::stablehlo::CosineOp> {
using OpRewritePattern<mlir::stablehlo::CosineOp>::OpRewritePattern;

Expand Down Expand Up @@ -1583,18 +1757,19 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns
.add<DynamicSliceToStatic, DynamicUpdateSliceElim,
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad, SliceSlice,
AddPad, PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, AndSimplify,
OrSimplify, NegateSimplify, MulSimplify, DivSimplify, PowSimplify,
BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
patterns.add<
ConvertConcat, DynamicSliceToStatic, DynamicUpdateSliceElim,
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad, SliceSlice,
AddPad, PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape, IotaSimplify,
ConvertSimplify, ReshapeSimplify, BroadcastInDimSimplify, SliceSimplify,
ReduceConcat, SliceConcat, NoopSlice, CosSimplify, SinSimplify,
SqrtSimplify, AddSimplify, SubSimplify, AndSimplify, MaxSimplify,
MinSimplify, OrSimplify, NegateSimplify, MulSimplify, DivSimplify,
PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
if (all_finite)
patterns.add<AllFinite>(context);
if (no_nan || all_finite)
Expand Down
19 changes: 19 additions & 0 deletions test/lit_tests/convertconcat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<2xf32>, %b : tensor<1xf32>, %c : tensor<1xf32>) -> tensor<4xbf16> {
%concat = stablehlo.concatenate %a, %b, %c, dim=0 : (tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4xf32>
%conv = stablehlo.convert %concat : (tensor<4xf32>) -> tensor<4xbf16>
return %conv : tensor<4xbf16>

}
}

// CHECK: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<4xbf16> {
// CHECK-NEXT: %0 = stablehlo.convert %arg0 : (tensor<2xf32>) -> tensor<2xbf16>
// CHECK-NEXT: %1 = stablehlo.convert %arg1 : (tensor<1xf32>) -> tensor<1xbf16>
// CHECK-NEXT: %2 = stablehlo.convert %arg2 : (tensor<1xf32>) -> tensor<1xbf16>
// CHECK-NEXT: %3 = stablehlo.concatenate %0, %1, %2, dim = 0 : (tensor<2xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<4xbf16>
// CHECK-NEXT: return %3 : tensor<4xbf16>
// CHECK-NEXT: }

0 comments on commit 2bce99e

Please sign in to comment.