diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index ed2a36e30..d3a44db7f 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -1600,7 +1600,7 @@ struct ReshapeSimplify : public OpRewritePattern { if (inp) { rewriter.replaceOpWithNewOp( op, op.getType(), - inp.isSplat() ? inp.resizeSplat(inp.getType()) + inp.isSplat() ? inp.resizeSplat(op.getType()) : inp.reshape(op.getType())); return success(); } diff --git a/test/lit_tests/reshapeconst.mlir b/test/lit_tests/reshapeconst.mlir new file mode 100644 index 000000000..02db14710 --- /dev/null +++ b/test/lit_tests/reshapeconst.mlir @@ -0,0 +1,14 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +module { + func.func @main() -> tensor<1xf32> { + %concat = stablehlo.constant dense<3.140000e+00> : tensor + %conv = stablehlo.reshape %concat : (tensor) -> tensor<1xf32> + return %conv : tensor<1xf32> + } +} + +// CHECK: func.func @main() -> tensor<1xf32> { +// CHECK-NEXT: %0 = stablehlo.constant dense<3.140000e+00> : tensor<1xf32> +// CHECK-NEXT: return %0 : tensor<1xf32> +// CHECK-NEXT: }