Skip to content

Commit

Permalink
Fix concat fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 10, 2024
1 parent 3cd0cac commit c339731
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,8 @@ struct ConcatFuse final : OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, op.getType(),
vals);
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
op, op.getType(), vals, op.getDimensionAttr());
return success();
}
};
Expand Down
15 changes: 15 additions & 0 deletions test/lit_tests/concatfuse.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<2xf32>, %b : tensor<1xf32>, %c : tensor<1xf32>) -> tensor<4xf32> {
%concat = stablehlo.concatenate %a, %b, dim=0 : (tensor<2xf32>, tensor<1xf32>) -> tensor<3xf32>
%concat2 = stablehlo.concatenate %concat, %c, dim=0 : (tensor<3xf32>, tensor<1xf32>) -> tensor<4xf32>
return %concat2 : tensor<4xf32>
}
}

// CHECK: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[concat:.+]] = stablehlo.concatenate %arg0, %arg1, %arg2, dim = 0 : (tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %[[concat]] : tensor<4xf32>
// CHECK-NEXT: }

0 comments on commit c339731

Please sign in to comment.