Skip to content

Commit

Permalink
Reduction optimization (#47)
Browse files Browse the repository at this point in the history
* Reduction optimization

* Fuse concat fixes
  • Loading branch information
wsmoses authored Mar 10, 2024
1 parent e652346 commit 5b4f4ec
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 6 deletions.
169 changes: 165 additions & 4 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,129 @@ struct SlicePad final : OpRewritePattern<mlir::stablehlo::SliceOp> {
}
};

// From
// https://github.com/openxla/stablehlo/blob/5d1a9c892500c2e9fecbfedfa66ffe84ff1caf7b/stablehlo/dialect/StablehloOps.cpp#L1498C1-L1532C1
bool hasSameOperandAndResultTypes(Operation &op) {
Type expected;
if (op.getNumResults() != 0)
expected = op.getResult(0).getType();
if (op.getNumOperands() != 0)
expected = op.getOperand(0).getType();
if (!expected)
return false;

auto typeMatch = [&](Type actual) { return actual == expected; };
return llvm::all_of(op.getOperandTypes(), typeMatch) &&
llvm::all_of(op.getResultTypes(), typeMatch);
}

static bool isEligibleForCompactPrint(stablehlo::ReduceOp op) {
// Check E1.
auto &block = op.getBody().front();
if (!hasSingleElement(block.without_terminator()))
return false;

Operation &innerOp = *block.begin();

// Check E2.
if (innerOp.getDialect() != op->getDialect())
return false;

if (innerOp.getNumOperands() != 2 ||
!innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
!hasSameOperandAndResultTypes(innerOp) ||
!innerOp.hasTrait<mlir::hlo::OpTrait::IsCommutative>() ||
!innerOp.hasTrait<mlir::OpTrait::ZeroRegions>())
return false;

// Check E3.
if (op.getInputs().empty())
return false;

auto elemType =
op.getInputs()[0].getType().cast<ShapedType>().getElementType();
auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
if (innerOp.getOperands()[0].getType() != expectedInnerOpType)
return false;

// Check E4.
if (!llvm::equal(block.getArguments(), innerOp.getOperands()))
return false;

// Check E5.
auto retOp = dyn_cast<stablehlo::ReturnOp>(block.getTerminator());
if (!retOp)
return false;

return llvm::equal(innerOp.getResults(), retOp.getOperands());
}

struct ReduceToReshape final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() != 1)
return failure();
if (!isEligibleForCompactPrint(op))
return failure();
auto inpTy = op.getInputs()[0].getType().cast<RankedTensorType>();
for (auto idx : op.getDimensions()) {
if (inpTy.getShape()[idx] != 1)
return failure();
}

auto reshaped = rewriter.create<stablehlo::ReshapeOp>(
op.getLoc(), op.getInitValues()[0].getType(), op.getInputs()[0]);

Operation &innerOp = op.getBody().front().front();

IRMapping map;
map.map(innerOp.getOperand(0), op.getInitValues()[0]);
map.map(innerOp.getOperand(1), reshaped);
auto res = rewriter.clone(innerOp, map)->getResult(0);

rewriter.replaceOp(op, res);
return success();
}
};

struct ReduceConcat final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() != 1)
return failure();

auto concat = op.getInputs()[0].getDefiningOp<stablehlo::ConcatenateOp>();
if (!concat)
return failure();

auto dim = concat.getDimension();

if (!llvm::is_contained(op.getDimensions(), dim))
return failure();

if (!isEligibleForCompactPrint(op))
return failure();

Operation &innerOp = op.getBody().front().front();

Value prev = op.getInitValues()[0];

for (auto v : concat.getOperands()) {
IRMapping map;
map.map(op.getInitValues()[0], prev);
map.map(op.getInputs()[0], v);
prev = rewriter.clone(*op, map)->getResult(0);
}

rewriter.replaceOp(op, prev);
return success();
}
};

struct SliceConcat final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -397,6 +520,42 @@ struct AddPad final : OpRewritePattern<mlir::stablehlo::AddOp> {
}
};

struct ConcatFuse final : OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() == 1 &&
op->getOperand(0).getType() == op.getType()) {
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
SmallVector<Value> vals;
bool changed = false;
for (auto v : op->getOperands()) {
if (auto c2 = v.getDefiningOp<stablehlo::ConcatenateOp>()) {
if (c2.getDimension() == op.getDimension()) {
for (auto v2 : c2->getOperands())
vals.push_back(v2);
changed = true;
continue;
}
}
if (v.getType().cast<RankedTensorType>().getShape()[op.getDimension()] ==
0) {
changed = true;
continue;
}
vals.push_back(v);
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, op.getType(),
vals);
return success();
}
};

struct ConcatConstProp final
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -932,10 +1091,12 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<SlicePad, SliceSlice, AddPad, DotReshapeDot, ConcatConstProp,
/*ScatterToPad, */ BroadcastToReshape, SliceConcat,
SliceSimplification, CosSimplify, SinSimplify, SqrtSimplify,
AddSimplify, SubSimplify, NegateSimplify, MulSimplify,
DivSimplify, PowSimplify>(context);
ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify,
NegateSimplify, MulSimplify, DivSimplify, PowSimplify>(
context);
mlir::stablehlo::populateStablehloCanonicalizationPatterns(context,
&patterns);

Expand Down
4 changes: 3 additions & 1 deletion test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ expand_template(
substitutions = {
"@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.",
"@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"),
"@ENZYMEXLA_BINARY_DIR@": "",
"@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"),
"@ENZYME_SOURCE_DIR@": "",
"@ENZYME_BINARY_DIR@": "",
Expand All @@ -30,6 +31,7 @@ exports_files(
":lit.cfg.py",
":lit_site_cfg_py",
"//src/enzyme_ad/jax:enzyme_jax_internal",
"//:enzymexlamlir-opt",
"@llvm-project//clang:builtin_headers_gen",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
Expand All @@ -38,7 +40,7 @@ exports_files(
)
for src in glob(
[
"**/*.pyt",
"**/*.pyt", "**/*.mlir",
],
)
]
Expand Down
3 changes: 2 additions & 1 deletion test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
config.test_format = lit.formats.ShTest(execute_external)

# suffixes: A list of file extensions to treat as test files.
config.suffixes = [".pyt"]
config.suffixes = [".pyt", ".mlir"]

# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
Expand All @@ -35,6 +35,7 @@
# Tweak the PATH to include the tools dir and the scripts dir.
base_paths = [
config.llvm_tools_dir,
config.enzymexla_tools_dir,
config.environment["PATH"],
]
path = os.path.pathsep.join(base_paths) # + config.extra_paths)
Expand Down
5 changes: 5 additions & 0 deletions test/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,10 @@ config.llvm_tools_dir = "@LLVM_TOOLS_BINARY_DIR@"
if len("@ENZYME_BINARY_DIR@") == 0:
config.llvm_tools_dir = os.getcwd() + "/" + config.llvm_tools_dir

config.enzymexla_tools_dir = "@ENZYMEXLA_BINARY_DIR@"

if len(config.enzymexla_tools_dir) == 0:
config.enzymexla_tools_dir = os.getcwd()

cfgfile = os.path.dirname(os.path.abspath(__file__)) + "/lit.cfg.py"
lit_config.load_config(config, cfgfile)
24 changes: 24 additions & 0 deletions test/lit_tests/reduceconcat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<2xf32>, %b : tensor<1xf32>, %c : tensor<1xf32>) -> tensor<f32> {
%cst0 = arith.constant dense<0.000000e+00> : tensor<f32>
%concat = stablehlo.concatenate %a, %b, %c, dim=0 : (tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4xf32>

%1308 = stablehlo.reduce(%concat init: %cst0) applies stablehlo.add across dimensions = [0] : (tensor<4xf32>, tensor<f32>) -> tensor<f32>

return %1308 : tensor<f32>

}
}

// CHECK: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<f32> {
// CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<1xf32>) -> tensor<f32>
// CHECK-NEXT: %2 = stablehlo.add %0, %1 : tensor<f32>
// CHECK-NEXT: %3 = stablehlo.reshape %arg2 : (tensor<1xf32>) -> tensor<f32>
// CHECK-NEXT: %4 = stablehlo.add %2, %3 : tensor<f32>
// CHECK-NEXT: return %4 : tensor<f32>
// CHECK-NEXT: }

0 comments on commit 5b4f4ec

Please sign in to comment.