Skip to content

Commit

Permalink
support mhlo.custom_call
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Feb 29, 2024
1 parent 04b394b commit fcbde5c
Showing 1 changed file with 97 additions and 9 deletions.
106 changes: 97 additions & 9 deletions tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "lhlo/IR/lhlo_ops.h"
#include "llvm/Support/Debug.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
Expand Down Expand Up @@ -177,13 +178,14 @@ struct HloToLhloArgsMutationOpConverter
}
};

struct HloToLhloCustomCallOpConverter : public BaseOpConversion<CustomCallOp> {
struct HloToLhloCustomCallOpConverter : public BaseOpConversion<mhlo_disc::CustomCallOp> {
public:
using BaseOpConversion<CustomCallOp>::BaseOpConversion;
using BaseOpConversion<mhlo_disc::CustomCallOp>::BaseOpConversion;

LogicalResult matchAndRewrite(
CustomCallOp hloOp, OpAdaptor adaptor,
mhlo_disc::CustomCallOp hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {

Operation* op = hloOp.getOperation();
auto operands = adaptor.getOperands();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
Expand All @@ -204,12 +206,12 @@ struct HloToLhloCustomCallOpConverter : public BaseOpConversion<CustomCallOp> {
};

struct HloToLhloCustomCallOpV2Converter
: public BaseOpConversion<CustomCallV2Op> {
: public BaseOpConversion<mhlo_disc::CustomCallV2Op> {
public:
using BaseOpConversion<CustomCallV2Op>::BaseOpConversion;
using BaseOpConversion<mhlo_disc::CustomCallV2Op>::BaseOpConversion;

LogicalResult matchAndRewrite(
CustomCallV2Op hloOp, OpAdaptor adaptor,
mhlo_disc::CustomCallV2Op hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
Location loc = hloOp->getLoc();
SmallVector<Type> resultTypes;
Expand All @@ -218,12 +220,96 @@ struct HloToLhloCustomCallOpV2Converter
resultTypes.push_back(
MemRefType::get(ty.getShape(), ty.getElementType()));
}

rewriter.replaceOpWithNewOp<lmhlo_disc::CustomCallV2Op>(
hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs());
hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs());

return success();
}
};

struct MhloToLhloCustomCallOpV2Converter
: public BaseOpConversion<mhlo::CustomCallOp> {
public:
using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;

LogicalResult matchAndRewrite(
mhlo::CustomCallOp hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {

Operation* op = hloOp.getOperation();
auto operands = adaptor.getOperands();

SmallVector<Type> resultTypes;

std::string input_placements, output_placements;
for(int i=0; i<operands.size(); i++){
input_placements += "d,";
}

hloOp->setAttr("call_target_name", rewriter.getStringAttr(hloOp.getCallTargetName()));
hloOp->setAttr("device", rewriter.getStringAttr("x"));
hloOp->setAttr("input_layouts", rewriter.getStringAttr("*"));
hloOp->setAttr("output_layouts", rewriter.getStringAttr("*"));
hloOp->setAttr("expected_input_layouts", rewriter.getStringAttr("*"));
hloOp->setAttr("expected_output_layouts", rewriter.getStringAttr("*"));

SmallVector<NamedAttribute> newAttrs;
if(hloOp.getBackendConfig().has_value()) {
newAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("backend_config"),
hloOp.getBackendConfig().value()));
}
auto newCustomAttrs = DictionaryAttr::get(hloOp->getContext(), newAttrs);
hloOp->setAttr("custom_attrs", newCustomAttrs);

if(hloOp->getNumResults() ==1 && hloOp->getResult(0).getType().dyn_cast<mlir::TupleType>()) {
auto tupleTy = hloOp->getResult(0).getType().dyn_cast<mlir::TupleType>();
for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) {
output_placements += "d,";
resultTypes.push_back(ty.cast<RankedTensorType>());
}
} else{
output_placements="d,";
for (Value v : hloOp->getResults()) {
auto ty = v.getType().cast<RankedTensorType>();
resultTypes.push_back(ty);
}
}

if(!input_placements.empty()) {
input_placements.pop_back();
}
if(!output_placements.empty()) {
output_placements.pop_back();
}

hloOp->setAttr("input_placements", rewriter.getStringAttr(input_placements));
hloOp->setAttr("output_placements", rewriter.getStringAttr(output_placements));

auto custom_v2_op = rewriter.create<mhlo_disc::CustomCallV2Op>(hloOp.getLoc(), resultTypes, operands, hloOp->getAttrs());

// TBD: Is this necessary?
if(hloOp->getNumResults() ==1 && hloOp->getResult(0).getType().dyn_cast<mlir::TupleType>()) {
auto tupleValue = hloOp->getResult(0);
for (int index=0; index<resultTypes.size(); index++) {
for (auto &use : tupleValue.getUses()) {
Operation *consumerOp = use.getOwner();
if(auto getTupleElementOp = llvm::dyn_cast<mhlo::GetTupleElementOp>(consumerOp)){
if(getTupleElementOp.getIndex() == index) {
rewriter.replaceOp(consumerOp, {custom_v2_op.getResult(index)});
break;
}
}
}
}
}

rewriter.replaceOp(hloOp, custom_v2_op.getResults());
return success();
}
};

struct TieShapeOpConverter : public BaseOpConversion<TieShapeOp> {
public:
using BaseOpConversion<TieShapeOp>::BaseOpConversion;
Expand Down Expand Up @@ -257,7 +343,7 @@ struct DiscHloLegalizeToLhlo
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect,
shape::ShapeDialect, bufferization::BufferizationDialect,
lmhlo::LmhloDialect>();
lmhlo::LmhloDialect, mhlo::MhloDialect, mhlo_disc::MhloDiscDialect>();
}

public:
Expand All @@ -270,10 +356,11 @@ struct DiscHloLegalizeToLhlo
target.addLegalDialect<arith::ArithDialect, lmhlo_disc::LmhloDiscDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect, shape::ShapeDialect,
tensor::TensorDialect, lmhlo::LmhloDialect>();
tensor::TensorDialect, lmhlo::LmhloDialect, mhlo::MhloDialect>();
target.addIllegalDialect<mhlo_disc::MhloDiscDialect>();
target.addIllegalOp<disc_shape::TieShapeOp>();
target.addIllegalOp<mhlo_disc::ArgsMutationOp>();
target.addIllegalOp<mhlo::CustomCallOp>();

bufferization::BufferizeTypeConverter converter;
populateDiscHLOToLHLOConversionPattern(&context, &converter, &patterns);
Expand All @@ -299,6 +386,7 @@ void populateDiscHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo_disc::SparseSegmentReductionOp>,
HloToLhloOpConverter<mhlo_disc::SparseSegmentReductionWithEmptyRowsOp>,
HloToLhloOpConverter<mhlo_disc::WhereOp>,
MhloToLhloCustomCallOpV2Converter,
HloToLhloArgsMutationOpConverter,
HloToLhloCustomCallOpConverter,
HloToLhloCustomCallOpV2Converter,
Expand Down

0 comments on commit fcbde5c

Please sign in to comment.