From 1ade6fd7640873d6fdac73b09033d740d7136035 Mon Sep 17 00:00:00 2001 From: yancey Date: Tue, 19 Mar 2024 14:19:06 +0800 Subject: [PATCH 1/5] add disc-collective-ops-rewriter pass to support collective ops --- tao_compiler/mlir/disc/disc_compiler.cc | 1 + .../disc_collective_ops_rewriter.cc | 259 ++++++++++++++++++ .../mlir/disc/transforms/disc_passes.td | 5 + .../disc/transforms/mhlo_decomp_rewriters.cc | 93 ------- tao_compiler/mlir/disc/transforms/passes.h | 6 + .../tests/disc-collective-ops-rewriter.mlir | 11 + tao_compiler/mlir/ral/collective.cu.cc | 160 ++++++++++- 7 files changed, 435 insertions(+), 100 deletions(-) create mode 100644 tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc create mode 100644 tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index 9e681c8b738..cc2b535adfc 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -245,6 +245,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addPass(disc_ral::createDiscInputOutputAliasPass()); pm.addPass(mlir::createInlinerPass()); // TODO(disc): Lower HLO shape constraints instead of eliding them here. + pm.addNestedPass(disc_ral::createDiscCollectiveOpsRewriterPass()); pm.addNestedPass(disc_ral::createDiscMhloDecompositionRewriterPass()); pm.addNestedPass(disc_ral::createDiscRemoveShapeConstraintsPass()); pm.addNestedPass(createCanonicalizerPass()); diff --git a/tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc b/tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc new file mode 100644 index 00000000000..1577852c6e7 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc @@ -0,0 +1,259 @@ +// Copyright 2021 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/disc/IR/hlo_disc_ops.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" + +namespace mlir { +namespace disc_ral { + +namespace { +enum ReductionKind { + ALL_REDUCE_SUM, + ALL_REDUCE_PRODUCT, + ALL_REDUCE_MIN, + ALL_REDUCE_MAX, +}; + +std::optional ReductionKindToString(ReductionKind kind) { + switch (kind) { + case ReductionKind::ALL_REDUCE_SUM: + return "sum"; + case ReductionKind::ALL_REDUCE_PRODUCT: + return "product"; + case ReductionKind::ALL_REDUCE_MIN: + return "min"; + case ReductionKind::ALL_REDUCE_MAX: + return "max"; + } + return std::nullopt; +} + +std::optional MatchReductionComputation(Region& region) { + if (!region.hasOneBlock()) { + return std::nullopt; + } + + auto ret = dyn_cast(region.front().getTerminator()); + if (!ret || ret->getNumOperands() != 1) { + return std::nullopt; + } + + auto computation = ret.getOperand(0).getDefiningOp(); + + if (isa(computation)) { + return "sum"; + } + if (isa(computation)) { + return "product"; + } + if (isa(computation)) { + return "min"; + } + if (isa(computation)) { + return "max"; + } + return std::nullopt; +} + +struct AllReduceOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::AllReduceOp op, + PatternRewriter& rewriter) const override { + SmallVector newOutputs; + auto reductionKind = MatchReductionComputation(op.getRegion()); + if (!reductionKind) { + return failure(); + } + for (int i = 0; i < op->getOperands().size(); ++i) { + // no need call all_reduce op if no consumer + if (op->getResult(i).getUsers().empty()) { + continue; + } + + op->setAttr("call_target_name", rewriter.getStringAttr("ral_all_reduce")); + op->setAttr("device", rewriter.getStringAttr("d")); + op->setAttr("input_placements", rewriter.getStringAttr("d")); + op->setAttr("output_placements", rewriter.getStringAttr("d")); + op->setAttr("input_layouts", rewriter.getStringAttr("*")); + op->setAttr("output_layouts", rewriter.getStringAttr("*")); + op->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); + op->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); + SmallVector newAttrs; + newAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("reduction_kind"), + rewriter.getStringAttr(reductionKind.value()))); + + auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs); + + op->setAttr("custom_attrs", newCustomAttrs); + + auto newOutput = rewriter.create( + op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i], + op->getAttrs()); + newOutputs.push_back(newOutput.getResult(0)); + } + rewriter.replaceOp(op, newOutputs); + return success(); + } +}; + +struct AllGatherOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::AllGatherOp op, + PatternRewriter& rewriter) const override { + SmallVector newAttrsVec; + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("call_target_name"), + rewriter.getStringAttr("ral_all_gather"))); + newAttrsVec.push_back(NamedAttribute(rewriter.getStringAttr("device"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("input_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("output_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("input_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("output_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_input_layouts"), + rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_output_layouts"), + rewriter.getStringAttr("*"))); + + SmallVector customAttrs; + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("all_gather_dim"), + op->getAttr("all_gather_dim"))); + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("replica_groups"), + op->getAttr("replica_groups"))); + + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("custom_attrs"), + rewriter.getDictionaryAttr(customAttrs))); + + ArrayRef newCustomAttrs(newAttrsVec); + auto newOp = rewriter.create( + op->getLoc(), op->getResultTypes(), op->getOperands(), newCustomAttrs); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +struct ReduceScatterOpConverter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::ReduceScatterOp op, + PatternRewriter& rewriter) const override { + auto reductionKind = MatchReductionComputation(op.getRegion()); + if (!reductionKind) { + return failure(); + } + SmallVector newAttrsVec; + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("call_target_name"), + rewriter.getStringAttr("ral_reduce_scatter"))); + newAttrsVec.push_back(NamedAttribute(rewriter.getStringAttr("device"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("input_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("output_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("input_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("output_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_input_layouts"), + rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_output_layouts"), + rewriter.getStringAttr("*"))); + + SmallVector customAttrs; + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("reduction_kind"), + rewriter.getStringAttr(reductionKind.value()))); + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("scatter_dimension"), + op->getAttr("scatter_dimension"))); + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("replica_groups"), + op->getAttr("replica_groups"))); + + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("custom_attrs"), + rewriter.getDictionaryAttr(customAttrs))); + + ArrayRef newCustomAttrs(newAttrsVec); + auto newOp = rewriter.create( + op->getLoc(), op->getResultTypes(), op->getOperands(), newCustomAttrs); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; +} // namespace + +struct DiscCollectiveOpsRewriterPass + : public DiscCollectiveOpsRewriterPassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext* ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + func.emitError("applyPatternsAndFoldGreedily does not converge"); + signalPassFailure(); + } + } +}; + +std::unique_ptr> +createDiscCollectiveOpsRewriterPass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index 7b97ed610df..c2d3108c9d0 100755 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -66,6 +66,11 @@ def ConvRewriterPass : Pass<"disc-conv-rewriter", "mlir::func::FuncOp"> { ]; } +def DiscCollectiveOpsRewriterPass : Pass<"disc-collective-ops-rewriter", "mlir::func::FuncOp"> { + let summary = "Rewrite mhlo collective ops to DISC custom library call"; + let constructor = "createDiscCollectiveOpsRewriterPass()"; +} + def MhloDecompositionRewriterPass : Pass<"disc-mhlo-decomp-rewriter", "mlir::func::FuncOp"> { let summary = "Rewrite and decompose mhlo ops."; let constructor = "createDiscMhloDecompositionRewriterPass()"; diff --git a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc index e5f6b46ab21..6c717f611ec 100644 --- a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc +++ b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc @@ -133,98 +133,6 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op, return success(); } } // namespace -namespace { -enum ReductionKind { - ALL_REDUCE_SUM, - ALL_REDUCE_PRODUCT, - ALL_REDUCE_MIN, - ALL_REDUCE_MAX, -}; - -std::optional ReductionKindToString(ReductionKind kind) { - switch (kind) { - case ReductionKind::ALL_REDUCE_SUM: - return "sum"; - case ReductionKind::ALL_REDUCE_PRODUCT: - return "product"; - case ReductionKind::ALL_REDUCE_MIN: - return "min"; - case ReductionKind::ALL_REDUCE_MAX: - return "max"; - } - return std::nullopt; -} - -std::optional MatchReductionComputation(Region& region) { - if (!region.hasOneBlock()) { - return std::nullopt; - } - - auto ret = dyn_cast(region.front().getTerminator()); - if (!ret || ret->getNumOperands() != 1) { - return std::nullopt; - } - - auto computation = ret.getOperand(0).getDefiningOp(); - - if (isa(computation)) { - return "sum"; - } - if (isa(computation)) { - return "product"; - } - if (isa(computation)) { - return "min"; - } - if (isa(computation)) { - return "max"; - } - return std::nullopt; -} - -struct CollectiveOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::AllReduceOp op, - PatternRewriter& rewriter) const override { - SmallVector newOutputs; - auto reductionKind = MatchReductionComputation(op.getRegion()); - if (!reductionKind) { - return failure(); - } - for (int i = 0; i < op->getOperands().size(); ++i) { - // no need call all_reduce op if no consumer - if (op->getResult(i).getUsers().empty()) { - continue; - } - - op->setAttr("call_target_name", rewriter.getStringAttr("ral_all_reduce")); - op->setAttr("device", rewriter.getStringAttr("d")); - op->setAttr("input_placements", rewriter.getStringAttr("d")); - op->setAttr("output_placements", rewriter.getStringAttr("d")); - op->setAttr("input_layouts", rewriter.getStringAttr("*")); - op->setAttr("output_layouts", rewriter.getStringAttr("*")); - op->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); - op->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); - SmallVector newAttrs; - newAttrs.push_back( - NamedAttribute(rewriter.getStringAttr("reduction_kind"), - rewriter.getStringAttr(reductionKind.value()))); - - auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs); - - op->setAttr("custom_attrs", newCustomAttrs); - - auto newOutput = rewriter.create( - op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i], - op->getAttrs()); - newOutputs.push_back(newOutput.getResult(0)); - } - rewriter.replaceOp(op, newOutputs); - return success(); - } -}; -} // namespace struct MhloDecompositionRewriterPass : public MhloDecompositionRewriterPassBase { @@ -240,7 +148,6 @@ struct MhloDecompositionRewriterPass patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx); - patterns.insert(ctx); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { func.emitError("applyPatternsAndFoldGreedily does not converge"); signalPassFailure(); diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index e533d442ac2..95ad522a62e 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -328,7 +328,13 @@ createDiscEraseBufferDeallocationPass(); // Insert ArgsMutationOp for buffer reuse std::unique_ptr> createDiscInputOutputAliasPass(); + +// Modifty buffer allocation inst to reduce buffer live range std::unique_ptr> createDiscReduceBufferLiveRangePass(); + +// rewrite mhlo collective ops to disc custom library call +std::unique_ptr> +createDiscCollectiveOpsRewriterPass(); } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir new file mode 100644 index 00000000000..8021cd6a609 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir @@ -0,0 +1,11 @@ +// RUN: disc-opt -disc-collecitve-ops-rewriter %s | FileCheck %s + +func.func @main(%arg0: tensor<8x3xf32>) -> (tensor<2x3xf32>) attributes {tf.entry_function = {input_output_alias_outputs = "", input_output_alias_params = "", input_placements = "gpu", output_placements = "gpu"}} { + // CHECK: %[[T0:.*]] = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_reduce_scatter", custom_attrs = {reduction_kind = "sum", replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 0 : i64}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d"} : (tensor<8x3xf32>) -> tensor<2x3xf32> + %3 = "mhlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %5 = mhlo.add %arg2, %arg3 : tensor + mhlo.return %5 : tensor + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 0 : i64} : (tensor<8x3xf32>) -> tensor<2x3xf32> + return %3 : tensor<2x3xf32> +} \ No newline at end of file diff --git a/tao_compiler/mlir/ral/collective.cu.cc b/tao_compiler/mlir/ral/collective.cu.cc index 35f92258aa0..2ed2edcf080 100644 --- a/tao_compiler/mlir/ral/collective.cu.cc +++ b/tao_compiler/mlir/ral/collective.cu.cc @@ -52,24 +52,59 @@ ncclRedOp_t getNcclReductionType(const std::string& kind) { return it->second; } +template +MemRefType ral_all_reduce_0d(ExecutionContext* ctx, void* stream_handle, + MemRefType input, void* customAttrs) { + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_all_reduce"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + + std::string reductionKind = + dictAttr.get("reduction_kind").template as().getValue(); + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; + auto ncclReductionType = getNcclReductionType(reductionKind); + + auto send_buffer = input.data; + int input_elements = 1; + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + auto nccl_comm = + static_cast(ctx)->getNcclComm(); + auto ptr = + static_cast(gpu_driver->alloc(ctx, input_elements * sizeof(T))); + auto output = assignMemRef_0d(ptr); + auto recv_buffer = output.data; + auto ncclResult = + ncclAllReduce(send_buffer, recv_buffer, input_elements, ncclDtype, + ncclReductionType, nccl_comm, gpu_stream); + if (ncclResult != ncclSuccess) { + ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n"); + } + return output; +} + template MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, MemRefType input, void* customAttrs) { - auto attr = - getOrParsePDLAttr(ctx, customAttrs, "simple_test_fused_add_mul_kernel"); + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_all_reduce"); if (!attr) { ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); } auto& dictAttr = attr->as(); + std::string reductionKind = dictAttr.get("reduction_kind").template as().getValue(); ncclDataType_t ncclDtype = ncclDataTypeMapper::value; auto ncclReductionType = getNcclReductionType(reductionKind); auto send_buffer = input.data; - int element_count = 1; + int input_elements = 1; for (int i = 0; i < N; ++i) { - element_count *= input.sizes[i]; + input_elements *= input.sizes[i]; } auto gpu_driver = ctx->getDriver( tao::ral::gpu::GPUDriver::name()); @@ -77,12 +112,12 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, static_cast(gpu_driver->asCUStream(ctx, stream_handle)); auto nccl_comm = static_cast(ctx)->getNcclComm(); - auto ptr = static_cast(gpu_driver->alloc(ctx, element_count * sizeof(T))); + auto ptr = + static_cast(gpu_driver->alloc(ctx, input_elements * sizeof(T))); auto output = assignMemRef(ptr, input.sizes); auto recv_buffer = output.data; - // TODO(yancey): support more nccl operations auto ncclResult = - ncclAllReduce(send_buffer, recv_buffer, element_count, ncclDtype, + ncclAllReduce(send_buffer, recv_buffer, input_elements, ncclDtype, ncclReductionType, nccl_comm, gpu_stream); if (ncclResult != ncclSuccess) { ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n"); @@ -90,13 +125,124 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, return output; } +template +MemRefType ral_all_gather(ExecutionContext* ctx, void* stream_handle, + MemRefType input, void* customAttrs) { + T* send_buffer = input.data; + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_all_reduce"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + int all_gather_dim = + dictAttr.get("all_gather_dim").template as().getValue(); + auto replic_groups = + dictAttr.get("replica_groups").template as(); + int output_sizes[N]; + for (int i = 0; i < N; ++i) output_sizes[i] = input.sizes[i]; + output_sizes[all_gather_dim] = + input.sizes[all_gather_dim] * replic_groups.getShape()[1]; + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + auto nccl_comm = + static_cast(ctx)->getNcclComm(); + int input_elements = 1; + for (int i = 0; i < N; ++i) { + input_elements *= input.sizes[i]; + } + int output_elements = input_elements * replic_groups.getShape()[1]; + auto ptr = + static_cast(gpu_driver->alloc(ctx, output_elements * sizeof(T))); + auto output = assignMemRef(ptr, output_sizes); + auto recv_buffer = output.data; + + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; + + if (ncclSuccess != ncclAllGather(send_buffer, recv_buffer, input_elements, + ncclDtype, nccl_comm, gpu_stream)) { + ctx->signalError(Context::FAILURE, "fail to call ncclAllGather\n"); + } + return output; +} + +template +MemRefType ral_reduce_scatter(ExecutionContext* ctx, void* stream_handle, + MemRefType input, void* customAttrs) { + T* send_buffer = input.data; + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_reduce_scatter"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + int scatter_dimension = + dictAttr.get("scatter_dimension").template as().getValue(); + auto replic_groups = + dictAttr.get("replica_groups").template as(); + std::string reductionKind = + dictAttr.get("reduction_kind").template as().getValue(); + auto ncclReductionType = getNcclReductionType(reductionKind); + + int output_sizes[N]; + for (int i = 0; i < N; ++i) output_sizes[i] = input.sizes[i]; + output_sizes[scatter_dimension] = + input.sizes[scatter_dimension] / replic_groups.getShape()[1]; + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + auto nccl_comm = + static_cast(ctx)->getNcclComm(); + int output_elements = 1; + for (int i = 0; i < N; ++i) { + output_elements *= output_sizes[i]; + } + auto ptr = + static_cast(gpu_driver->alloc(ctx, output_elements * sizeof(T))); + auto output = assignMemRef(ptr, output_sizes); + auto recv_buffer = output.data; + + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; + + if (ncclSuccess != + ncclReduceScatter(send_buffer, recv_buffer, output_elements, ncclDtype, + ncclReductionType, nccl_comm, gpu_stream)) { + ctx->signalError(Context::FAILURE, "fail to call ncclReduceScatter\n"); + } + return output; +} + +TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce_0d); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); +TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce_0d); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); + +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); + +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); + } // namespace ral } // namespace tao From fdfb385b881a5b18c5a5efb804745a5d758495ad Mon Sep 17 00:00:00 2001 From: yancey Date: Tue, 19 Mar 2024 14:20:18 +0800 Subject: [PATCH 2/5] update --- tao_compiler/mlir/disc/BUILD | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index bccadfd8565..8fdf954aae7 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -492,6 +492,27 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_collective_ops_rewriter", + srcs = ["transforms/disc_collective_ops_rewriter.cc"], + hdrs = ["transforms/passes.h"], + includes = [ + "tensorflow/compiler/xla/mlir_hlo/include", + "." + ], + deps = [ + ":disc_util", + ":pass_details", + ":mhlo_disc", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "mhlo_decomp_rewriters", srcs = ["transforms/mhlo_decomp_rewriters.cc"], @@ -2407,6 +2428,7 @@ cc_library( ":input_inline_fusion", ":lhlo_fusion_inliner", ":mhlo_decomp_rewriters", + ":disc_collective_ops_rewriter", ":mhlo_mark_shape_calc", ":mhlo_placer", ":ral_inject_execution_context", From 365719719d63110414fa6e571767282646b7accb Mon Sep 17 00:00:00 2001 From: YanXu Date: Wed, 20 Mar 2024 21:12:47 +0800 Subject: [PATCH 3/5] fix ut --- .../disc/transforms/tests/disc-collective-ops-rewriter.mlir | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir index 8021cd6a609..e173206056e 100644 --- a/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir @@ -1,7 +1,8 @@ // RUN: disc-opt -disc-collecitve-ops-rewriter %s | FileCheck %s -func.func @main(%arg0: tensor<8x3xf32>) -> (tensor<2x3xf32>) attributes {tf.entry_function = {input_output_alias_outputs = "", input_output_alias_params = "", input_placements = "gpu", output_placements = "gpu"}} { - // CHECK: %[[T0:.*]] = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_reduce_scatter", custom_attrs = {reduction_kind = "sum", replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 0 : i64}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d"} : (tensor<8x3xf32>) -> tensor<2x3xf32> +// CHECK-LABEL: @reduce_scatter +func.func @reduce_scatter(%arg0: tensor<8x3xf32>) -> (tensor<2x3xf32>) attributes {tf.entry_function = {input_output_alias_outputs = "", input_output_alias_params = "", input_placements = "gpu", output_placements = "gpu"}} { + // CHECK: "mhlo_disc.custom_call_v2"(%arg0) %3 = "mhlo.reduce_scatter"(%arg0) ({ ^bb0(%arg2: tensor, %arg3: tensor): %5 = mhlo.add %arg2, %arg3 : tensor From a4f8c775cc7302fa0012b8710d3fcfbb8d34ded3 Mon Sep 17 00:00:00 2001 From: YanXu Date: Thu, 21 Mar 2024 00:24:43 +0800 Subject: [PATCH 4/5] fix ut --- .../disc/transforms/tests/disc-collective-ops-rewriter.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir index e173206056e..bf669f24491 100644 --- a/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir @@ -1,4 +1,4 @@ -// RUN: disc-opt -disc-collecitve-ops-rewriter %s | FileCheck %s +// RUN: disc-opt --disc-collective-ops-rewriter %s | FileCheck %s // CHECK-LABEL: @reduce_scatter func.func @reduce_scatter(%arg0: tensor<8x3xf32>) -> (tensor<2x3xf32>) attributes {tf.entry_function = {input_output_alias_outputs = "", input_output_alias_params = "", input_placements = "gpu", output_placements = "gpu"}} { From 722bded36eef146af43dd55e78217c7a522c9bd0 Mon Sep 17 00:00:00 2001 From: YanXu Date: Thu, 21 Mar 2024 12:06:29 +0800 Subject: [PATCH 5/5] update --- .../tests/disc-collective-ops-rewriter.mlir | 15 +++++++++++++-- .../transforms/tests/mhlo_decomp_rewriter.mlir | 11 ----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir index bf669f24491..988e664e22f 100644 --- a/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir @@ -2,11 +2,22 @@ // CHECK-LABEL: @reduce_scatter func.func @reduce_scatter(%arg0: tensor<8x3xf32>) -> (tensor<2x3xf32>) attributes {tf.entry_function = {input_output_alias_outputs = "", input_output_alias_params = "", input_placements = "gpu", output_placements = "gpu"}} { - // CHECK: "mhlo_disc.custom_call_v2"(%arg0) + // CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg0) %3 = "mhlo.reduce_scatter"(%arg0) ({ ^bb0(%arg2: tensor, %arg3: tensor): %5 = mhlo.add %arg2, %arg3 : tensor mhlo.return %5 : tensor }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 0 : i64} : (tensor<8x3xf32>) -> tensor<2x3xf32> return %3 : tensor<2x3xf32> -} \ No newline at end of file +} + +func.func @all_reduce(%arg0: tensor, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor) { + // CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %1 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor + %0:2 = "mhlo.all_reduce"(%arg1, %arg0) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + mhlo.return %1 : tensor + }) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>, tensor) -> (tensor<4xf32>, tensor) + return %0#0, %0#1 : tensor<4xf32>, tensor +} diff --git a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir index 10d3c5b9df6..5768a2f224b 100644 --- a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir @@ -35,14 +35,3 @@ func.func @batch_norm_inference(%arg0: tensor, %arg1: tensor<128x %0 = "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor return %0: tensor } - -func.func @main(%arg0: tensor, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor) { - // CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor - %0:2 = "mhlo.all_reduce"(%arg1, %arg0) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.add %arg2, %arg3 : tensor - mhlo.return %1 : tensor - }) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>, tensor) -> (tensor<4xf32>, tensor) - return %0#0, %0#1 : tensor<4xf32>, tensor -}