Skip to content

Commit

Permalink
simply mul-of-pad (#57)
Browse files Browse the repository at this point in the history
* simply mul-of-pad

* dot_general(pad(?, 0))

* fixup

---------

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
ftynse and wsmoses authored Mar 15, 2024
1 parent 5db59e2 commit a0cff21
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
130 changes: 129 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,133 @@ struct BroadcastReduce : public OpRewritePattern<mlir::stablehlo::ReduceOp> {
}
};

template <typename OpTy>
static LogicalResult getDefiningZeroPadding(OpTy op, PatternRewriter &rewriter,
stablehlo::PadOp &pad,
Value &otherArg,
bool &isOtherArgLHS) {
pad = op.getLhs().template getDefiningOp<stablehlo::PadOp>();
otherArg = op.getRhs();
isOtherArgLHS = false;
if (!pad) {
pad = op.getRhs().template getDefiningOp<stablehlo::PadOp>();
otherArg = op.getLhs();
isOtherArgLHS = true;
}
if (!pad)
return rewriter.notifyMatchFailure(op, "operands not produced by pad");
if (!llvm::hasSingleElement(pad->getUsers()))
return rewriter.notifyMatchFailure(op, "pad has multiple users");

if (!matchPattern(pad.getPaddingValue(), m_AnyZeroFloat()))
return rewriter.notifyMatchFailure(op, "padding value not zero");
return success();
}

struct PadMultiply : public OpRewritePattern<mlir::stablehlo::MulOp> {
using OpRewritePattern<mlir::stablehlo::MulOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
PatternRewriter &rewriter) const final {
stablehlo::PadOp pad;
Value otherArg;
bool otherIsLHS;
if (failed(getDefiningZeroPadding(op, rewriter, pad, otherArg, otherIsLHS)))
return failure();

auto otherArgType = otherArg.getType().cast<TensorType>();
SmallVector<int64_t> limitDims = llvm::to_vector(otherArgType.getShape());
for (auto &&[limit, pad] : llvm::zip(limitDims, pad.getEdgePaddingHigh())) {
limit -= pad;
}
SmallVector<int64_t> interior = llvm::to_vector(pad.getInteriorPadding());
for (int64_t &value : interior) {
value += 1;
}

auto slice = rewriter.create<stablehlo::SliceOp>(
pad.getLoc(), otherArg, pad.getEdgePaddingLow(), limitDims, interior);
auto mul = rewriter.create<stablehlo::MulOp>(
op.getLoc(), otherIsLHS ? slice.getResult() : pad.getOperand(),
otherIsLHS ? pad.getOperand() : slice.getResult());
auto newPad = rewriter.create<stablehlo::PadOp>(
pad.getLoc(), mul.getResult(), pad.getPaddingValue(),
pad.getEdgePaddingLowAttr(), pad.getEdgePaddingHighAttr(),
pad.getInteriorPaddingAttr());
rewriter.replaceOp(op, newPad);
rewriter.eraseOp(pad);

return success();
}
};

struct PadDotGeneral : public OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
using OpRewritePattern<mlir::stablehlo::DotGeneralOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op,
PatternRewriter &rewriter) const final {
stablehlo::PadOp pad;
Value otherArg;
bool otherIsLHS;
if (failed(getDefiningZeroPadding(op, rewriter, pad, otherArg, otherIsLHS)))
return failure();

auto dimensionNumbers = op.getDotDimensionNumbers();
auto padContractingDimensions =
dimensionNumbers.getLhsContractingDimensions();
auto otherContractingDimensions =
dimensionNumbers.getRhsContractingDimensions();
if (otherIsLHS)
std::swap(padContractingDimensions, otherContractingDimensions);

// Need to figure out which dimension(s) to slice. For this purpose,
// look the pairs of contracting dimensions.
SmallVector<std::tuple<int64_t, int64_t, int64_t, int64_t>>
otherDimsToSlice;
for (auto &&[padDim, otherDim] :
llvm::zip(padContractingDimensions, otherContractingDimensions)) {
// If padding along the dim, mark the corresponding other dim for slicing.
int64_t low = pad.getEdgePaddingLow()[padDim];
int64_t high = pad.getEdgePaddingHigh()[padDim];
int64_t interior = pad.getInteriorPadding()[padDim];
if (low == 0 && high == 0 && interior == 0)
continue;
otherDimsToSlice.emplace_back(otherDim, low, high, interior);
}

if (otherDimsToSlice.empty()) {
return rewriter.notifyMatchFailure(op,
"contracting dimensions not padded");
}

SmallVector<int64_t> sliceLow, sliceHigh, sliceStride;
for (auto &&[pos, size] :
llvm::enumerate(otherArg.getType().cast<TensorType>().getShape())) {
auto it = llvm::find_if(
otherDimsToSlice, [&](auto &tup) { return std::get<0>(tup) == pos; });
if (it == otherDimsToSlice.end()) {
sliceLow.push_back(0);
sliceHigh.push_back(size);
sliceStride.push_back(1);
continue;
}

sliceLow.push_back(std::get<1>(*it));
sliceHigh.push_back(size - std::get<2>(*it));
sliceStride.push_back(std::get<3>(*it) + 1);
}

auto slice = rewriter.create<stablehlo::SliceOp>(
op.getLoc(), otherArg, sliceLow, sliceHigh, sliceStride);
rewriter.replaceOpWithNewOp<stablehlo::DotGeneralOp>(
op, op.getResult().getType(),
otherIsLHS ? slice.getResult() : pad.getOperand(),
otherIsLHS ? pad.getOperand() : slice.getResult(),
op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr());
return success();
}
};

struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {

void runOnOperation() override {
Expand All @@ -2313,7 +2440,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
TransposeConvert, BroadcastReduce>(context);
TransposeConvert, BroadcastReduce, PadMultiply, PadDotGeneral>(
context);
patterns.add<IotaSimplify, BroadcastInDimSimplify>(max_constant_expansion,
context);
if (all_finite)
Expand Down
14 changes: 14 additions & 0 deletions test/lit_tests/paddotgeneral.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

// CHECK-LABEL: @pad_dot_general
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3x1024x4xbf16>, %[[ARG1:.+]]: tensor<1x8x3x1024x2048xbf16>)
// CHECK: %[[SLICE:.+]] = stablehlo.slice %[[ARG1]] [0:1, 0:8, 0:3, 0:1024, 1024:2048]
// CHECK-NOT: pad
// CHECK: stablehlo.dot_general %[[ARG0]], %[[SLICE]], batching_dims = [0, 1] x [0, 2], contracting_dims = [2] x [4], precision = [DEFAULT, DEFAULT]
func.func @pad_dot_general(%4 : tensor<1x3x1024x4xbf16>, %6: tensor<1x8x3x1024x2048xbf16>) -> tensor<1x3x4x8x1024xbf16> {
%3 = stablehlo.constant dense<0.000000e+00> : tensor<bf16>
%5 = stablehlo.pad %4, %3, low = [0, 0, 1024, 0], high = [0, 0, 0, 0], interior = [0, 0, 0, 0] : (tensor<1x3x1024x4xbf16>, tensor<bf16>) -> tensor<1x3x2048x4xbf16>
%7 = stablehlo.dot_general %5, %6, batching_dims = [0, 1] x [0, 2], contracting_dims = [2] x [4], precision = [DEFAULT, DEFAULT] : (tensor<1x3x2048x4xbf16>, tensor<1x8x3x1024x2048xbf16>) -> tensor<1x3x4x8x1024xbf16>
return %7 : tensor<1x3x4x8x1024xbf16>
}

12 changes: 12 additions & 0 deletions test/lit_tests/padmultiply.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

// CHECK-LABEL: @pad_multiply
// CHECK: %[[SLICE:.+]] = stablehlo.slice %{{.*}} [0:1, 0:3, 1024:2048]
// CHECK: %[[MUL:.+]] = stablehlo.multiply %{{.*}}, %[[SLICE]] : tensor<1x3x1024xf32>
// CHECK: stablehlo.pad %[[MUL]], %{{.*}}, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0]
func.func @pad_multiply(%4: tensor<1x3x1024xf32>, %2: tensor<1x3x2048xf32>) -> tensor<1x3x2048xf32> {
%constant_0 = stablehlo.constant dense<0.0> : tensor<f32>
%5 = stablehlo.pad %4, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
%7 = stablehlo.multiply %5, %2 : tensor<1x3x2048xf32>
return %7 : tensor<1x3x2048xf32>
}

0 comments on commit a0cff21

Please sign in to comment.