From a64cec96c59d04597827f3b14310e6d07b56fa4e Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 1 Mar 2024 16:08:53 +0800 Subject: [PATCH 1/2] add reduce-buffer-live-range pass to reduce memory peak (#1279) move alloc op to the nearest location to the fist consumer op. --- tao_compiler/mlir/disc/BUILD | 20 +++++ tao_compiler/mlir/disc/disc_compiler.cc | 1 + .../mlir/disc/transforms/disc_passes.td | 5 ++ .../disc_reduce_buffer_live_range.cc | 90 +++++++++++++++++++ tao_compiler/mlir/disc/transforms/passes.h | 1 + 5 files changed, 117 insertions(+) mode change 100755 => 100644 tao_compiler/mlir/disc/disc_compiler.cc create mode 100644 tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 794feeaa2de..622d2d5a6e3 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -2090,6 +2090,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_reduce_buffer_live_range", + srcs = ["transforms/disc_reduce_buffer_live_range.cc"], + deps = [ + ":lmhlo_disc", + ":disc_util", + ":pass_details", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:BufferizationTransforms", + ], + alwayslink = 1, +) + cc_library( name = "disc_bf16_expansion", srcs = ["transforms/disc_bf16_expansion.cc"], @@ -2306,6 +2325,7 @@ cc_library( ":disc_assign_memory_space", ":disc_bf16_expansion", ":disc_buffer_deallocation", + ":disc_reduce_buffer_live_range", ":disc_canonicalizer", ":disc_comp_intens_fusion_to_cuda_source", ":disc_comp_intens_fusion_to_func", diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc old mode 100755 new mode 100644 index 57a2fd8cd3f..582f6846d1e --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -526,6 +526,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(disc_ral::createDiscReduceBufferLiveRangePass()); pm.addNestedPass(bufferization::createBufferDeallocationPass()); pm.addNestedPass(disc_ral::createDiscBufferDeallocationPass()); diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index d96ae4e63e6..7b97ed610df 100755 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -662,4 +662,9 @@ def DiscEraseBufferDeallocationPass : Pass<"disc-erase-buffer-deallocation", "ml def DiscInputOutputAliasPass : Pass<"disc-input-output-alias", "ModuleOp"> { let summary = "Input and output alias information for buffer reuse"; let constructor = "createDiscInputOutputAliasPass()"; +} + +def DiscReduceBufferLiveRangePass : Pass<"disc-reduce-buffer-live-range", "mlir::func::FuncOp"> { + let summary = "reduce buffer live range"; + let constructor = "createDiscReduceBufferLiveRangePass()"; } \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc b/tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc new file mode 100644 index 00000000000..c153ca25a76 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc @@ -0,0 +1,90 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" + +namespace mlir { +namespace disc_ral { + +using lmhlo::FusionOp; +using memref::AllocOp; + +namespace { + +LogicalResult moveBufferAllocator(AllocOp allocOp) { + Value alloc = allocOp.getResult(); + BufferViewFlowAnalysis aliasAnalysis(allocOp); + PostDominanceInfo postDominators(allocOp); + auto aliasesSet = aliasAnalysis.resolve(alloc); + // Determine the actual block to place the alloc and get liveness + // information. + Block* placementBlock = + bufferization::BufferPlacementTransformationBase::findCommonDominator( + alloc, aliasesSet, postDominators); + + Operation* toMoveBefore = nullptr; + for (auto user : alloc.getUsers()) { + if (isa(user)) continue; + // user maybe in the sub-block of the placementBlock, + // find the closest parent op inside of placementBlock + while (user->getBlock() != placementBlock) { + user = user->getParentOp(); + } + if (toMoveBefore == nullptr || user->isBeforeInBlock(toMoveBefore)) { + toMoveBefore = user; + } + } + allocOp->moveBefore(toMoveBefore); + return success(); +} + +struct DiscReduceBufferLiveRangePass + : public DiscReduceBufferLiveRangePassBase { + void runOnOperation() override { + SmallVector candidateBuffers; + func::FuncOp func = getOperation(); + + func.walk([&](AllocOp op) { candidateBuffers.push_back(op); }); + + for (int i = 0; i < candidateBuffers.size(); ++i) { + if (failed(moveBufferAllocator(candidateBuffers[i]))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> +createDiscReduceBufferLiveRangePass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index c645856b9ec..05def8a81b8 100755 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -328,6 +328,7 @@ createDiscEraseBufferDeallocationPass(); // Insert ArgsMutationOp for buffer reuse std::unique_ptr> createDiscInputOutputAliasPass(); +std::unique_ptr> createDiscReduceBufferLiveRangePass(); } // namespace disc_ral } // namespace mlir From 30cf3d3b30756fb12191890b75fff11478c1ad5f Mon Sep 17 00:00:00 2001 From: Dalong Date: Mon, 4 Mar 2024 13:33:27 +0800 Subject: [PATCH 2/2] Support mhlo.custom_call op processing (#1283) --- .../transforms/disc_hlo_legalize_to_lhlo.cc | 125 ++++++++++++++++-- .../tests/disc-hlo-legalize-to-lhlo.mlir | 11 ++ 2 files changed, 128 insertions(+), 8 deletions(-) diff --git a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc index 55138af799d..a01ad7e9025 100644 --- a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc +++ b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include "lhlo/IR/lhlo_ops.h" #include "llvm/Support/Debug.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -177,12 +178,13 @@ struct HloToLhloArgsMutationOpConverter } }; -struct HloToLhloCustomCallOpConverter : public BaseOpConversion { +struct HloToLhloCustomCallOpConverter + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - CustomCallOp hloOp, OpAdaptor adaptor, + mhlo_disc::CustomCallOp hloOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Operation* op = hloOp.getOperation(); auto operands = adaptor.getOperands(); @@ -204,12 +206,12 @@ struct HloToLhloCustomCallOpConverter : public BaseOpConversion { }; struct HloToLhloCustomCallOpV2Converter - : public BaseOpConversion { + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - CustomCallV2Op hloOp, OpAdaptor adaptor, + mhlo_disc::CustomCallV2Op hloOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Location loc = hloOp->getLoc(); SmallVector resultTypes; @@ -218,8 +220,111 @@ struct HloToLhloCustomCallOpV2Converter resultTypes.push_back( MemRefType::get(ty.getShape(), ty.getElementType())); } + rewriter.replaceOpWithNewOp( hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs()); + + return success(); + } +}; + +struct CustomCallOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::CustomCallOp hloOp, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Operation* op = hloOp.getOperation(); + auto operands = adaptor.getOperands(); + + SmallVector resultTypes; + + std::string input_placements, output_placements; + std::string input_layouts, output_layouts; + for (int i = 0; i < operands.size(); i++) { + input_placements += "d,"; + input_layouts += "*,"; + } + + hloOp->setAttr("call_target_name", + rewriter.getStringAttr(hloOp.getCallTargetName())); + hloOp->setAttr("device", rewriter.getStringAttr("x")); + + SmallVector newAttrs; + if (hloOp.getBackendConfig().has_value()) { + newAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("backend_config"), + hloOp.getBackendConfig().value())); + } + auto newCustomAttrs = DictionaryAttr::get(hloOp->getContext(), newAttrs); + hloOp->setAttr("custom_attrs", newCustomAttrs); + + if (hloOp->getNumResults() == 1 && + hloOp->getResult(0).getType().dyn_cast()) { + auto tupleTy = hloOp->getResult(0).getType().dyn_cast(); + for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) { + output_placements += "d,"; + output_layouts += "*,"; + auto tensor_type = ty.cast(); + if (!tensor_type) { + op->emitOpError() << "Unsupported result type in disc for "; + } + resultTypes.push_back(tensor_type); + } + } else { + output_placements = "d,"; + output_layouts += "*,"; + for (Value v : hloOp->getResults()) { + auto ty = v.getType().cast(); + if (!ty) { + op->emitOpError() << "Unsupported result type in disc for "; + } + resultTypes.push_back(ty); + } + } + + if (!input_placements.empty()) { + input_placements.pop_back(); + input_layouts.pop_back(); + } + if (!output_placements.empty()) { + output_placements.pop_back(); + output_layouts.pop_back(); + } + + hloOp->setAttr("input_placements", + rewriter.getStringAttr(input_placements)); + hloOp->setAttr("output_placements", + rewriter.getStringAttr(output_placements)); + hloOp->setAttr("input_layouts", rewriter.getStringAttr(input_layouts)); + hloOp->setAttr("output_layouts", rewriter.getStringAttr(output_layouts)); + hloOp->setAttr("expected_input_layouts", + rewriter.getStringAttr(input_layouts)); + hloOp->setAttr("expected_output_layouts", + rewriter.getStringAttr(output_layouts)); + + auto custom_v2_op = rewriter.create( + hloOp.getLoc(), resultTypes, operands, hloOp->getAttrs()); + + if (hloOp->getNumResults() == 1 && + hloOp->getResult(0).getType().dyn_cast()) { + auto tupleValue = hloOp->getResult(0); + for (int index = 0; index < resultTypes.size(); index++) { + for (auto& use : tupleValue.getUses()) { + Operation* consumerOp = use.getOwner(); + if (auto getTupleElementOp = + llvm::dyn_cast(consumerOp)) { + if (getTupleElementOp.getIndex() == index) { + rewriter.replaceOp(consumerOp, {custom_v2_op.getResult(index)}); + break; + } + } + } + } + } + + rewriter.replaceOp(hloOp, custom_v2_op.getResults()); return success(); } }; @@ -257,7 +362,8 @@ struct DiscHloLegalizeToLhlo void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); + lmhlo::LmhloDialect, mhlo::MhloDialect, + mhlo_disc::MhloDiscDialect>(); } public: @@ -270,10 +376,12 @@ struct DiscHloLegalizeToLhlo target.addLegalDialect(); + tensor::TensorDialect, lmhlo::LmhloDialect, + mhlo::MhloDialect>(); target.addIllegalDialect(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); bufferization::BufferizeTypeConverter converter; populateDiscHLOToLHLOConversionPattern(&context, &converter, &patterns); @@ -290,6 +398,7 @@ void populateDiscHLOToLHLOConversionPattern( RewritePatternSet* patterns) { // clang-format off patterns->insert< + CustomCallOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-hlo-legalize-to-lhlo.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-hlo-legalize-to-lhlo.mlir index bcec830b8fa..56f2731008d 100644 --- a/tao_compiler/mlir/disc/transforms/tests/disc-hlo-legalize-to-lhlo.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-hlo-legalize-to-lhlo.mlir @@ -207,4 +207,15 @@ func.func @custom_call_v2_op( output_layouts = "AB" } : (tensor, tensor<2xi32>) -> tensor return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @mhlo_custom_call +func.func @mhlo_custom_call(%arg0: tensor<2048x32x128xf16>, %arg1: tensor<2048x32x128xf16>) -> tensor<2048x32x128xf16> { + // CHECK: %0:2 = "lmhlo_disc.custom_call_v2"(%arg0, %arg1) {backend_config = "test_config", call_target_name = "custom_fn", custom_attrs = {backend_config = "test_config"}, device = "x", expected_input_layouts = "*,*", expected_output_layouts = "*,*", has_side_effect = false, input_layouts = "*,*", input_placements = "d,d", output_layouts = "*,*", output_placements = "d,d"} : (memref<2048x32x128xf16>, memref<2048x32x128xf16>) -> (memref<1x32x2048xf32>, memref<2048x32x128xf16>) + // CHECK: return %0#1 : memref<2048x32x128xf16> + %1 = "mhlo.custom_call"(%arg0, %arg1) {call_target_name="custom_fn", backend_config = "test_config"} : (tensor<2048x32x128xf16>, tensor<2048x32x128xf16>) -> tuple, tensor<2048x32x128xf16>> + %2 = mhlo.get_tuple_element %1[1] : (tuple, tensor<2048x32x128xf16>>) -> tensor<2048x32x128xf16> + return %2: tensor<2048x32x128xf16> } \ No newline at end of file