diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 8fdf954aae7..8c77f2bfb20 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -2357,6 +2357,34 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_argsmutation_expand", + srcs = ["transforms/disc_argsmutation_expand.cc"], + hdrs = [ + "transforms/passes.h", + "transforms/rewriters.h", + ], + deps = [ + ":lmhlo_disc", + ":pass_details", + ":placement_utils", + ":shape_utils", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:SCFDialect", + ], + alwayslink = 1, +) + cc_library( name = "all_passes", hdrs = [ @@ -2370,6 +2398,7 @@ cc_library( ":disc_dot_merge", ":disc_quantized_dot_merge", ":disc_algebraic_simplifier", + ":disc_argsmutation_expand", ":disc_assign_kernel_name", ":disc_assign_memory_space", ":disc_bf16_expansion", diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index cc2b535adfc..b55b07366fa 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -624,6 +624,9 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(disc_ral::createLhloFusionInlinerPass()); + // Expand ArgsMutationOp to redirect memory writing target + pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass()); + if (gpu_enabled) { // Lower dot fusion to CUDA. pm.addPass(disc_ral::createDiscCompIntensFusionToCUDASourcePass( diff --git a/tao_compiler/mlir/disc/transforms/disc_argsmutation_expand.cc b/tao_compiler/mlir/disc/transforms/disc_argsmutation_expand.cc new file mode 100644 index 00000000000..ee0b8532327 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_argsmutation_expand.cc @@ -0,0 +1,108 @@ +// Copyright 2021 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements logic for lowering HLO DISC dialect to LHLO DISC +// dialect. + +#include +#include +#include +#include +#include +#include +#include + +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/disc/IR/disc_shape_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +#include "mlir/disc/transforms/PassDetail.h" +#include "mlir/disc/transforms/placement_utils.h" +#include "mlir/disc/transforms/rewriters.h" +#include "mlir/disc/transforms/shape_utils.h" + +namespace mlir { +using placement_utils::kDiscPlaceAssignment; +using placement_utils::kGpu; + +namespace mhlo_disc { +namespace { + +template +using BaseOpConversion = OpConversionPattern; + +struct LhloDISCArgsMutationOpConverter + : public OpRewritePattern { + explicit LhloDISCArgsMutationOpConverter(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(lmhlo_disc::ArgsMutationOp lhloOp, + PatternRewriter& rewriter) const override { + auto op = lhloOp.getOperation(); + auto operands = op->getOperands(); + operands[0].replaceAllUsesWith(operands[1]); + rewriter.eraseOp(op); + return success(); + } +}; + +struct DiscArgsMutationExpandPass + : public DiscArgsMutationExpandPassBase { + using DiscArgsMutationExpandPassBase< + DiscArgsMutationExpandPass>::DiscArgsMutationExpandPassBase; + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + DiscArgsMutationExpandPass() = default; + + void runOnOperation() override { + auto& context = getContext(); + RewritePatternSet patterns(&context); + ConversionTarget target(context); + target.addLegalDialect(); + target.addIllegalOp(); + patterns.insert(&context); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> createDiscArgsMutationExpandPass() { + return std::make_unique(); +} +} // namespace mhlo_disc +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc b/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc old mode 100755 new mode 100644 index 39d8dcadea2..5b3c36ec7cc --- a/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc +++ b/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc @@ -143,15 +143,14 @@ struct DiscInputOutputAliasPass } // DISC now only support one-hop buffer sharing. auto defineOp = outputs[outputs_index[i]].getDefiningOp(); - for (const auto& value : defineOp->getOperands()) { - if (params[params_index[i]] == value) { - builder.setInsertionPointAfterValue(outputs[outputs_index[i]]); - builder.create(main_func.getLoc(), - outputs[outputs_index[i]], - params[params_index[i]]); - break; - } + if (llvm::isa(defineOp)) { + continue; } + + builder.setInsertionPointAfter(defineOp); + builder.create(main_func.getLoc(), + outputs[outputs_index[i]], + params[params_index[i]]); } } }; diff --git a/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc b/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc old mode 100644 new mode 100755 index bdc298bf945..5a08f138708 --- a/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc @@ -66,21 +66,6 @@ Value backtraceOperand(Value operand) { return operand; } -struct LhloArgsMutationOpRewriter - : public OpRewritePattern { - explicit LhloArgsMutationOpRewriter(MLIRContext* context) - : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(lmhlo_disc::ArgsMutationOp lhloOp, - PatternRewriter& rewriter) const override { - auto op = lhloOp.getOperation(); - auto operands = op->getOperands(); - Value value = backtraceOperand(operands[0]); - value.replaceAllUsesWith(operands[1]); - rewriter.eraseOp(op); - return success(); - } -}; - struct LhloConcatenateOpConverter : public OpRewritePattern { explicit LhloConcatenateOpConverter(MLIRContext* context) @@ -195,7 +180,6 @@ struct DiscLhloRewriterPass patterns.insert(&context); patterns.insert(&context); - patterns.insert(&context); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td index c0e1c254bbd..8856db9fbf7 100755 --- a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td @@ -29,3 +29,8 @@ def DiscOptimizationBarrierExpandPass : Pass<"disc-optimization-barrier-expand", let summary = "Expand OptimizationBarrierOp"; let constructor = "createDiscOptimizationBarrierExpandPass()"; } + +def DiscArgsMutationExpandPass : Pass<"disc-argsmutation-expand", "ModuleOp"> { + let summary = "Expand ArgsMutationOp"; + let constructor = "createDiscArgsMutationExpandPass()"; +} diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index 95ad522a62e..c2decbe13a0 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -349,6 +349,8 @@ std::unique_ptr> createDiscLhloRewriterPass(); std::unique_ptr> createDiscOptimizationBarrierExpandPass(); +std::unique_ptr> createDiscArgsMutationExpandPass(); + } // namespace mhlo_disc } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/tests/input-mutation.mlir b/tao_compiler/mlir/disc/transforms/tests/input-mutation.mlir old mode 100644 new mode 100755 index b7266f1c563..92cbc9cecfd --- a/tao_compiler/mlir/disc/transforms/tests/input-mutation.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/input-mutation.mlir @@ -1,4 +1,4 @@ -// RUN: disc-opt %s -disc-hlo-legalize-to-lhlo -hlo-legalize-to-lhlo -canonicalize -disc-lhlo-rewriter -split-input-file | FileCheck %s +// RUN: disc-opt %s -disc-hlo-legalize-to-lhlo -hlo-legalize-to-lhlo -canonicalize -disc-lhlo-rewriter -disc-argsmutation-expand -split-input-file | FileCheck %s func.func @input_mutation(%arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>) -> tensor<8x32xf32> { // CHECK: "lmhlo.add"(%arg0, %arg1, %arg0) : (memref<8x32xf32>, memref<8x32xf32>, memref<8x32xf32>) -> ()