From 9577b9dc595272c6788bac1d3e551886d91a458c Mon Sep 17 00:00:00 2001 From: eedalong Date: Thu, 29 Feb 2024 13:01:51 +0800 Subject: [PATCH] support mhlo.custom_call --- .pre-commit-config.yaml | 14 --- .../transforms/disc_hlo_legalize_to_lhlo.cc | 110 ++++++++++++++++-- 2 files changed, 102 insertions(+), 22 deletions(-) delete mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 2e69969b709..00000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,14 +0,0 @@ -repos: -- repo: local - hooks: - - id: copyright_checker - name: copyright_checker - entry: python ./scripts/pre-commit/copyright.py - language: system - files: \.(cc|cpp|h|py|sh)$ -- repo: https://github.com/pre-commit/mirrors-clang-format - rev: v10.0.1 - hooks: - - id: clang-format - name: clang-format - diff --git a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc index 55138af799d..1cb4c695d21 100644 --- a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc +++ b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include "lhlo/IR/lhlo_ops.h" #include "llvm/Support/Debug.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -177,12 +178,13 @@ struct HloToLhloArgsMutationOpConverter } }; -struct HloToLhloCustomCallOpConverter : public BaseOpConversion { +struct HloToLhloCustomCallOpConverter + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - CustomCallOp hloOp, OpAdaptor adaptor, + mhlo_disc::CustomCallOp hloOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Operation* op = hloOp.getOperation(); auto operands = adaptor.getOperands(); @@ -204,12 +206,12 @@ struct HloToLhloCustomCallOpConverter : public BaseOpConversion { }; struct HloToLhloCustomCallOpV2Converter - : public BaseOpConversion { + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - CustomCallV2Op hloOp, OpAdaptor adaptor, + mhlo_disc::CustomCallV2Op hloOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Location loc = hloOp->getLoc(); SmallVector resultTypes; @@ -218,8 +220,96 @@ struct HloToLhloCustomCallOpV2Converter resultTypes.push_back( MemRefType::get(ty.getShape(), ty.getElementType())); } + rewriter.replaceOpWithNewOp( hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs()); + + return success(); + } +}; + +struct CustomCallOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::CustomCallOp hloOp, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Operation* op = hloOp.getOperation(); + auto operands = adaptor.getOperands(); + + SmallVector resultTypes; + + std::string input_placements, output_placements; + for (int i = 0; i < operands.size(); i++) { + input_placements += "d,"; + } + + hloOp->setAttr("call_target_name", + rewriter.getStringAttr(hloOp.getCallTargetName())); + hloOp->setAttr("device", rewriter.getStringAttr("x")); + hloOp->setAttr("input_layouts", rewriter.getStringAttr("*")); + hloOp->setAttr("output_layouts", rewriter.getStringAttr("*")); + hloOp->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); + hloOp->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); + + SmallVector newAttrs; + if (hloOp.getBackendConfig().has_value()) { + newAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("backend_config"), + hloOp.getBackendConfig().value())); + } + auto newCustomAttrs = DictionaryAttr::get(hloOp->getContext(), newAttrs); + hloOp->setAttr("custom_attrs", newCustomAttrs); + + if (hloOp->getNumResults() == 1 && + hloOp->getResult(0).getType().dyn_cast()) { + auto tupleTy = hloOp->getResult(0).getType().dyn_cast(); + for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) { + output_placements += "d,"; + resultTypes.push_back(ty.cast()); + } + } else { + output_placements = "d,"; + for (Value v : hloOp->getResults()) { + auto ty = v.getType().cast(); + resultTypes.push_back(ty); + } + } + + if (!input_placements.empty()) { + input_placements.pop_back(); + } + if (!output_placements.empty()) { + output_placements.pop_back(); + } + + hloOp->setAttr("input_placements", + rewriter.getStringAttr(input_placements)); + hloOp->setAttr("output_placements", + rewriter.getStringAttr(output_placements)); + + auto custom_v2_op = rewriter.create( + hloOp.getLoc(), resultTypes, operands, hloOp->getAttrs()); + + if (hloOp->getNumResults() == 1 && + hloOp->getResult(0).getType().dyn_cast()) { + auto tupleValue = hloOp->getResult(0); + for (int index = 0; index < resultTypes.size(); index++) { + for (auto& use : tupleValue.getUses()) { + Operation* consumerOp = use.getOwner(); + if (auto getTupleElementOp = + llvm::dyn_cast(consumerOp)) { + if (getTupleElementOp.getIndex() == index) { + rewriter.replaceOp(consumerOp, {custom_v2_op.getResult(index)}); + break; + } + } + } + } + } + + rewriter.replaceOp(hloOp, custom_v2_op.getResults()); return success(); } }; @@ -257,7 +347,8 @@ struct DiscHloLegalizeToLhlo void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); + lmhlo::LmhloDialect, mhlo::MhloDialect, + mhlo_disc::MhloDiscDialect>(); } public: @@ -270,10 +361,12 @@ struct DiscHloLegalizeToLhlo target.addLegalDialect(); + tensor::TensorDialect, lmhlo::LmhloDialect, + mhlo::MhloDialect>(); target.addIllegalDialect(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); bufferization::BufferizeTypeConverter converter; populateDiscHLOToLHLOConversionPattern(&context, &converter, &patterns); @@ -290,6 +383,7 @@ void populateDiscHLOToLHLOConversionPattern( RewritePatternSet* patterns) { // clang-format off patterns->insert< + CustomCallOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter,