diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 2ed6ac2b133..bccadfd8565 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -941,10 +941,7 @@ cc_library( "transforms/rewriters.h", ], deps = [ - ":mhlo_disc", ":lmhlo_disc", - ":disc_ral", - ":disc_map_hlo_to_lhlo_op", ":pass_details", ":placement_utils", ":shape_utils", @@ -2121,6 +2118,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_reduce_buffer_live_range", + srcs = ["transforms/disc_reduce_buffer_live_range.cc"], + deps = [ + ":lmhlo_disc", + ":disc_util", + ":pass_details", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:BufferizationTransforms", + ], + alwayslink = 1, +) + cc_library( name = "disc_bf16_expansion", srcs = ["transforms/disc_bf16_expansion.cc"], @@ -2337,6 +2353,7 @@ cc_library( ":disc_assign_memory_space", ":disc_bf16_expansion", ":disc_buffer_deallocation", + ":disc_reduce_buffer_live_range", ":disc_canonicalizer", ":disc_comp_intens_fusion_to_cuda_source", ":disc_comp_intens_fusion_to_func", diff --git a/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td b/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td old mode 100644 new mode 100755 index 48f8afe5ce0..3cb14895865 --- a/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td +++ b/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td @@ -295,5 +295,30 @@ def LHLO_ArgsMutationOp : LHLODISC_Op<"args_mutation", []> { ); } +def LHLODISC_OptimizationBarrierOp : LHLODISC_Op<"optimization_barrier", []> { + let summary = "OptimizationBarrier operation"; + let description = [{ + Ensures that the operations that produce the `operand` are executed before any + operations that depend on the `result` and prevents compiler transformations + from moving operations across the barrier. Other than that, the operation is + an identity, i.e. `result` = `operand`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier + + Example: + ```mlir + %result0, %result1 = mhlo.optimization_barrier %operand0, %operand1 : tensor, tensor + ``` + }]; + + let arguments = (ins + Arg, "", [MemRead]>:$args + ); + + let results = (outs Variadic); + +} + #endif // LMHLO_DISC_OPS diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc old mode 100755 new mode 100644 index 0ae4c0eb461..9e681c8b738 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -528,6 +528,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(disc_ral::createDiscReduceBufferLiveRangePass()); pm.addNestedPass(bufferization::createBufferDeallocationPass()); pm.addNestedPass(disc_ral::createDiscBufferDeallocationPass()); diff --git a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc index 90ed18fc3a1..21ad976ae22 100755 --- a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc +++ b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc @@ -185,9 +185,6 @@ struct HloToLhloOptimizationBarrierOpConverter LogicalResult matchAndRewrite( mhlo::OptimizationBarrierOp hloOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - - llvm::dbgs() << "Converting mhlo::OptimizationBarrierOp \n"; - Operation* op = hloOp.getOperation(); auto operands = adaptor.getOperands(); @@ -197,9 +194,9 @@ struct HloToLhloOptimizationBarrierOpConverter resultTypes.push_back( MemRefType::get(ty.getShape(), ty.getElementType())); } - - llvm::dbgs() << "Replace Op With lmhlo_disc::OptimizationBarrierOp\n"; - rewriter.replaceOpWithNewOp(hloOp, resultTypes, operands, op->getAttrs()); + + rewriter.replaceOpWithNewOp( + hloOp, resultTypes, operands, op->getAttrs()); return success(); } @@ -247,6 +244,7 @@ struct HloToLhloCustomCallOpV2Converter resultTypes.push_back( MemRefType::get(ty.getShape(), ty.getElementType())); } + rewriter.replaceOpWithNewOp( hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs()); diff --git a/tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc b/tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc index 70c78da7ec3..eb986bc3cb2 100755 --- a/tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc +++ b/tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc @@ -37,12 +37,9 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "mlir/disc/IR/disc_ral_ops.h" #include "mlir/disc/IR/disc_shape_ops.h" -#include "mlir/disc/IR/hlo_disc_ops.h" #include "mlir/disc/IR/lhlo_disc_ops.h" #include "mlir/disc/transforms/PassDetail.h" -#include "mlir/disc/transforms/disc_map_hlo_to_lhlo_op.h" #include "mlir/disc/transforms/placement_utils.h" #include "mlir/disc/transforms/rewriters.h" #include "mlir/disc/transforms/shape_utils.h" @@ -57,19 +54,18 @@ namespace { template using BaseOpConversion = OpConversionPattern; -struct LhloDISCOptimizationBarrierOpConverter : public OpRewritePattern { +struct LhloDISCOptimizationBarrierOpConverter + : public OpRewritePattern { explicit LhloDISCOptimizationBarrierOpConverter(MLIRContext* context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(lmhlo_disc::OptimizationBarrierOp lhloOp, PatternRewriter& rewriter) const override { - - llvm::dbgs() << "Expand OptimizationBarrierPass \n"; Operation* op = lhloOp.getOperation(); auto operands = op->getOperands(); auto results = op->getResults(); - for(int i=0; i { + : public DiscOptimizationBarrierExpandPassBase< + DiscOptimizationBarrierExpandPass> { using DiscOptimizationBarrierExpandPassBase< DiscOptimizationBarrierExpandPass>::DiscOptimizationBarrierExpandPassBase; void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry.insert(); } public: @@ -108,7 +104,8 @@ struct DiscOptimizationBarrierExpandPass }; } // namespace -std::unique_ptr> createDiscOptimizationBarrierExpandPass() { +std::unique_ptr> +createDiscOptimizationBarrierExpandPass() { return std::make_unique(); } diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index d96ae4e63e6..7b97ed610df 100755 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -662,4 +662,9 @@ def DiscEraseBufferDeallocationPass : Pass<"disc-erase-buffer-deallocation", "ml def DiscInputOutputAliasPass : Pass<"disc-input-output-alias", "ModuleOp"> { let summary = "Input and output alias information for buffer reuse"; let constructor = "createDiscInputOutputAliasPass()"; +} + +def DiscReduceBufferLiveRangePass : Pass<"disc-reduce-buffer-live-range", "mlir::func::FuncOp"> { + let summary = "reduce buffer live range"; + let constructor = "createDiscReduceBufferLiveRangePass()"; } \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc b/tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc new file mode 100644 index 00000000000..c153ca25a76 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc @@ -0,0 +1,90 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" + +namespace mlir { +namespace disc_ral { + +using lmhlo::FusionOp; +using memref::AllocOp; + +namespace { + +LogicalResult moveBufferAllocator(AllocOp allocOp) { + Value alloc = allocOp.getResult(); + BufferViewFlowAnalysis aliasAnalysis(allocOp); + PostDominanceInfo postDominators(allocOp); + auto aliasesSet = aliasAnalysis.resolve(alloc); + // Determine the actual block to place the alloc and get liveness + // information. + Block* placementBlock = + bufferization::BufferPlacementTransformationBase::findCommonDominator( + alloc, aliasesSet, postDominators); + + Operation* toMoveBefore = nullptr; + for (auto user : alloc.getUsers()) { + if (isa(user)) continue; + // user maybe in the sub-block of the placementBlock, + // find the closest parent op inside of placementBlock + while (user->getBlock() != placementBlock) { + user = user->getParentOp(); + } + if (toMoveBefore == nullptr || user->isBeforeInBlock(toMoveBefore)) { + toMoveBefore = user; + } + } + allocOp->moveBefore(toMoveBefore); + return success(); +} + +struct DiscReduceBufferLiveRangePass + : public DiscReduceBufferLiveRangePassBase { + void runOnOperation() override { + SmallVector candidateBuffers; + func::FuncOp func = getOperation(); + + func.walk([&](AllocOp op) { candidateBuffers.push_back(op); }); + + for (int i = 0; i < candidateBuffers.size(); ++i) { + if (failed(moveBufferAllocator(candidateBuffers[i]))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> +createDiscReduceBufferLiveRangePass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h old mode 100755 new mode 100644 index 7f1e8c71a40..e533d442ac2 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -328,6 +328,7 @@ createDiscEraseBufferDeallocationPass(); // Insert ArgsMutationOp for buffer reuse std::unique_ptr> createDiscInputOutputAliasPass(); +std::unique_ptr> createDiscReduceBufferLiveRangePass(); } // namespace disc_ral } // namespace mlir @@ -339,7 +340,8 @@ std::unique_ptr> createDiscLegalizeToLhloPass(); std::unique_ptr> createDiscLhloRewriterPass(); -std::unique_ptr> createDiscOptimizationBarrierExpandPass(); +std::unique_ptr> +createDiscOptimizationBarrierExpandPass(); } // namespace mhlo_disc } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir index d059472f4c0..60447bab9ec 100644 --- a/tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir @@ -4,10 +4,19 @@ // CHECK-LABEL: @optimization_barrier_expand func.func @optimization_barrier_expand(%arg0 : tensor<1x2048x4096xf32>, %arg1: tensor<1x2048x4096xf32>) -> tensor<2048x4096xf16> { + // CHECK: %alloc = memref.alloc() : memref<1x2048x4096xf32> + // CHECK: "lmhlo.add"(%arg0, %arg1, %alloc) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> () %1 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32> + // CHECK: %alloc_0 = memref.alloc() : memref<1x2048x4096xf32> + // CHECK: "lmhlo.add"(%arg0, %arg1, %alloc_0) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> () %2 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32> + // CHECK: %alloc_1 = memref.alloc() : memref<1x2048x4096xf16> + // CHECK: "lmhlo.convert"(%alloc_0, %alloc_1) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf16>) -> () %3:2 = "mhlo.optimization_barrier"(%1, %2): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) %4 = "mhlo.convert"(%3#1): (tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf16> + // CHECK: %alloc_2 = memref.alloc() : memref<2048x4096xf16> + // CHECK: "lmhlo.reshape"(%alloc_1, %alloc_2) : (memref<1x2048x4096xf16>, memref<2048x4096xf16>) -> () %5 = "mhlo.reshape"(%4) : (tensor<1x2048x4096xf16>) -> tensor<2048x4096xf16> + // CHECK: return %alloc_2 : memref<2048x4096xf16> return %5: tensor<2048x4096xf16> } \ No newline at end of file