diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index bccadfd8565..98101783738 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -2336,6 +2336,34 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_argsmutation_expand", + srcs = ["transforms/disc_argsmutation_expand.cc"], + hdrs = [ + "transforms/passes.h", + "transforms/rewriters.h", + ], + deps = [ + ":lmhlo_disc", + ":pass_details", + ":placement_utils", + ":shape_utils", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:SCFDialect", + ], + alwayslink = 1, +) + cc_library( name = "all_passes", hdrs = [ @@ -2349,6 +2377,7 @@ cc_library( ":disc_dot_merge", ":disc_quantized_dot_merge", ":disc_algebraic_simplifier", + ":disc_argsmutation_expand", ":disc_assign_kernel_name", ":disc_assign_memory_space", ":disc_bf16_expansion", diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index 9e681c8b738..df3c7e1c1c5 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -623,6 +623,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(disc_ral::createLhloFusionInlinerPass()); + pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass()); + if (gpu_enabled) { // Lower dot fusion to CUDA. pm.addPass(disc_ral::createDiscCompIntensFusionToCUDASourcePass( diff --git a/tao_compiler/mlir/disc/transforms/disc_argmutation_expand.cc b/tao_compiler/mlir/disc/transforms/disc_argmutation_expand.cc new file mode 100644 index 00000000000..154d77d72ef --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_argmutation_expand.cc @@ -0,0 +1,109 @@ +// Copyright 2021 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements logic for lowering HLO DISC dialect to LHLO DISC +// dialect. + +#include +#include +#include +#include +#include +#include +#include + +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/disc/IR/disc_shape_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +#include "mlir/disc/transforms/PassDetail.h" +#include "mlir/disc/transforms/placement_utils.h" +#include "mlir/disc/transforms/rewriters.h" +#include "mlir/disc/transforms/shape_utils.h" + +namespace mlir { +using placement_utils::kDiscPlaceAssignment; +using placement_utils::kGpu; + +namespace mhlo_disc { +namespace { + +template +using BaseOpConversion = OpConversionPattern; + +struct LhloDISCArgsMutationOpConverter + : public OpRewritePattern { + explicit LhloDISCArgsMutationOpConverter(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(lmhlo_disc::ArgsMutationOp lhloOp, + PatternRewriter& rewriter) const override { + auto op = lhloOp.getOperation(); + auto operands = op->getOperands(); + // Value value = backtraceOperand(operands[0]); + operands[0].replaceAllUsesWith(operands[1]); + rewriter.eraseOp(op); + return success(); + } +}; + +struct DiscArgsMutationExpandPass + : public DiscArgsMutationExpandPassBase { + using DiscArgsMutationExpandPassBase::DiscArgsMutationExpandPassBase; + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + DiscArgsMutationExpandPass() = default; + + void runOnOperation() override { + auto& context = getContext(); + RewritePatternSet patterns(&context); + ConversionTarget target(context); + target.addLegalDialect(); + target.addIllegalOp(); + patterns.insert(&context); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +createDiscArgsMutationExpandPass() { + return std::make_unique(); +} +} // namespace mhlo_disc +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc b/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc old mode 100755 new mode 100644 index 39d8dcadea2..5b3c36ec7cc --- a/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc +++ b/tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc @@ -143,15 +143,14 @@ struct DiscInputOutputAliasPass } // DISC now only support one-hop buffer sharing. auto defineOp = outputs[outputs_index[i]].getDefiningOp(); - for (const auto& value : defineOp->getOperands()) { - if (params[params_index[i]] == value) { - builder.setInsertionPointAfterValue(outputs[outputs_index[i]]); - builder.create(main_func.getLoc(), - outputs[outputs_index[i]], - params[params_index[i]]); - break; - } + if (llvm::isa(defineOp)) { + continue; } + + builder.setInsertionPointAfter(defineOp); + builder.create(main_func.getLoc(), + outputs[outputs_index[i]], + params[params_index[i]]); } } }; diff --git a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td index c0e1c254bbd..8856db9fbf7 100755 --- a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td @@ -29,3 +29,8 @@ def DiscOptimizationBarrierExpandPass : Pass<"disc-optimization-barrier-expand", let summary = "Expand OptimizationBarrierOp"; let constructor = "createDiscOptimizationBarrierExpandPass()"; } + +def DiscArgsMutationExpandPass : Pass<"disc-argsmutation-expand", "ModuleOp"> { + let summary = "Expand ArgsMutationOp"; + let constructor = "createDiscArgsMutationExpandPass()"; +} diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index e533d442ac2..003648d149e 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -343,6 +343,8 @@ std::unique_ptr> createDiscLhloRewriterPass(); std::unique_ptr> createDiscOptimizationBarrierExpandPass(); +std::unique_ptr> createDiscArgsMutationExpandPass(); + } // namespace mhlo_disc } // namespace mlir