From 6588de746116edaae6451b750f1f1d62404e7edd Mon Sep 17 00:00:00 2001 From: eedalong Date: Wed, 19 Jun 2024 15:32:33 +0800 Subject: [PATCH] Add DiscRematerializationPass to reduce peak memory --- tao_compiler/mlir/disc/BUILD | 30 ++ tao_compiler/mlir/disc/disc_compiler.cc | 2 + .../disc/transforms/disc_rematerialization.cc | 486 ++++++++++++++++++ .../mlir/disc/transforms/mhlo_disc_passes.td | 5 + tao_compiler/mlir/disc/transforms/passes.h | 2 + 5 files changed, 525 insertions(+) create mode 100644 tao_compiler/mlir/disc/transforms/disc_rematerialization.cc mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/passes.h diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index d3dfd3fbc9b..9d7ca87aa00 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -1011,6 +1011,35 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_rematerialization", + srcs = ["transforms/disc_rematerialization.cc"], + hdrs = [ + "transforms/passes.h", + "transforms/rewriters.h", + ], + deps = [ + ":lmhlo_disc", + ":pass_details", + ":placement_utils", + ":shape_utils", + ":fusion_utils", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:SCFDialect", + ], + alwayslink = 1, +) + cc_library( name = "disc_lower_to_library_call", srcs = ["transforms/disc_lower_to_library_call.cc"], @@ -2490,6 +2519,7 @@ cc_library( ":disc_optimization_barrier_expand", ":disc_parallel_loop_collapsing", ":disc_parallel_loop_tiling", + ":disc_rematerialization", ":disc_remove_dead_buffer", ":disc_remove_shape_constraints", ":disc_shape_optimization", diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index e8f0eb06ae3..1cab16968b1 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -544,6 +544,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(bufferization::createBufferDeallocationPass()); pm.addNestedPass(disc_ral::createDiscBufferDeallocationPass()); + pm.addPass(mhlo_disc::createDiscRematerializationPass()); + pm.addPass(disc_ral::createRalInjectExecutionContextPass()); pm.addNestedPass( disc_ral::createDiscLowerToLibraryCallPass(gpu_enabled)); diff --git a/tao_compiler/mlir/disc/transforms/disc_rematerialization.cc b/tao_compiler/mlir/disc/transforms/disc_rematerialization.cc new file mode 100644 index 00000000000..7870847c550 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_rematerialization.cc @@ -0,0 +1,486 @@ +// Copyright 2021 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements logic for lowering HLO DISC dialect to LHLO DISC +// dialect. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/disc/IR/disc_shape_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" +#include "mlir/disc/transforms/fusion_utils.h" +#include "mlir/disc/transforms/placement_utils.h" +#include "mlir/disc/transforms/rewriters.h" +#include "mlir/disc/transforms/shape_utils.h" + +namespace mlir { +using placement_utils::kDiscPlaceAssignment; +using placement_utils::kGpu; + +namespace mhlo_disc { +namespace { + +bool IsRematerializable(const Operation* op) { + return true; +} + +enum class RematStrategy{ + // Recompute the node at a later program point. + kRecompute, + // Change the layout into a compact form and uncompress it back at a later + // program point. + kCompress, + // Copy the data off the device to the host to be copied back later. + kHostOffload, + + // Combination of different strategies. + kRecomputeAndCompress, + kRecomputeAndHostOffload, + kCompressAndHostOffload, + kAll, + kNoAction, +}; + +struct CompactShape {}; +struct Item { + Value memref; + std::vector live_range; + bool live_out; + bool inplace_reuse; + // compressed format of this memref + CompactShape compact_shape; + // peak memory until this item + size_t current_memory_usage; +}; + +class LivingItems{ + public: + LivingItems() = default; + + Item ConstructItemFromValue(const Value memref, std::unordered_map& op_position_map) { + // Add memrefs and their live range. + std::vector live_range; + live_range.push_back(op_position_map[reinterpret_cast(memref.getDefiningOp())]); + bool live_out = false, inplace_reuse = false; + for(auto user : memref.getUsers()) { + if (auto parent_op = user->getParentOfType()) { + user = parent_op; + } + + if(isa(user)) { + live_out = true; + continue; + } + + if(isa(user)) { + inplace_reuse = true; + continue; + } + live_range.push_back(op_position_map[reinterpret_cast(user)]); + } + + // Sort and remove duplicates. + // Duplicates happen inside a fusion block + std::sort(live_range.begin(), live_range.end()); + auto new_end = std::unique(live_range.begin(), live_range.end()); + live_range.erase(new_end, live_range.end()); + + return Item{memref, live_range, live_out, inplace_reuse}; + } + + void Add(const Value memref, std::unordered_map& op_position_map) { + // Add memrefs and their live range. + auto item = this->ConstructItemFromValue(memref, op_position_map); + Add(item); + } + + void Add(const Item& item) { + int64_t key = reinterpret_cast(item.memref.getAsOpaquePointer()); + live_range_map_[key] = living_items_.size(); + living_items_.push_back(item); + } + + void Remove(const Value memref) { + int64_t key = reinterpret_cast(memref.getAsOpaquePointer()); + int index = live_range_map_[key]; + live_range_map_.erase(key); + //living_items_.erase(std::advance(living_items_.begin(), index)); + } + + bool IsExist(const Value memref) { + int64_t key = reinterpret_cast(memref.getAsOpaquePointer()); + return live_range_map_.find(key) != live_range_map_.end(); + } + + void AddBefore(const Item& item, Item& target_item) { + + } + + void AddAfter(const Item& item, Item& target_item) { + + } + + void Update(const Item& item, Item& target_item) { + + } + + std::list& GetLivingItems() + { + return living_items_; + } + private: + std::list living_items_; + std::map live_range_map_; +}; + +class MemoryUsageTracker { + public: + MemoryUsageTracker() = default; + + void SetRematStrategy(RematStrategy strategy) { + remat_strategy_ = strategy; + } + void SetAllOperationPositionInfo(const std::unordered_map& operation_position_map, const std::unordered_map& reverse_operation_position_map) { + operation_position_map_ = operation_position_map; + reverse_operation_position_map_ = reverse_operation_position_map; + } + + void ProcessAlloc(memref::AllocOp op) { + auto memref = op.getResult(); + if (NeedSkip(memref)) { + return; + } + + auto item = living_items_.ConstructItemFromValue(memref, operation_position_map_); + if (NeedSkip(item)) { + return; + } + + current_memory_usage_ += GetMemoryUsageForValue(memref); + current_peak_memory_ = (current_memory_usage_ > current_peak_memory_)? current_memory_usage_ : current_peak_memory_; + item.current_memory_usage = current_memory_usage_; + living_items_.Add(item); + } + + void ProcessDealloc(memref::DeallocOp op) { + auto memref = op.getOperand(); + if(!living_items_.IsExist(memref)) { + return; + } + current_memory_usage_ -= GetMemoryUsageForValue(memref); + living_items_.Remove(memref); + } + + void ProcessCustomCallV2(lmhlo_disc::CustomCallV2Op op) { + for(auto memref : op.getResults()) { + if(NeedSkip(memref)) { + continue; + } + current_memory_usage_ += GetMemoryUsageForValue(memref); + current_peak_memory_ = (current_memory_usage_ > current_peak_memory_)? current_memory_usage_ : current_peak_memory_; + living_items_.Add(memref, operation_position_map_); + } + } + + size_t GetMemoryUsageForValue(Value memref) { + auto memref_ty = memref.getType().dyn_cast_or_null(); + if(!memref_ty) { + return 0; + } + assert(memref_ty.getLayout().isIdentity()); + if(memref_ty.hasStaticShape()) { + int byte_width = memref_ty.getElementTypeBitWidth() / 8; + auto shape = memref_ty.getShape(); + size_t logical_size = byte_width; + for (size_t dimSize : shape) { + logical_size *= dimSize; + } + return logical_size; + } else { + throw std::logic_error("GetMemoryUsageForValue for dynamic shape memref not implemented"); + } + } + + CompactShape GetCompactShape(Value memref) { + throw std::logic_error("GetCompactShape not implemented"); + } + size_t GetCompactedMemoryUsageForItem(const Item& item) { + throw std::logic_error("GetCompactedMemoryUsageForItem not implemented"); + } + + void Recompute(int op_position, const Item& root_item) { + int start_position; + for(int idx = 0; idx < root_item.live_range.size(); idx++) { + if(root_item.live_range[idx] > op_position) { + start_position = root_item.live_range[idx - 1]; + break; + } + } + llvm::dbgs() << "Recompute Memref " << root_item.memref << " from " << *(reinterpret_cast(reverse_operation_position_map_[start_position])) << " at position " << op_position << "\n"; + } + + void Offload(int op_position, const Item& root_item) { + throw std::logic_error("Offload not implemented"); + } + + void Compress(int op_position, const Item& item) { + throw std::logic_error("Compress not implemented"); + } + + size_t GetOffloadScore(int op_position, const Item& item, size_t target_peak_memory=-1) { + throw std::logic_error("GetOffloadScore not implemented"); + } + + size_t GetCompressionScore(int op_position, const Item& item, size_t target_peak_memory=-1) { + throw std::logic_error(" GetCompressionScore not implemented"); + } + + size_t GetRecomputationScore(int op_position, const Item& item, size_t target_peak_memory=-1) { + + // We want this interval to be large + int interval = 1, start_position = 0, end_position = 0; + for(int idx = 0; idx < item.live_range.size(); idx++) { + if(item.live_range[idx] > op_position) { + if(idx == 1 ) return kInvalidScore; + end_position = item.live_range[idx]; + start_position = item.live_range[idx - 1]; + interval = end_position - start_position; + break; + } + } + // remove meaningless ops + int temp_position = start_position; + while(temp_position < end_position) { + if(isa(reinterpret_cast(reverse_operation_position_map_[temp_position]))) { + interval -= 1; + } + temp_position += 1; + } + + // We want memory_saving to be large + size_t memory_saving = GetMemoryUsageForValue(item.memref); + + return memory_saving; + } + std::pair GetRematEvaluation(int op_position, const Item& item) { + switch(remat_strategy_) { + case RematStrategy::kRecompute: + return std::make_pair(RematStrategy::kRecompute, GetRecomputationScore(op_position, item)); + case RematStrategy::kCompress: + return std::make_pair(RematStrategy::kCompress, GetCompressionScore(op_position, item)); + case RematStrategy::kHostOffload: + return std::make_pair(RematStrategy::kHostOffload, GetOffloadScore(op_position, item)); + default: + return std::make_pair(RematStrategy::kRecompute, GetRecomputationScore(op_position, item)); + } + } + + bool RematerializeToTargetMemoryUsage(int op_position, size_t peak_memory_target) { + + int try_count = 0; + while(current_peak_memory_ > peak_memory_target && try_count++ < kMaxTryCount) { + std::vector> items_eval_res; + for(auto item : living_items_.GetLivingItems()) { + auto eval_res = GetRematEvaluation(op_position, item); + if(eval_res.second != kInvalidScore) { + items_eval_res.push_back(std::make_tuple(item, eval_res.first, eval_res.second)); + } + } + // Sort by score + std::sort(items_eval_res.begin(), items_eval_res.end(), [](auto& a, auto& b) { + return std::get<2>(a) > std::get<2>(b); + }); + + auto best_item = items_eval_res[0]; + + switch(std::get<1>(best_item)) { + case RematStrategy::kRecompute: + Recompute(op_position, std::get<0>(best_item)); + break; + case RematStrategy::kCompress: + Compress(op_position, std::get<0>(best_item)); + break; + case RematStrategy::kHostOffload: + Offload(op_position, std::get<0>(best_item)); + break; + default: + ; + } + } + return current_peak_memory_ <= peak_memory_target; + } + + void RematerializeToLowestMemoryUsage() { + // Iterate until we cannot get more memory-saving benefit + throw std::logic_error("RematerializeToLowestMemoryUsage not implemented"); + } + + size_t GetCurrentPeakMemoryUsage() { return current_peak_memory_; } + + bool NeedSkip(const Value memref) { + // We also need to handle those buffers which are used only inside a fusion block + // We escape them because they will be removed later + return GetMemoryUsageForValue(memref) <= kSmallMemrefSize; + } + bool NeedSkip(const Item& item) { + // We also need to handle those buffers which are used only inside a fusion block + // We escape them because they will be removed later + if(GetMemoryUsageForValue(item.memref) <= kSmallMemrefSize) return true; + /* + The pattern is like: + buffer = alloc() + fusion { + use(buffer) + } + dealloc(buffer) + */ + if(item.live_range.size() == 3 && !item.live_out){ + if(isa(reinterpret_cast(reverse_operation_position_map_[item.live_range[0]])) && + isa(reinterpret_cast(reverse_operation_position_map_[item.live_range[2]]))) { + return true; + } + } + + return false; + } + private: + LivingItems living_items_; + size_t current_peak_memory_ = 0; + size_t current_memory_usage_ = 0; + + const size_t kSmallMemrefSize = 0 * 1024ll * 1024ll; // memoryrefs under kSmallMemrefSize are not considered when remat; + const int kMaxTryCount = 1; + const int kInvalidScore = -1; + + const int kMaxRecomputeBlockSize = 1; + std::unordered_map operation_position_map_; + std::unordered_map reverse_operation_position_map_; + RematStrategy remat_strategy_ = RematStrategy::kRecompute; +}; + +struct DiscRematerializationPass : public DiscRematerializationPassBase { + using DiscRematerializationPassBase::DiscRematerializationPassBase; + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + private: + MemoryUsageTracker memory_usage_tracker_; + + public: + DiscRematerializationPass() = default; + + bool IsDynmaicShapeGraph() { + return false; + } + + size_t GetPeakMemoryLimit() { + if(IsDynmaicShapeGraph()) { + return -1; + } + return 10ll * 1024ll * 1024ll * 1024ll; // 10GB + } + + void runOnOperation() override { + auto& context = getContext(); + RewritePatternSet patterns(&context); + ConversionTarget target(context); + target.addLegalDialect(); + + ModuleOp module = getOperation(); + auto main_func = module.lookupSymbol("main"); + std::unordered_map op_position_map; + std::unordered_map reverse_op_position_map; + + int op_position = 0; + for (auto& block : main_func.getBody()) { + for (auto& op : block) { + op_position_map[reinterpret_cast(&op)] = op_position; + reverse_op_position_map[op_position] = reinterpret_cast(&op); + //if(!isa(op)) { + //op_position += 1; + //} + op_position += 1; + } + } + + memory_usage_tracker_.SetAllOperationPositionInfo(op_position_map, reverse_op_position_map); + // iterate over op_position_map + op_position = 0; + for (auto& block : main_func.getBody()) { + for (auto& op : block) { + if(isa(op)) { + memory_usage_tracker_.ProcessAlloc(cast(op)); + if(!IsDynmaicShapeGraph() && memory_usage_tracker_.GetCurrentPeakMemoryUsage() > GetPeakMemoryLimit()) { + memory_usage_tracker_.RematerializeToTargetMemoryUsage(op_position, GetPeakMemoryLimit()); + } + } else if(isa(op)) { + memory_usage_tracker_.ProcessDealloc(cast(op)); + } else if(isa(op)) { + memory_usage_tracker_.ProcessCustomCallV2(cast(op)); + if(!IsDynmaicShapeGraph() && memory_usage_tracker_.GetCurrentPeakMemoryUsage() > GetPeakMemoryLimit()) { + memory_usage_tracker_.RematerializeToTargetMemoryUsage(op_position, GetPeakMemoryLimit()); + } + } + op_position += 1; + llvm::dbgs() << "Current Peak Memory Usage: " << memory_usage_tracker_.GetCurrentPeakMemoryUsage() / 1024.0 / 1024.0 / 1024.0 << "\n"; + } + } + + // Dynamic Shape Graph Processing + if(IsDynmaicShapeGraph()) { + memory_usage_tracker_.RematerializeToLowestMemoryUsage(); + } + return; + } +}; +} // namespace + +std::unique_ptr> createDiscRematerializationPass() { + return std::make_unique(); +} + +} // namespace mhlo_disc +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td index da3859ef866..86508599839 100755 --- a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td @@ -39,3 +39,8 @@ def DiscOpSchedulePass : Pass<"disc-op-schedule", "ModuleOp"> { let summary = "Schedule ops in a function"; let constructor = "createDiscOpSchedulePass()"; } + +def DiscRematerializationPass : Pass<"disc-rematerialization", "ModuleOp"> { + let summary = "Remat to reduce peak memory"; + let constructor = "createDiscRematerializationPass()"; +} diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h old mode 100644 new mode 100755 index b09b0bef411..3e35be98d59 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -354,6 +354,8 @@ createDiscOptimizationBarrierExpandPass(); std::unique_ptr> createDiscOpSchedulePass(); +std::unique_ptr> createDiscRematerializationPass(); + std::unique_ptr> createDiscArgsMutationExpandPass(); } // namespace mhlo_disc