Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into support_transpose_cus…
Browse files Browse the repository at this point in the history
…tom_call
  • Loading branch information
eedalong committed Feb 29, 2024
2 parents 99cc4f4 + 5d0bafd commit 04b394b
Show file tree
Hide file tree
Showing 12 changed files with 449 additions and 162 deletions.
1 change: 1 addition & 0 deletions pytorch_blade/pytorch_blade/compiler/mlir/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cc_library(
"@local_org_torch//:libtorch",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_nccl//:nccl",
]),
copts = select({
"//:enable_rocm": ["-DTORCH_BLADE_USE_ROCM -DTORCH_BLADE_BUILD_WITH_CUDA "],
Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ cc_library(
deps = [
":disc_util",
":pass_details",
":mhlo_disc",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ disc_cc_library(
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cuda_driver",
"@local_config_cuda//cuda:cudart",
"@local_config_nccl//:nccl",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
"//tensorflow/compiler/xla/stream_executor/rocm:rocm_driver",
Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/tools/disc-replay/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ disc_cc_library(
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cuda_driver",
"@local_config_cuda//cuda:cudart",
"@local_config_nccl//:nccl",
])
)

Expand Down
101 changes: 100 additions & 1 deletion tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/disc/IR/hlo_disc_ops.h"
#include "mlir/disc/disc_util.h"
#include "mlir/disc/transforms/PassDetail.h"

Expand Down Expand Up @@ -132,16 +133,114 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op,
return success();
}
} // namespace
namespace {
enum ReductionKind {
ALL_REDUCE_SUM,
ALL_REDUCE_PRODUCT,
ALL_REDUCE_MIN,
ALL_REDUCE_MAX,
};

std::optional<std::string> ReductionKindToString(ReductionKind kind) {
switch (kind) {
case ReductionKind::ALL_REDUCE_SUM:
return "sum";
case ReductionKind::ALL_REDUCE_PRODUCT:
return "product";
case ReductionKind::ALL_REDUCE_MIN:
return "min";
case ReductionKind::ALL_REDUCE_MAX:
return "max";
}
return std::nullopt;
}

std::optional<std::string> MatchReductionComputation(Region& region) {
if (!region.hasOneBlock()) {
return std::nullopt;
}

auto ret = dyn_cast<mhlo::ReturnOp>(region.front().getTerminator());
if (!ret || ret->getNumOperands() != 1) {
return std::nullopt;
}

auto computation = ret.getOperand(0).getDefiningOp();

if (isa<mhlo::AddOp>(computation)) {
return "sum";
}
if (isa<mhlo::MulOp>(computation)) {
return "product";
}
if (isa<mhlo::MinOp>(computation)) {
return "min";
}
if (isa<mhlo::MaxOp>(computation)) {
return "max";
}
return std::nullopt;
}

struct CollectiveOpConverter : public OpRewritePattern<mhlo::AllReduceOp> {
using OpRewritePattern<mhlo::AllReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::AllReduceOp op,
PatternRewriter& rewriter) const override {
SmallVector<Value, 4> newOutputs;
auto reductionKind = MatchReductionComputation(op.getRegion());
if (!reductionKind) {
return failure();
}
for (int i = 0; i < op->getOperands().size(); ++i) {
// no need call all_reduce op if no consumer
if (op->getResult(i).getUsers().empty()) {
continue;
}

op->setAttr("call_target_name", rewriter.getStringAttr("ral_all_reduce"));
op->setAttr("device", rewriter.getStringAttr("d"));
op->setAttr("input_placements", rewriter.getStringAttr("d"));
op->setAttr("output_placements", rewriter.getStringAttr("d"));
op->setAttr("input_layouts", rewriter.getStringAttr("*"));
op->setAttr("output_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_input_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_output_layouts", rewriter.getStringAttr("*"));
SmallVector<NamedAttribute> newAttrs;
newAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("reduction_kind"),
rewriter.getStringAttr(reductionKind.value())));

auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs);

op->setAttr("custom_attrs", newCustomAttrs);

auto newOutput = rewriter.create<mhlo_disc::CustomCallV2Op>(
op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i],
op->getAttrs());
newOutputs.push_back(newOutput.getResult(0));
}
rewriter.replaceOp(op, newOutputs);
return success();
}
};
} // namespace

struct MhloDecompositionRewriterPass
: public MhloDecompositionRewriterPassBase<MhloDecompositionRewriterPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::mhlo_disc::MhloDiscDialect>();
}
void runOnOperation() override {
func::FuncOp func = getOperation();
MLIRContext* ctx = func.getContext();
// MLIRContext* ctx = func.getContext();
MLIRContext* ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.insert<BatchNormInferenceOpConvert>(ctx);
patterns.insert<PadOpConvert>(ctx);
patterns.insert<SliceOpConvert>(ctx);
patterns.insert<CollectiveOpConverter>(ctx);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
func.emitError("applyPatternsAndFoldGreedily does not converge");
signalPassFailure();
Expand Down
11 changes: 11 additions & 0 deletions tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,14 @@ func.func @batch_norm_inference(%arg0: tensor<?x128x?x?xf32>, %arg1: tensor<128x
%0 = "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x128x?x?xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<?x128x?x?xf32>
return %0: tensor<?x128x?x?xf32>
}

func.func @main(%arg0: tensor<f32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<f32>) {
// CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %1 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<f32>) -> tensor<f32>
%0:2 = "mhlo.all_reduce"(%arg1, %arg0) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
mhlo.return %1 : tensor<f32>
}) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>, tensor<f32>) -> (tensor<4xf32>, tensor<f32>)
return %0#0, %0#1 : tensor<4xf32>, tensor<f32>
}
Loading

0 comments on commit 04b394b

Please sign in to comment.