Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DiscCollectiveOpsPass and related collective ops #1288

Merged
merged 5 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,27 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_collective_ops_rewriter",
srcs = ["transforms/disc_collective_ops_rewriter.cc"],
hdrs = ["transforms/passes.h"],
includes = [
"tensorflow/compiler/xla/mlir_hlo/include",
"."
],
deps = [
":disc_util",
":pass_details",
":mhlo_disc",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

cc_library(
name = "mhlo_decomp_rewriters",
srcs = ["transforms/mhlo_decomp_rewriters.cc"],
Expand Down Expand Up @@ -2407,6 +2428,7 @@ cc_library(
":input_inline_fusion",
":lhlo_fusion_inliner",
":mhlo_decomp_rewriters",
":disc_collective_ops_rewriter",
":mhlo_mark_shape_calc",
":mhlo_placer",
":ral_inject_execution_context",
Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addPass(disc_ral::createDiscInputOutputAliasPass());
pm.addPass(mlir::createInlinerPass());
// TODO(disc): Lower HLO shape constraints instead of eliding them here.
pm.addNestedPass<FuncOp>(disc_ral::createDiscCollectiveOpsRewriterPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscMhloDecompositionRewriterPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscRemoveShapeConstraintsPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
Expand Down
259 changes: 259 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
// Copyright 2021 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <string>
#include <unordered_set>
#include <vector>

#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/disc/IR/hlo_disc_ops.h"
#include "mlir/disc/disc_util.h"
#include "mlir/disc/transforms/PassDetail.h"

namespace mlir {
namespace disc_ral {

namespace {
enum ReductionKind {
ALL_REDUCE_SUM,
ALL_REDUCE_PRODUCT,
ALL_REDUCE_MIN,
ALL_REDUCE_MAX,
};

std::optional<std::string> ReductionKindToString(ReductionKind kind) {
switch (kind) {
case ReductionKind::ALL_REDUCE_SUM:
return "sum";
case ReductionKind::ALL_REDUCE_PRODUCT:
return "product";
case ReductionKind::ALL_REDUCE_MIN:
return "min";
case ReductionKind::ALL_REDUCE_MAX:
return "max";
}
return std::nullopt;
}

std::optional<std::string> MatchReductionComputation(Region& region) {
if (!region.hasOneBlock()) {
return std::nullopt;
}

auto ret = dyn_cast<mhlo::ReturnOp>(region.front().getTerminator());
if (!ret || ret->getNumOperands() != 1) {
return std::nullopt;
}

auto computation = ret.getOperand(0).getDefiningOp();

if (isa<mhlo::AddOp>(computation)) {
return "sum";
}
if (isa<mhlo::MulOp>(computation)) {
return "product";
}
if (isa<mhlo::MinOp>(computation)) {
return "min";
}
if (isa<mhlo::MaxOp>(computation)) {
return "max";
}
return std::nullopt;
}

struct AllReduceOpConverter : public OpRewritePattern<mhlo::AllReduceOp> {
using OpRewritePattern<mhlo::AllReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::AllReduceOp op,
PatternRewriter& rewriter) const override {
SmallVector<Value, 4> newOutputs;
auto reductionKind = MatchReductionComputation(op.getRegion());
if (!reductionKind) {
return failure();
}
for (int i = 0; i < op->getOperands().size(); ++i) {
// no need call all_reduce op if no consumer
if (op->getResult(i).getUsers().empty()) {
continue;
}

op->setAttr("call_target_name", rewriter.getStringAttr("ral_all_reduce"));
op->setAttr("device", rewriter.getStringAttr("d"));
op->setAttr("input_placements", rewriter.getStringAttr("d"));
op->setAttr("output_placements", rewriter.getStringAttr("d"));
op->setAttr("input_layouts", rewriter.getStringAttr("*"));
op->setAttr("output_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_input_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_output_layouts", rewriter.getStringAttr("*"));
SmallVector<NamedAttribute> newAttrs;
newAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("reduction_kind"),
rewriter.getStringAttr(reductionKind.value())));

auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs);

op->setAttr("custom_attrs", newCustomAttrs);

auto newOutput = rewriter.create<mhlo_disc::CustomCallV2Op>(
op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i],
op->getAttrs());
newOutputs.push_back(newOutput.getResult(0));
}
rewriter.replaceOp(op, newOutputs);
return success();
}
};

struct AllGatherOpConverter : public OpRewritePattern<mhlo::AllGatherOp> {
using OpRewritePattern<mhlo::AllGatherOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::AllGatherOp op,
PatternRewriter& rewriter) const override {
SmallVector<NamedAttribute, 4> newAttrsVec;
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr("ral_all_gather")));
newAttrsVec.push_back(NamedAttribute(rewriter.getStringAttr("device"),
rewriter.getStringAttr("d")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("input_placements"),
rewriter.getStringAttr("d")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("output_placements"),
rewriter.getStringAttr("d")));
newAttrsVec.push_back(NamedAttribute(
rewriter.getStringAttr("input_layouts"), rewriter.getStringAttr("*")));
newAttrsVec.push_back(NamedAttribute(
rewriter.getStringAttr("output_layouts"), rewriter.getStringAttr("*")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("expected_input_layouts"),
rewriter.getStringAttr("*")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("expected_output_layouts"),
rewriter.getStringAttr("*")));

SmallVector<NamedAttribute, 4> customAttrs;
customAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("all_gather_dim"),
op->getAttr("all_gather_dim")));
customAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("replica_groups"),
op->getAttr("replica_groups")));

newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("custom_attrs"),
rewriter.getDictionaryAttr(customAttrs)));

ArrayRef<NamedAttribute> newCustomAttrs(newAttrsVec);
auto newOp = rewriter.create<mhlo_disc::CustomCallV2Op>(
op->getLoc(), op->getResultTypes(), op->getOperands(), newCustomAttrs);
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};

struct ReduceScatterOpConverter
: public OpRewritePattern<mhlo::ReduceScatterOp> {
using OpRewritePattern<mhlo::ReduceScatterOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::ReduceScatterOp op,
PatternRewriter& rewriter) const override {
auto reductionKind = MatchReductionComputation(op.getRegion());
if (!reductionKind) {
return failure();
}
SmallVector<NamedAttribute, 4> newAttrsVec;
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr("ral_reduce_scatter")));
newAttrsVec.push_back(NamedAttribute(rewriter.getStringAttr("device"),
rewriter.getStringAttr("d")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("input_placements"),
rewriter.getStringAttr("d")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("output_placements"),
rewriter.getStringAttr("d")));
newAttrsVec.push_back(NamedAttribute(
rewriter.getStringAttr("input_layouts"), rewriter.getStringAttr("*")));
newAttrsVec.push_back(NamedAttribute(
rewriter.getStringAttr("output_layouts"), rewriter.getStringAttr("*")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("expected_input_layouts"),
rewriter.getStringAttr("*")));
newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("expected_output_layouts"),
rewriter.getStringAttr("*")));

SmallVector<NamedAttribute, 4> customAttrs;
customAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("reduction_kind"),
rewriter.getStringAttr(reductionKind.value())));
customAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("scatter_dimension"),
op->getAttr("scatter_dimension")));
customAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("replica_groups"),
op->getAttr("replica_groups")));

newAttrsVec.push_back(
NamedAttribute(rewriter.getStringAttr("custom_attrs"),
rewriter.getDictionaryAttr(customAttrs)));

ArrayRef<NamedAttribute> newCustomAttrs(newAttrsVec);
auto newOp = rewriter.create<mhlo_disc::CustomCallV2Op>(
op->getLoc(), op->getResultTypes(), op->getOperands(), newCustomAttrs);
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
} // namespace

struct DiscCollectiveOpsRewriterPass
: public DiscCollectiveOpsRewriterPassBase<DiscCollectiveOpsRewriterPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::mhlo_disc::MhloDiscDialect>();
}
void runOnOperation() override {
func::FuncOp func = getOperation();
MLIRContext* ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.insert<AllReduceOpConverter>(ctx);
patterns.insert<AllGatherOpConverter>(ctx);
patterns.insert<ReduceScatterOpConverter>(ctx);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
func.emitError("applyPatternsAndFoldGreedily does not converge");
signalPassFailure();
}
}
};

std::unique_ptr<OperationPass<func::FuncOp>>
createDiscCollectiveOpsRewriterPass() {
return std::make_unique<DiscCollectiveOpsRewriterPass>();
}

} // namespace disc_ral
} // namespace mlir
5 changes: 5 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def ConvRewriterPass : Pass<"disc-conv-rewriter", "mlir::func::FuncOp"> {
];
}

def DiscCollectiveOpsRewriterPass : Pass<"disc-collective-ops-rewriter", "mlir::func::FuncOp"> {
let summary = "Rewrite mhlo collective ops to DISC custom library call";
let constructor = "createDiscCollectiveOpsRewriterPass()";
}

def MhloDecompositionRewriterPass : Pass<"disc-mhlo-decomp-rewriter", "mlir::func::FuncOp"> {
let summary = "Rewrite and decompose mhlo ops.";
let constructor = "createDiscMhloDecompositionRewriterPass()";
Expand Down
Loading
Loading