Skip to content

Commit

Permalink
Add DiscRematerializationPass to reduce peak memory
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Jun 27, 2024
1 parent 4d35390 commit f1936c9
Show file tree
Hide file tree
Showing 5 changed files with 876 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,35 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_rematerialization",
srcs = ["transforms/disc_rematerialization.cc"],
hdrs = [
"transforms/passes.h",
"transforms/rewriters.h",
],
deps = [
":lmhlo_disc",
":pass_details",
":placement_utils",
":shape_utils",
":fusion_utils",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:SCFDialect",
],
alwayslink = 1,
)

cc_library(
name = "disc_lower_to_library_call",
srcs = ["transforms/disc_lower_to_library_call.cc"],
Expand Down Expand Up @@ -2490,6 +2519,7 @@ cc_library(
":disc_optimization_barrier_expand",
":disc_parallel_loop_collapsing",
":disc_parallel_loop_tiling",
":disc_rematerialization",
":disc_remove_dead_buffer",
":disc_remove_shape_constraints",
":disc_shape_optimization",
Expand Down
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(bufferization::createBufferDeallocationPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscBufferDeallocationPass());

pm.addPass(mhlo_disc::createDiscRematerializationPass());

pm.addPass(disc_ral::createRalInjectExecutionContextPass());
pm.addNestedPass<FuncOp>(
disc_ral::createDiscLowerToLibraryCallPass(gpu_enabled));
Expand Down
Loading

0 comments on commit f1936c9

Please sign in to comment.