diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index efb743164..274bd4c0c 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -391,6 +391,145 @@ def ConvBatchGroupCount : GlobalExpr(rhs.getType()); + auto shape = Ty.getShape(); + + auto odim = dimensionNumbers.getKernelOutputFeatureDimension(); + + SmallVector newShape; + for (int64_t i = 0, e = shape.size(); i < e; ++i) { + if (i == odim) { + newShape.push_back(groupCount); + newShape.push_back(shape[i] / groupCount); + } else { + newShape.push_back(shape[i]); + } + } + + RankedTensorType::get(newShape, Ty.getElementType()); +}]>; + +def GradDataFilterTranspose : GlobalExpr transposes; + auto dimensionNumbers = op.getDimensionNumbers(); + auto idim = dimensionNumbers.getKernelInputFeatureDimension(); + auto odim = dimensionNumbers.getKernelOutputFeatureDimension(); + + if (odim < idim) + idim++; + + int64_t i = 0, N = op.getType().getShape().size(); + while (i <= N) { + if (i == idim) { + transposes.push_back(odim); + transposes.push_back(idim); + } else if (i != odim) { + transposes.push_back(i); + } + i++; + } + + getI64Attr(builder, transposes); +}]>; + +def GradDataConvOutputType : GlobalExpr 1) { + SmallVector shape(Ty.getShape().begin(), Ty.getShape().end()); + auto dimensionNumbers = op.getDimensionNumbers(); + shape[dimensionNumbers.getInputFeatureDimension()] *= batchGroupCount; + shape[dimensionNumbers.getInputBatchDimension()] /= batchGroupCount; + Ty = Ty.clone(shape); + } + + Ty; +}]>; + +def GradDataConvBatchGroupCountType : GlobalExpr newShape; + for (int64_t i = 0, e = shape.size(); i < e; ++i) { + if (i == fdim) { + newShape.push_back(batchGroupCount); + newShape.push_back(shape[i]); + } else if (i == bdim) { + newShape.push_back(shape[i] / batchGroupCount); + } else { + newShape.push_back(shape[i]); + } + } + + Ty.clone(newShape); +}]>; + +def GradDataConvBatchGroupPerm : GlobalExpr transposes; + + auto dimensionNumbers = op.getDimensionNumbers(); + auto fdim = dimensionNumbers.getInputFeatureDimension(); + auto bdim = dimensionNumbers.getInputBatchDimension(); + + if (fdim < bdim) + bdim++; + + int64_t i = 0, N = op.getType().getShape().size(); + while (i <= N) { + if (i == bdim) { + transposes.push_back(fdim); + transposes.push_back(bdim); + } else if (i != fdim) { + transposes.push_back(i); + } + i++; + } + + getI64Attr(builder, transposes); +}]>; + +def GradDataFilterReshape2 : GlobalExpr(rhs.getType()); + auto shape = Ty.getShape(); + + auto odim = dimensionNumbers.getKernelOutputFeatureDimension(); + auto idim = dimensionNumbers.getKernelInputFeatureDimension(); + + SmallVector newShape; + for (int64_t i = 0, e = shape.size(); i < e; ++i) { + if (i == idim) { + newShape.push_back(shape[i] * groupCount); + } else if (i == odim) { + newShape.push_back(shape[i] / groupCount); + } else { + newShape.push_back(shape[i]); + } + } + + RankedTensorType::get(newShape, Ty.getElementType()); +}]>; + def GradDataConvWindowStrides : GlobalExpr windowStrides(N, 1); @@ -410,6 +549,15 @@ def GradDataConvPadding : GlobalExpr()); } + auto dilateShape = [](int64_t shape, int64_t dilation) { + if (dilation == 1) return shape; + int64_t dilated = 1 + dilation * (shape - 1); + return dilated < 0 ? 0 : dilated; + }; + + auto lhsDilations = op.getLhsDilation(); + auto rhsDilations = op.getRhsDilation(); + auto windowStrides = op.getWindowStrides(); for (int i = 0; i < N; ++i) { auto weightDim = dimensionNumbers.getKernelSpatialDimensions()[i]; auto dataDim = dimensionNumbers.getInputSpatialDimensions()[i]; @@ -418,9 +566,19 @@ def GradDataConvPadding : GlobalExpr; def GradDataConvFeatureGroupCount : GlobalExpr; def GradDataConvBatchGroupCount : GlobalExpr; // GradFilter @@ -500,17 +661,36 @@ def GradFilterConvPadding : GlobalExpr()); } + auto dilateShape = [](int64_t shape, int64_t dilation) { + if (dilation == 1) return shape; + int64_t dilated = 1 + dilation * (shape - 1); + return dilated < 0 ? 0 : dilated; + }; + + auto lhsDilations = op.getLhsDilation(); + auto rhsDilations = op.getRhsDilation(); + auto windowStrides = op.getWindowStrides(); for (int i = 0; i < N; ++i) { - auto weightDim = dimensionNumbers.getKernelSpatialDimensions()[i]; auto dataDim = dimensionNumbers.getInputSpatialDimensions()[i]; + auto weightDim = dimensionNumbers.getKernelSpatialDimensions()[i]; auto outputDim = dimensionNumbers.getOutputSpatialDimensions()[i]; auto padBefore = newPaddingValues[2 * i]; auto padAfter = newPaddingValues[2 * i + 1]; - auto rhsShape = op.getRhs().getType().getShape()[weightDim]; - auto lhsShape = op.getLhs().getType().getShape()[dataDim]; - auto outShape = op.getType().getShape()[outputDim]; + auto lhsDilation = lhsDilations.has_value() ? + getI64Value(lhsDilations.value(), i) : + 1; + auto rhsDilation = rhsDilations.has_value() ? + getI64Value(rhsDilations.value(), i) : + 1; + auto windowStride = windowStrides.has_value() ? + getI64Value(windowStrides.value(), i) : + 1; + + auto lhsShape = dilateShape(op.getLhs().getType().getShape()[dataDim], lhsDilation); + auto rhsShape = dilateShape(op.getRhs().getType().getShape()[weightDim], rhsDilation); + auto outShape = dilateShape(op.getType().getShape()[outputDim], windowStride); newPaddingValues[2 * i] = padBefore; newPaddingValues[2 * i + 1] = outShape - lhsShape + rhsShape - padBefore - 1; @@ -522,6 +702,22 @@ def GradFilterConvPadding : GlobalExpr; +def GradFilterConvReverseDims : GlobalExpr reverseDims; + + if (windowReversals.has_value()) { + for (auto it : llvm::enumerate(getBoolIter(windowReversals.value()))) { + if (it.value()) { + reverseDims.push_back(it.index()); + } + } + } + + getI64Attr(builder, reverseDims); +}]>; + def GradFilterConvLhsDilation : GlobalExpr; @@ -552,8 +748,8 @@ def GradFilterConvDimensionNumbers : GlobalExpr; def GradFilterConvFeatureGroupCount : GlobalExpr; def GradFilterConvBatchGroupCount : GlobalExpr()[pos]; +} + static mlir::DenseElementsAttr getBoolAttr(OpBuilder &builder, llvm::ArrayRef vals) { return builder.getBoolVectorAttr(vals); diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 851612b1f..fe4def6e5 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -38,6 +38,10 @@ static mlir::DenseI64ArrayAttr getI64Attr(OpBuilder &builder, return builder.getDenseI64ArrayAttr(vals); } +static int64_t getI64Value(llvm::ArrayRef attr, size_t pos) { + return attr[pos]; +} + static mlir::DenseBoolArrayAttr getBoolAttr(OpBuilder &builder, llvm::ArrayRef vals) { return builder.getDenseBoolArrayAttr(vals); diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 642c4a35a..3c877fbf2 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6248,6 +6248,35 @@ struct GetDimensionSizeOpCanon final } }; +struct NoopReverse final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ReverseOp op, + PatternRewriter &rewriter) const override { + SmallVector newDimensions; + auto dimensions = op.getDimensions(); + auto shape = op.getResult().getType().getShape(); + + for (auto dim : dimensions) { + auto size = shape[dim]; + if (size != 1) + newDimensions.push_back(dim); + } + + if (newDimensions.empty()) { + rewriter.replaceAllUsesWith(op.getResult(), op.getOperand()); + return success(); + } + + if (newDimensions.size() == dimensions.size()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getOperand(), + newDimensions); + return success(); + } +}; + /// Converts gather ops to slice ops in case we have a single set of constant /// indices. struct GatherOpCanon final : OpRewritePattern { @@ -6655,7 +6684,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { .add { let patterns = ["NoopSlice"]; } +def ApplyNoopReversePatterns : EnzymeHLOPatternOp< + "noop_reverse"> { + let patterns = ["NoopReverse"]; +} def ApplySliceSlicePatterns : EnzymeHLOPatternOp< "slice_slice"> { let patterns = ["SliceSlice"]; diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index a2f2ecad7..7c8ded31e 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -307,6 +307,7 @@ def hlo_opts(): cos_simplify<16>; sin_simplify<16>; noop_slice<16>; +noop_reverse<16>; const_prop_through_barrier<16>; slice_slice<16>; shift_right_logical_simplify<16>; diff --git a/test/lit_tests/diffrules/stablehlo/convolution.mlir b/test/lit_tests/diffrules/stablehlo/convolution.mlir index d42d5b1ce..8ac2fd3f1 100644 --- a/test/lit_tests/diffrules/stablehlo/convolution.mlir +++ b/test/lit_tests/diffrules/stablehlo/convolution.mlir @@ -8,7 +8,6 @@ module { } } - // FORWARD: func.func @main(%arg0: tensor<8x66x66x512xf32>, %arg1: tensor<8x66x66x512xf32>, %arg2: tensor<3x3x512x512xf32>, %arg3: tensor<3x3x512x512xf32>) -> (tensor<8x64x64x512xf32>, tensor<8x64x64x512xf32>) { // FORWARD-NEXT: %0 = stablehlo.convolution(%arg1, %arg2) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<8x66x66x512xf32>, tensor<3x3x512x512xf32>) -> tensor<8x64x64x512xf32> // FORWARD-NEXT: %1 = stablehlo.convolution(%arg0, %arg3) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<8x66x66x512xf32>, tensor<3x3x512x512xf32>) -> tensor<8x64x64x512xf32> diff --git a/test/lit_tests/diffrules/stablehlo/convolution_batch_group_count.mlir b/test/lit_tests/diffrules/stablehlo/convolution_batch_group_count.mlir new file mode 100644 index 000000000..529994049 --- /dev/null +++ b/test/lit_tests/diffrules/stablehlo/convolution_batch_group_count.mlir @@ -0,0 +1,20 @@ +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active,enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --enzyme-hlo-opt | FileCheck %s --check-prefix=REVERSE + +module { + func.func @main(%arg0: tensor<8x4x4x4xf32>, %arg1: tensor<3x3x4x4xf32>) -> tensor<2x4x2x2xf32> { + %2 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {feature_group_count = 1 : i64, batch_group_count = 4 : i64} : (tensor<8x4x4x4xf32>, tensor<3x3x4x4xf32>) -> tensor<2x4x2x2xf32> + return %2 : tensor<2x4x2x2xf32> + } +} + +// REVERSE{LITERAL}: func.func @main(%arg0: tensor<8x4x4x4xf32>, %arg1: tensor<3x3x4x4xf32>, %arg2: tensor<2x4x2x2xf32>) -> (tensor<8x4x4x4xf32>, tensor<3x3x4x4xf32>) { +// REVERSE-NEXT{LITERAL}: %0 = stablehlo.reshape %arg1 : (tensor<3x3x4x4xf32>) -> tensor<3x3x4x4x1xf32> +// REVERSE-NEXT{LITERAL}: %1 = stablehlo.transpose %0, dims = [0, 1, 3, 2, 4] : (tensor<3x3x4x4x1xf32>) -> tensor<3x3x4x4x1xf32> +// REVERSE-NEXT{LITERAL}: %2 = stablehlo.reshape %1 : (tensor<3x3x4x4x1xf32>) -> tensor<3x3x16x1xf32> +// REVERSE-NEXT{LITERAL}: %3 = stablehlo.convolution(%arg2, %2) dim_numbers = [b, f, 1, 0]x[0, 1, o, i]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [true, true]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<2x4x2x2xf32>, tensor<3x3x16x1xf32>) -> tensor<2x16x4x4xf32> +// REVERSE-NEXT{LITERAL}: %4 = stablehlo.reshape %3 : (tensor<2x16x4x4xf32>) -> tensor<2x4x4x4x4xf32> +// REVERSE-NEXT{LITERAL}: %5 = stablehlo.transpose %4, dims = [1, 0, 2, 3, 4] : (tensor<2x4x4x4x4xf32>) -> tensor<4x2x4x4x4xf32> +// REVERSE-NEXT{LITERAL}: %6 = stablehlo.reshape %5 : (tensor<4x2x4x4x4xf32>) -> tensor<8x4x4x4xf32> +// REVERSE-NEXT{LITERAL}: %7 = stablehlo.convolution(%arg0, %arg2) dim_numbers = [f, b, 1, 0]x[i, o, 1, 0]->[0, 1, b, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<8x4x4x4xf32>, tensor<2x4x2x2xf32>) -> tensor<3x3x4x4xf32> +// REVERSE-NEXT{LITERAL}: return %6, %7 : tensor<8x4x4x4xf32>, tensor<3x3x4x4xf32> +// REVERSE-NEXT{LITERAL}: } diff --git a/test/lit_tests/diffrules/stablehlo/convolution_dilation.mlir b/test/lit_tests/diffrules/stablehlo/convolution_dilation.mlir new file mode 100644 index 000000000..a0686cd07 --- /dev/null +++ b/test/lit_tests/diffrules/stablehlo/convolution_dilation.mlir @@ -0,0 +1,23 @@ +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_dup,enzyme_dup retTys=enzyme_dup mode=ForwardMode" | FileCheck %s --check-prefix=FORWARD +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active,enzyme_active retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --enzyme-hlo-opt | FileCheck %s --check-prefix=REVERSE + +module { + func.func @main(%arg0: tensor<3x1x4x4xf64>, %arg1: tensor<2x2x1x2xf64>) -> tensor<3x2x6x6xf64> { + %2 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], rhs_dilate = [2, 2]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<3x1x4x4xf64>, tensor<2x2x1x2xf64>) -> tensor<3x2x6x6xf64> + return %2 : tensor<3x2x6x6xf64> + } +} + +// FORWARD: func.func @main(%arg0: tensor<3x1x4x4xf64>, %arg1: tensor<3x1x4x4xf64>, %arg2: tensor<2x2x1x2xf64>, %arg3: tensor<2x2x1x2xf64>) -> (tensor<3x2x6x6xf64>, tensor<3x2x6x6xf64>) { +// FORWARD-NEXT{LITERAL}: %0 = stablehlo.convolution(%arg1, %arg2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], rhs_dilate = [2, 2]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<3x1x4x4xf64>, tensor<2x2x1x2xf64>) -> tensor<3x2x6x6xf64> +// FORWARD-NEXT{LITERAL}: %1 = stablehlo.convolution(%arg0, %arg3) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], rhs_dilate = [2, 2]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<3x1x4x4xf64>, tensor<2x2x1x2xf64>) -> tensor<3x2x6x6xf64> +// FORWARD-NEXT{LITERAL}: %2 = stablehlo.add %0, %1 : tensor<3x2x6x6xf64> +// FORWARD-NEXT{LITERAL}: %3 = stablehlo.convolution(%arg0, %arg2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], rhs_dilate = [2, 2]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<3x1x4x4xf64>, tensor<2x2x1x2xf64>) -> tensor<3x2x6x6xf64> +// FORWARD-NEXT{LITERAL}: return %3, %2 : tensor<3x2x6x6xf64>, tensor<3x2x6x6xf64> +// FORWARD-NEXT{LITERAL}: } + +// REVERSE: func.func @main(%arg0: tensor<3x1x4x4xf64>, %arg1: tensor<2x2x1x2xf64>, %arg2: tensor<3x2x6x6xf64>) -> (tensor<3x1x4x4xf64>, tensor<2x2x1x2xf64>) { +// REVERSE-NEXT{LITERAL}: %0 = stablehlo.convolution(%arg2, %arg1) dim_numbers = [b, f, 1, 0]x[0, 1, o, i]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [2, 2], reverse = [true, true]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<3x2x6x6xf64>, tensor<2x2x1x2xf64>) -> tensor<3x1x4x4xf64> +// REVERSE-NEXT{LITERAL}: %1 = stablehlo.convolution(%arg0, %arg2) dim_numbers = [f, b, 1, 0]x[i, o, 1, 0]->[0, 1, b, f], window = {stride = [2, 2], pad = [[2, 2], [2, 2]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<3x1x4x4xf64>, tensor<3x2x6x6xf64>) -> tensor<2x2x1x2xf64> +// REVERSE-NEXT{LITERAL}: return %0, %1 : tensor<3x1x4x4xf64>, tensor<2x2x1x2xf64> +// REVERSE-NEXT{LITERAL}: } diff --git a/test/lit_tests/diffrules/stablehlo/convolution_feature_group_count.mlir b/test/lit_tests/diffrules/stablehlo/convolution_feature_group_count.mlir new file mode 100644 index 000000000..cc3f79f27 --- /dev/null +++ b/test/lit_tests/diffrules/stablehlo/convolution_feature_group_count.mlir @@ -0,0 +1,24 @@ +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_dup argTys=enzyme_dup,enzyme_dup mode=ForwardMode" | FileCheck %s --check-prefix=FORWARD +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active,enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --enzyme-hlo-opt | FileCheck %s --check-prefix=REVERSE + +module { + func.func @main(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<3x3x1x8xf32>) -> tensor<1x8x2x2xf32> { + %2 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x4x4x4xf32>, tensor<3x3x1x8xf32>) -> tensor<1x8x2x2xf32> + return %2 : tensor<1x8x2x2xf32> + } +} + +// FORWARD: func.func @main(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x4x4x4xf32>, %arg2: tensor<3x3x1x8xf32>, %arg3: tensor<3x3x1x8xf32>) -> (tensor<1x8x2x2xf32>, tensor<1x8x2x2xf32>) { +// FORWARD-NEXT{LITERAL}: %0 = stablehlo.convolution(%arg1, %arg2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x4x4x4xf32>, tensor<3x3x1x8xf32>) -> tensor<1x8x2x2xf32> +// FORWARD-NEXT{LITERAL}: %1 = stablehlo.convolution(%arg0, %arg3) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x4x4x4xf32>, tensor<3x3x1x8xf32>) -> tensor<1x8x2x2xf32> +// FORWARD-NEXT{LITERAL}: %2 = stablehlo.add %0, %1 : tensor<1x8x2x2xf32> +// FORWARD-NEXT{LITERAL}: %3 = stablehlo.convolution(%arg0, %arg2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x4x4x4xf32>, tensor<3x3x1x8xf32>) -> tensor<1x8x2x2xf32> +// FORWARD-NEXT{LITERAL}: return %3, %2 : tensor<1x8x2x2xf32>, tensor<1x8x2x2xf32> +// FORWARD-NEXT{LITERAL}: } + +// REVERSE{LITERAL}: func.func @main(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<3x3x1x8xf32>, %arg2: tensor<1x8x2x2xf32>) -> (tensor<1x4x4x4xf32>, tensor<3x3x1x8xf32>) { +// REVERSE-NEXT{LITERAL}: %0 = stablehlo.reshape %arg1 : (tensor<3x3x1x8xf32>) -> tensor<3x3x4x2xf32> +// REVERSE-NEXT{LITERAL}: %1 = stablehlo.convolution(%arg2, %0) dim_numbers = [b, f, 1, 0]x[0, 1, o, i]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [true, true]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x8x2x2xf32>, tensor<3x3x4x2xf32>) -> tensor<1x4x4x4xf32> +// REVERSE-NEXT{LITERAL}: %2 = stablehlo.convolution(%arg0, %arg2) dim_numbers = [f, b, 1, 0]x[i, o, 1, 0]->[0, 1, b, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 4 : i64, feature_group_count = 1 : i64} : (tensor<1x4x4x4xf32>, tensor<1x8x2x2xf32>) -> tensor<3x3x1x8xf32> +// REVERSE-NEXT{LITERAL}: return %1, %2 : tensor<1x4x4x4xf32>, tensor<3x3x1x8xf32> +// REVERSE-NEXT{LITERAL}: }