Skip to content

Commit

Permalink
Fix and test reshape of constants
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 12, 2024
1 parent 7833e8c commit 98f5dbd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,7 @@ struct ReshapeSimplify : public OpRewritePattern<mlir::stablehlo::ReshapeOp> {
if (inp) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(),
inp.isSplat() ? inp.resizeSplat(inp.getType())
inp.isSplat() ? inp.resizeSplat(op.getType())
: inp.reshape(op.getType()));
return success();
}
Expand Down
14 changes: 14 additions & 0 deletions test/lit_tests/reshapeconst.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {
func.func @main() -> tensor<1xf32> {
%concat = stablehlo.constant dense<3.140000e+00> : tensor<f32>
%conv = stablehlo.reshape %concat : (tensor<f32>) -> tensor<1xf32>
return %conv : tensor<1xf32>
}
}

// CHECK: func.func @main() -> tensor<1xf32> {
// CHECK-NEXT: %0 = stablehlo.constant dense<3.140000e+00> : tensor<1xf32>
// CHECK-NEXT: return %0 : tensor<1xf32>
// CHECK-NEXT: }

0 comments on commit 98f5dbd

Please sign in to comment.