Skip to content

Commit

Permalink
support mhlo.custom_call
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Feb 29, 2024
1 parent 04b394b commit 70ae32f
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 8 deletions.
Empty file modified .pre-commit-config.yaml
100644 → 100755
Empty file.
110 changes: 102 additions & 8 deletions tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "lhlo/IR/lhlo_ops.h"
#include "llvm/Support/Debug.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
Expand Down Expand Up @@ -177,12 +178,13 @@ struct HloToLhloArgsMutationOpConverter
}
};

struct HloToLhloCustomCallOpConverter : public BaseOpConversion<CustomCallOp> {
struct HloToLhloCustomCallOpConverter
: public BaseOpConversion<mhlo_disc::CustomCallOp> {
public:
using BaseOpConversion<CustomCallOp>::BaseOpConversion;
using BaseOpConversion<mhlo_disc::CustomCallOp>::BaseOpConversion;

LogicalResult matchAndRewrite(
CustomCallOp hloOp, OpAdaptor adaptor,
mhlo_disc::CustomCallOp hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
Operation* op = hloOp.getOperation();
auto operands = adaptor.getOperands();
Expand All @@ -204,12 +206,12 @@ struct HloToLhloCustomCallOpConverter : public BaseOpConversion<CustomCallOp> {
};

struct HloToLhloCustomCallOpV2Converter
: public BaseOpConversion<CustomCallV2Op> {
: public BaseOpConversion<mhlo_disc::CustomCallV2Op> {
public:
using BaseOpConversion<CustomCallV2Op>::BaseOpConversion;
using BaseOpConversion<mhlo_disc::CustomCallV2Op>::BaseOpConversion;

LogicalResult matchAndRewrite(
CustomCallV2Op hloOp, OpAdaptor adaptor,
mhlo_disc::CustomCallV2Op hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
Location loc = hloOp->getLoc();
SmallVector<Type> resultTypes;
Expand All @@ -218,8 +220,96 @@ struct HloToLhloCustomCallOpV2Converter
resultTypes.push_back(
MemRefType::get(ty.getShape(), ty.getElementType()));
}

rewriter.replaceOpWithNewOp<lmhlo_disc::CustomCallV2Op>(
hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs());

return success();
}
};

struct CustomCallOpConverter : public BaseOpConversion<mhlo::CustomCallOp> {
public:
using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;

LogicalResult matchAndRewrite(
mhlo::CustomCallOp hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
Operation* op = hloOp.getOperation();
auto operands = adaptor.getOperands();

SmallVector<Type> resultTypes;

std::string input_placements, output_placements;
for (int i = 0; i < operands.size(); i++) {
input_placements += "d,";
}

hloOp->setAttr("call_target_name",
rewriter.getStringAttr(hloOp.getCallTargetName()));
hloOp->setAttr("device", rewriter.getStringAttr("x"));
hloOp->setAttr("input_layouts", rewriter.getStringAttr("*"));
hloOp->setAttr("output_layouts", rewriter.getStringAttr("*"));
hloOp->setAttr("expected_input_layouts", rewriter.getStringAttr("*"));
hloOp->setAttr("expected_output_layouts", rewriter.getStringAttr("*"));

SmallVector<NamedAttribute> newAttrs;
if (hloOp.getBackendConfig().has_value()) {
newAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("backend_config"),
hloOp.getBackendConfig().value()));
}
auto newCustomAttrs = DictionaryAttr::get(hloOp->getContext(), newAttrs);
hloOp->setAttr("custom_attrs", newCustomAttrs);

if (hloOp->getNumResults() == 1 &&
hloOp->getResult(0).getType().dyn_cast<mlir::TupleType>()) {
auto tupleTy = hloOp->getResult(0).getType().dyn_cast<mlir::TupleType>();
for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) {
output_placements += "d,";
resultTypes.push_back(ty.cast<RankedTensorType>());
}
} else {
output_placements = "d,";
for (Value v : hloOp->getResults()) {
auto ty = v.getType().cast<RankedTensorType>();
resultTypes.push_back(ty);
}
}

if (!input_placements.empty()) {
input_placements.pop_back();
}
if (!output_placements.empty()) {
output_placements.pop_back();
}

hloOp->setAttr("input_placements",
rewriter.getStringAttr(input_placements));
hloOp->setAttr("output_placements",
rewriter.getStringAttr(output_placements));

auto custom_v2_op = rewriter.create<mhlo_disc::CustomCallV2Op>(
hloOp.getLoc(), resultTypes, operands, hloOp->getAttrs());

if (hloOp->getNumResults() == 1 &&
hloOp->getResult(0).getType().dyn_cast<mlir::TupleType>()) {
auto tupleValue = hloOp->getResult(0);
for (int index = 0; index < resultTypes.size(); index++) {
for (auto& use : tupleValue.getUses()) {
Operation* consumerOp = use.getOwner();
if (auto getTupleElementOp =
llvm::dyn_cast<mhlo::GetTupleElementOp>(consumerOp)) {
if (getTupleElementOp.getIndex() == index) {
rewriter.replaceOp(consumerOp, {custom_v2_op.getResult(index)});
break;
}
}
}
}
}

rewriter.replaceOp(hloOp, custom_v2_op.getResults());
return success();
}
};
Expand Down Expand Up @@ -257,7 +347,8 @@ struct DiscHloLegalizeToLhlo
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect,
shape::ShapeDialect, bufferization::BufferizationDialect,
lmhlo::LmhloDialect>();
lmhlo::LmhloDialect, mhlo::MhloDialect,
mhlo_disc::MhloDiscDialect>();
}

public:
Expand All @@ -270,10 +361,12 @@ struct DiscHloLegalizeToLhlo
target.addLegalDialect<arith::ArithDialect, lmhlo_disc::LmhloDiscDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect, shape::ShapeDialect,
tensor::TensorDialect, lmhlo::LmhloDialect>();
tensor::TensorDialect, lmhlo::LmhloDialect,
mhlo::MhloDialect>();
target.addIllegalDialect<mhlo_disc::MhloDiscDialect>();
target.addIllegalOp<disc_shape::TieShapeOp>();
target.addIllegalOp<mhlo_disc::ArgsMutationOp>();
target.addIllegalOp<mhlo::CustomCallOp>();

bufferization::BufferizeTypeConverter converter;
populateDiscHLOToLHLOConversionPattern(&context, &converter, &patterns);
Expand All @@ -290,6 +383,7 @@ void populateDiscHLOToLHLOConversionPattern(
RewritePatternSet* patterns) {
// clang-format off
patterns->insert<
CustomCallOpConverter,
HloToLhloOpConverter<mhlo_disc::H2DOp>,
HloToLhloOpConverter<mhlo_disc::D2HOp>,
HloToLhloOpConverter<mhlo_disc::QuantizedDotGeneralOp>,
Expand Down

0 comments on commit 70ae32f

Please sign in to comment.