diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 0660bb0ed..fec32a207 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -2292,6 +2292,133 @@ struct BroadcastReduce : public OpRewritePattern { } }; +template +static LogicalResult getDefiningZeroPadding(OpTy op, PatternRewriter &rewriter, + stablehlo::PadOp &pad, + Value &otherArg, + bool &isOtherArgLHS) { + pad = op.getLhs().template getDefiningOp(); + otherArg = op.getRhs(); + isOtherArgLHS = false; + if (!pad) { + pad = op.getRhs().template getDefiningOp(); + otherArg = op.getLhs(); + isOtherArgLHS = true; + } + if (!pad) + return rewriter.notifyMatchFailure(op, "operands not produced by pad"); + if (!llvm::hasSingleElement(pad->getUsers())) + return rewriter.notifyMatchFailure(op, "pad has multiple users"); + + if (!matchPattern(pad.getPaddingValue(), m_AnyZeroFloat())) + return rewriter.notifyMatchFailure(op, "padding value not zero"); + return success(); +} + +struct PadMultiply : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + PatternRewriter &rewriter) const final { + stablehlo::PadOp pad; + Value otherArg; + bool otherIsLHS; + if (failed(getDefiningZeroPadding(op, rewriter, pad, otherArg, otherIsLHS))) + return failure(); + + auto otherArgType = otherArg.getType().cast(); + SmallVector limitDims = llvm::to_vector(otherArgType.getShape()); + for (auto &&[limit, pad] : llvm::zip(limitDims, pad.getEdgePaddingHigh())) { + limit -= pad; + } + SmallVector interior = llvm::to_vector(pad.getInteriorPadding()); + for (int64_t &value : interior) { + value += 1; + } + + auto slice = rewriter.create( + pad.getLoc(), otherArg, pad.getEdgePaddingLow(), limitDims, interior); + auto mul = rewriter.create( + op.getLoc(), otherIsLHS ? slice.getResult() : pad.getOperand(), + otherIsLHS ? pad.getOperand() : slice.getResult()); + auto newPad = rewriter.create( + pad.getLoc(), mul.getResult(), pad.getPaddingValue(), + pad.getEdgePaddingLowAttr(), pad.getEdgePaddingHighAttr(), + pad.getInteriorPaddingAttr()); + rewriter.replaceOp(op, newPad); + rewriter.eraseOp(pad); + + return success(); + } +}; + +struct PadDotGeneral : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op, + PatternRewriter &rewriter) const final { + stablehlo::PadOp pad; + Value otherArg; + bool otherIsLHS; + if (failed(getDefiningZeroPadding(op, rewriter, pad, otherArg, otherIsLHS))) + return failure(); + + auto dimensionNumbers = op.getDotDimensionNumbers(); + auto padContractingDimensions = + dimensionNumbers.getLhsContractingDimensions(); + auto otherContractingDimensions = + dimensionNumbers.getRhsContractingDimensions(); + if (otherIsLHS) + std::swap(padContractingDimensions, otherContractingDimensions); + + // Need to figure out which dimension(s) to slice. For this purpose, + // look the pairs of contracting dimensions. + SmallVector> + otherDimsToSlice; + for (auto &&[padDim, otherDim] : + llvm::zip(padContractingDimensions, otherContractingDimensions)) { + // If padding along the dim, mark the corresponding other dim for slicing. + int64_t low = pad.getEdgePaddingLow()[padDim]; + int64_t high = pad.getEdgePaddingHigh()[padDim]; + int64_t interior = pad.getInteriorPadding()[padDim]; + if (low == 0 && high == 0 && interior == 0) + continue; + otherDimsToSlice.emplace_back(otherDim, low, high, interior); + } + + if (otherDimsToSlice.empty()) { + return rewriter.notifyMatchFailure(op, + "contracting dimensions not padded"); + } + + SmallVector sliceLow, sliceHigh, sliceStride; + for (auto &&[pos, size] : + llvm::enumerate(otherArg.getType().cast().getShape())) { + auto it = llvm::find_if( + otherDimsToSlice, [&](auto &tup) { return std::get<0>(tup) == pos; }); + if (it == otherDimsToSlice.end()) { + sliceLow.push_back(0); + sliceHigh.push_back(size); + sliceStride.push_back(1); + continue; + } + + sliceLow.push_back(std::get<1>(*it)); + sliceHigh.push_back(size - std::get<2>(*it)); + sliceStride.push_back(std::get<3>(*it) + 1); + } + + auto slice = rewriter.create( + op.getLoc(), otherArg, sliceLow, sliceHigh, sliceStride); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), + otherIsLHS ? slice.getResult() : pad.getOperand(), + otherIsLHS ? pad.getOperand() : slice.getResult(), + op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr()); + return success(); + } +}; + struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { @@ -2313,7 +2440,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { BinBroadcastSplat, BinBroadcastSplat, BinBroadcastSplat, TransposeTranspose, - TransposeConvert, BroadcastReduce>(context); + TransposeConvert, BroadcastReduce, PadMultiply, PadDotGeneral>( + context); patterns.add(max_constant_expansion, context); if (all_finite) diff --git a/test/lit_tests/paddotgeneral.mlir b/test/lit_tests/paddotgeneral.mlir new file mode 100644 index 000000000..e3f613d43 --- /dev/null +++ b/test/lit_tests/paddotgeneral.mlir @@ -0,0 +1,14 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +// CHECK-LABEL: @pad_dot_general +// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3x1024x4xbf16>, %[[ARG1:.+]]: tensor<1x8x3x1024x2048xbf16>) +// CHECK: %[[SLICE:.+]] = stablehlo.slice %[[ARG1]] [0:1, 0:8, 0:3, 0:1024, 1024:2048] +// CHECK-NOT: pad +// CHECK: stablehlo.dot_general %[[ARG0]], %[[SLICE]], batching_dims = [0, 1] x [0, 2], contracting_dims = [2] x [4], precision = [DEFAULT, DEFAULT] +func.func @pad_dot_general(%4 : tensor<1x3x1024x4xbf16>, %6: tensor<1x8x3x1024x2048xbf16>) -> tensor<1x3x4x8x1024xbf16> { + %3 = stablehlo.constant dense<0.000000e+00> : tensor + %5 = stablehlo.pad %4, %3, low = [0, 0, 1024, 0], high = [0, 0, 0, 0], interior = [0, 0, 0, 0] : (tensor<1x3x1024x4xbf16>, tensor) -> tensor<1x3x2048x4xbf16> + %7 = stablehlo.dot_general %5, %6, batching_dims = [0, 1] x [0, 2], contracting_dims = [2] x [4], precision = [DEFAULT, DEFAULT] : (tensor<1x3x2048x4xbf16>, tensor<1x8x3x1024x2048xbf16>) -> tensor<1x3x4x8x1024xbf16> + return %7 : tensor<1x3x4x8x1024xbf16> +} + diff --git a/test/lit_tests/padmultiply.mlir b/test/lit_tests/padmultiply.mlir new file mode 100644 index 000000000..8aad1d0a8 --- /dev/null +++ b/test/lit_tests/padmultiply.mlir @@ -0,0 +1,12 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +// CHECK-LABEL: @pad_multiply +// CHECK: %[[SLICE:.+]] = stablehlo.slice %{{.*}} [0:1, 0:3, 1024:2048] +// CHECK: %[[MUL:.+]] = stablehlo.multiply %{{.*}}, %[[SLICE]] : tensor<1x3x1024xf32> +// CHECK: stablehlo.pad %[[MUL]], %{{.*}}, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] +func.func @pad_multiply(%4: tensor<1x3x1024xf32>, %2: tensor<1x3x2048xf32>) -> tensor<1x3x2048xf32> { + %constant_0 = stablehlo.constant dense<0.0> : tensor + %5 = stablehlo.pad %4, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor) -> tensor<1x3x2048xf32> + %7 = stablehlo.multiply %5, %2 : tensor<1x3x2048xf32> + return %7 : tensor<1x3x2048xf32> +}