From fcbde5c43119b0791020bf5328aa46679fda6f08 Mon Sep 17 00:00:00 2001 From: eedalong Date: Thu, 29 Feb 2024 13:01:51 +0800 Subject: [PATCH] support mhlo.custom_call --- .../transforms/disc_hlo_legalize_to_lhlo.cc | 106 ++++++++++++++++-- 1 file changed, 97 insertions(+), 9 deletions(-) mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc diff --git a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc old mode 100644 new mode 100755 index 55138af799d..5b739f81402 --- a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc +++ b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include "lhlo/IR/lhlo_ops.h" #include "llvm/Support/Debug.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -177,13 +178,14 @@ struct HloToLhloArgsMutationOpConverter } }; -struct HloToLhloCustomCallOpConverter : public BaseOpConversion { +struct HloToLhloCustomCallOpConverter : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - CustomCallOp hloOp, OpAdaptor adaptor, + mhlo_disc::CustomCallOp hloOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + Operation* op = hloOp.getOperation(); auto operands = adaptor.getOperands(); SmallVector buffer_args(operands.begin(), operands.end()); @@ -204,12 +206,12 @@ struct HloToLhloCustomCallOpConverter : public BaseOpConversion { }; struct HloToLhloCustomCallOpV2Converter - : public BaseOpConversion { + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - CustomCallV2Op hloOp, OpAdaptor adaptor, + mhlo_disc::CustomCallV2Op hloOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Location loc = hloOp->getLoc(); SmallVector resultTypes; @@ -218,12 +220,96 @@ struct HloToLhloCustomCallOpV2Converter resultTypes.push_back( MemRefType::get(ty.getShape(), ty.getElementType())); } + rewriter.replaceOpWithNewOp( - hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs()); + hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs()); + return success(); } }; +struct MhloToLhloCustomCallOpV2Converter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::CustomCallOp hloOp, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + + Operation* op = hloOp.getOperation(); + auto operands = adaptor.getOperands(); + + SmallVector resultTypes; + + std::string input_placements, output_placements; + for(int i=0; isetAttr("call_target_name", rewriter.getStringAttr(hloOp.getCallTargetName())); + hloOp->setAttr("device", rewriter.getStringAttr("x")); + hloOp->setAttr("input_layouts", rewriter.getStringAttr("*")); + hloOp->setAttr("output_layouts", rewriter.getStringAttr("*")); + hloOp->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); + hloOp->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); + + SmallVector newAttrs; + if(hloOp.getBackendConfig().has_value()) { + newAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("backend_config"), + hloOp.getBackendConfig().value())); + } + auto newCustomAttrs = DictionaryAttr::get(hloOp->getContext(), newAttrs); + hloOp->setAttr("custom_attrs", newCustomAttrs); + + if(hloOp->getNumResults() ==1 && hloOp->getResult(0).getType().dyn_cast()) { + auto tupleTy = hloOp->getResult(0).getType().dyn_cast(); + for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) { + output_placements += "d,"; + resultTypes.push_back(ty.cast()); + } + } else{ + output_placements="d,"; + for (Value v : hloOp->getResults()) { + auto ty = v.getType().cast(); + resultTypes.push_back(ty); + } + } + + if(!input_placements.empty()) { + input_placements.pop_back(); + } + if(!output_placements.empty()) { + output_placements.pop_back(); + } + + hloOp->setAttr("input_placements", rewriter.getStringAttr(input_placements)); + hloOp->setAttr("output_placements", rewriter.getStringAttr(output_placements)); + + auto custom_v2_op = rewriter.create(hloOp.getLoc(), resultTypes, operands, hloOp->getAttrs()); + + // TBD: Is this necessary? + if(hloOp->getNumResults() ==1 && hloOp->getResult(0).getType().dyn_cast()) { + auto tupleValue = hloOp->getResult(0); + for (int index=0; index(consumerOp)){ + if(getTupleElementOp.getIndex() == index) { + rewriter.replaceOp(consumerOp, {custom_v2_op.getResult(index)}); + break; + } + } + } + } + } + + rewriter.replaceOp(hloOp, custom_v2_op.getResults()); + return success(); + } +}; + struct TieShapeOpConverter : public BaseOpConversion { public: using BaseOpConversion::BaseOpConversion; @@ -257,7 +343,7 @@ struct DiscHloLegalizeToLhlo void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); + lmhlo::LmhloDialect, mhlo::MhloDialect, mhlo_disc::MhloDiscDialect>(); } public: @@ -270,10 +356,11 @@ struct DiscHloLegalizeToLhlo target.addLegalDialect(); + tensor::TensorDialect, lmhlo::LmhloDialect, mhlo::MhloDialect>(); target.addIllegalDialect(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); bufferization::BufferizeTypeConverter converter; populateDiscHLOToLHLOConversionPattern(&context, &converter, &patterns); @@ -299,6 +386,7 @@ void populateDiscHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + MhloToLhloCustomCallOpV2Converter, HloToLhloArgsMutationOpConverter, HloToLhloCustomCallOpConverter, HloToLhloCustomCallOpV2Converter,