From f8457519c79bca6a074e79892309451df2ccc093 Mon Sep 17 00:00:00 2001 From: yancey Date: Fri, 15 Mar 2024 15:27:02 +0800 Subject: [PATCH] add collective ops rewrtier pass --- tao_compiler/mlir/disc/BUILD | 22 ++ tao_compiler/mlir/disc/disc_compiler.cc | 1 + .../disc_collective_ops_rewriter.cc | 259 ++++++++++++++++++ .../mlir/disc/transforms/disc_passes.td | 5 + .../disc/transforms/mhlo_decomp_rewriters.cc | 93 ------- tao_compiler/mlir/disc/transforms/passes.h | 6 + tao_compiler/mlir/ral/collective.cu.cc | 123 ++++++++- 7 files changed, 409 insertions(+), 100 deletions(-) create mode 100644 tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index bccadfd8565..8fdf954aae7 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -492,6 +492,27 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_collective_ops_rewriter", + srcs = ["transforms/disc_collective_ops_rewriter.cc"], + hdrs = ["transforms/passes.h"], + includes = [ + "tensorflow/compiler/xla/mlir_hlo/include", + "." + ], + deps = [ + ":disc_util", + ":pass_details", + ":mhlo_disc", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "mhlo_decomp_rewriters", srcs = ["transforms/mhlo_decomp_rewriters.cc"], @@ -2407,6 +2428,7 @@ cc_library( ":input_inline_fusion", ":lhlo_fusion_inliner", ":mhlo_decomp_rewriters", + ":disc_collective_ops_rewriter", ":mhlo_mark_shape_calc", ":mhlo_placer", ":ral_inject_execution_context", diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index 9e681c8b738..cc2b535adfc 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -245,6 +245,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addPass(disc_ral::createDiscInputOutputAliasPass()); pm.addPass(mlir::createInlinerPass()); // TODO(disc): Lower HLO shape constraints instead of eliding them here. + pm.addNestedPass(disc_ral::createDiscCollectiveOpsRewriterPass()); pm.addNestedPass(disc_ral::createDiscMhloDecompositionRewriterPass()); pm.addNestedPass(disc_ral::createDiscRemoveShapeConstraintsPass()); pm.addNestedPass(createCanonicalizerPass()); diff --git a/tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc b/tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc new file mode 100644 index 00000000000..1577852c6e7 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc @@ -0,0 +1,259 @@ +// Copyright 2021 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/disc/IR/hlo_disc_ops.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" + +namespace mlir { +namespace disc_ral { + +namespace { +enum ReductionKind { + ALL_REDUCE_SUM, + ALL_REDUCE_PRODUCT, + ALL_REDUCE_MIN, + ALL_REDUCE_MAX, +}; + +std::optional ReductionKindToString(ReductionKind kind) { + switch (kind) { + case ReductionKind::ALL_REDUCE_SUM: + return "sum"; + case ReductionKind::ALL_REDUCE_PRODUCT: + return "product"; + case ReductionKind::ALL_REDUCE_MIN: + return "min"; + case ReductionKind::ALL_REDUCE_MAX: + return "max"; + } + return std::nullopt; +} + +std::optional MatchReductionComputation(Region& region) { + if (!region.hasOneBlock()) { + return std::nullopt; + } + + auto ret = dyn_cast(region.front().getTerminator()); + if (!ret || ret->getNumOperands() != 1) { + return std::nullopt; + } + + auto computation = ret.getOperand(0).getDefiningOp(); + + if (isa(computation)) { + return "sum"; + } + if (isa(computation)) { + return "product"; + } + if (isa(computation)) { + return "min"; + } + if (isa(computation)) { + return "max"; + } + return std::nullopt; +} + +struct AllReduceOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::AllReduceOp op, + PatternRewriter& rewriter) const override { + SmallVector newOutputs; + auto reductionKind = MatchReductionComputation(op.getRegion()); + if (!reductionKind) { + return failure(); + } + for (int i = 0; i < op->getOperands().size(); ++i) { + // no need call all_reduce op if no consumer + if (op->getResult(i).getUsers().empty()) { + continue; + } + + op->setAttr("call_target_name", rewriter.getStringAttr("ral_all_reduce")); + op->setAttr("device", rewriter.getStringAttr("d")); + op->setAttr("input_placements", rewriter.getStringAttr("d")); + op->setAttr("output_placements", rewriter.getStringAttr("d")); + op->setAttr("input_layouts", rewriter.getStringAttr("*")); + op->setAttr("output_layouts", rewriter.getStringAttr("*")); + op->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); + op->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); + SmallVector newAttrs; + newAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("reduction_kind"), + rewriter.getStringAttr(reductionKind.value()))); + + auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs); + + op->setAttr("custom_attrs", newCustomAttrs); + + auto newOutput = rewriter.create( + op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i], + op->getAttrs()); + newOutputs.push_back(newOutput.getResult(0)); + } + rewriter.replaceOp(op, newOutputs); + return success(); + } +}; + +struct AllGatherOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::AllGatherOp op, + PatternRewriter& rewriter) const override { + SmallVector newAttrsVec; + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("call_target_name"), + rewriter.getStringAttr("ral_all_gather"))); + newAttrsVec.push_back(NamedAttribute(rewriter.getStringAttr("device"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("input_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("output_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("input_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("output_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_input_layouts"), + rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_output_layouts"), + rewriter.getStringAttr("*"))); + + SmallVector customAttrs; + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("all_gather_dim"), + op->getAttr("all_gather_dim"))); + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("replica_groups"), + op->getAttr("replica_groups"))); + + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("custom_attrs"), + rewriter.getDictionaryAttr(customAttrs))); + + ArrayRef newCustomAttrs(newAttrsVec); + auto newOp = rewriter.create( + op->getLoc(), op->getResultTypes(), op->getOperands(), newCustomAttrs); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +struct ReduceScatterOpConverter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::ReduceScatterOp op, + PatternRewriter& rewriter) const override { + auto reductionKind = MatchReductionComputation(op.getRegion()); + if (!reductionKind) { + return failure(); + } + SmallVector newAttrsVec; + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("call_target_name"), + rewriter.getStringAttr("ral_reduce_scatter"))); + newAttrsVec.push_back(NamedAttribute(rewriter.getStringAttr("device"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("input_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("output_placements"), + rewriter.getStringAttr("d"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("input_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back(NamedAttribute( + rewriter.getStringAttr("output_layouts"), rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_input_layouts"), + rewriter.getStringAttr("*"))); + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("expected_output_layouts"), + rewriter.getStringAttr("*"))); + + SmallVector customAttrs; + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("reduction_kind"), + rewriter.getStringAttr(reductionKind.value()))); + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("scatter_dimension"), + op->getAttr("scatter_dimension"))); + customAttrs.push_back( + NamedAttribute(rewriter.getStringAttr("replica_groups"), + op->getAttr("replica_groups"))); + + newAttrsVec.push_back( + NamedAttribute(rewriter.getStringAttr("custom_attrs"), + rewriter.getDictionaryAttr(customAttrs))); + + ArrayRef newCustomAttrs(newAttrsVec); + auto newOp = rewriter.create( + op->getLoc(), op->getResultTypes(), op->getOperands(), newCustomAttrs); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; +} // namespace + +struct DiscCollectiveOpsRewriterPass + : public DiscCollectiveOpsRewriterPassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext* ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + func.emitError("applyPatternsAndFoldGreedily does not converge"); + signalPassFailure(); + } + } +}; + +std::unique_ptr> +createDiscCollectiveOpsRewriterPass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index 7b97ed610df..c2d3108c9d0 100755 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -66,6 +66,11 @@ def ConvRewriterPass : Pass<"disc-conv-rewriter", "mlir::func::FuncOp"> { ]; } +def DiscCollectiveOpsRewriterPass : Pass<"disc-collective-ops-rewriter", "mlir::func::FuncOp"> { + let summary = "Rewrite mhlo collective ops to DISC custom library call"; + let constructor = "createDiscCollectiveOpsRewriterPass()"; +} + def MhloDecompositionRewriterPass : Pass<"disc-mhlo-decomp-rewriter", "mlir::func::FuncOp"> { let summary = "Rewrite and decompose mhlo ops."; let constructor = "createDiscMhloDecompositionRewriterPass()"; diff --git a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc index e5f6b46ab21..6c717f611ec 100644 --- a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc +++ b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc @@ -133,98 +133,6 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op, return success(); } } // namespace -namespace { -enum ReductionKind { - ALL_REDUCE_SUM, - ALL_REDUCE_PRODUCT, - ALL_REDUCE_MIN, - ALL_REDUCE_MAX, -}; - -std::optional ReductionKindToString(ReductionKind kind) { - switch (kind) { - case ReductionKind::ALL_REDUCE_SUM: - return "sum"; - case ReductionKind::ALL_REDUCE_PRODUCT: - return "product"; - case ReductionKind::ALL_REDUCE_MIN: - return "min"; - case ReductionKind::ALL_REDUCE_MAX: - return "max"; - } - return std::nullopt; -} - -std::optional MatchReductionComputation(Region& region) { - if (!region.hasOneBlock()) { - return std::nullopt; - } - - auto ret = dyn_cast(region.front().getTerminator()); - if (!ret || ret->getNumOperands() != 1) { - return std::nullopt; - } - - auto computation = ret.getOperand(0).getDefiningOp(); - - if (isa(computation)) { - return "sum"; - } - if (isa(computation)) { - return "product"; - } - if (isa(computation)) { - return "min"; - } - if (isa(computation)) { - return "max"; - } - return std::nullopt; -} - -struct CollectiveOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::AllReduceOp op, - PatternRewriter& rewriter) const override { - SmallVector newOutputs; - auto reductionKind = MatchReductionComputation(op.getRegion()); - if (!reductionKind) { - return failure(); - } - for (int i = 0; i < op->getOperands().size(); ++i) { - // no need call all_reduce op if no consumer - if (op->getResult(i).getUsers().empty()) { - continue; - } - - op->setAttr("call_target_name", rewriter.getStringAttr("ral_all_reduce")); - op->setAttr("device", rewriter.getStringAttr("d")); - op->setAttr("input_placements", rewriter.getStringAttr("d")); - op->setAttr("output_placements", rewriter.getStringAttr("d")); - op->setAttr("input_layouts", rewriter.getStringAttr("*")); - op->setAttr("output_layouts", rewriter.getStringAttr("*")); - op->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); - op->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); - SmallVector newAttrs; - newAttrs.push_back( - NamedAttribute(rewriter.getStringAttr("reduction_kind"), - rewriter.getStringAttr(reductionKind.value()))); - - auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs); - - op->setAttr("custom_attrs", newCustomAttrs); - - auto newOutput = rewriter.create( - op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i], - op->getAttrs()); - newOutputs.push_back(newOutput.getResult(0)); - } - rewriter.replaceOp(op, newOutputs); - return success(); - } -}; -} // namespace struct MhloDecompositionRewriterPass : public MhloDecompositionRewriterPassBase { @@ -240,7 +148,6 @@ struct MhloDecompositionRewriterPass patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx); - patterns.insert(ctx); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { func.emitError("applyPatternsAndFoldGreedily does not converge"); signalPassFailure(); diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index e533d442ac2..95ad522a62e 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -328,7 +328,13 @@ createDiscEraseBufferDeallocationPass(); // Insert ArgsMutationOp for buffer reuse std::unique_ptr> createDiscInputOutputAliasPass(); + +// Modifty buffer allocation inst to reduce buffer live range std::unique_ptr> createDiscReduceBufferLiveRangePass(); + +// rewrite mhlo collective ops to disc custom library call +std::unique_ptr> +createDiscCollectiveOpsRewriterPass(); } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/ral/collective.cu.cc b/tao_compiler/mlir/ral/collective.cu.cc index 35f92258aa0..a9ac445b47c 100644 --- a/tao_compiler/mlir/ral/collective.cu.cc +++ b/tao_compiler/mlir/ral/collective.cu.cc @@ -55,21 +55,21 @@ ncclRedOp_t getNcclReductionType(const std::string& kind) { template MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, MemRefType input, void* customAttrs) { - auto attr = - getOrParsePDLAttr(ctx, customAttrs, "simple_test_fused_add_mul_kernel"); + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_all_reduce"); if (!attr) { ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); } auto& dictAttr = attr->as(); + std::string reductionKind = dictAttr.get("reduction_kind").template as().getValue(); ncclDataType_t ncclDtype = ncclDataTypeMapper::value; auto ncclReductionType = getNcclReductionType(reductionKind); auto send_buffer = input.data; - int element_count = 1; + int input_elements = 1; for (int i = 0; i < N; ++i) { - element_count *= input.sizes[i]; + input_elements *= input.sizes[i]; } auto gpu_driver = ctx->getDriver( tao::ral::gpu::GPUDriver::name()); @@ -77,12 +77,12 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, static_cast(gpu_driver->asCUStream(ctx, stream_handle)); auto nccl_comm = static_cast(ctx)->getNcclComm(); - auto ptr = static_cast(gpu_driver->alloc(ctx, element_count * sizeof(T))); + auto ptr = + static_cast(gpu_driver->alloc(ctx, input_elements * sizeof(T))); auto output = assignMemRef(ptr, input.sizes); auto recv_buffer = output.data; - // TODO(yancey): support more nccl operations auto ncclResult = - ncclAllReduce(send_buffer, recv_buffer, element_count, ncclDtype, + ncclAllReduce(send_buffer, recv_buffer, input_elements, ncclDtype, ncclReductionType, nccl_comm, gpu_stream); if (ncclResult != ncclSuccess) { ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n"); @@ -90,6 +90,96 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, return output; } +template +MemRefType ral_all_gather(ExecutionContext* ctx, void* stream_handle, + MemRefType input, void* customAttrs) { + T* send_buffer = input.data; + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_all_reduce"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + int all_gather_dim = + dictAttr.get("all_gather_dim").template as().getValue(); + auto replic_groups = + dictAttr.get("replica_groups").template as(); + int output_sizes[N]; + for (int i = 0; i < N; ++i) output_sizes[i] = input.sizes[i]; + output_sizes[all_gather_dim] = + input.sizes[all_gather_dim] * replic_groups.getShape()[1]; + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + auto nccl_comm = + static_cast(ctx)->getNcclComm(); + int input_elements = 1; + for (int i = 0; i < N; ++i) { + input_elements *= input.sizes[i]; + } + int output_elements = input_elements * replic_groups.getShape()[1]; + auto ptr = + static_cast(gpu_driver->alloc(ctx, output_elements * sizeof(T))); + auto output = assignMemRef(ptr, output_sizes); + auto recv_buffer = output.data; + + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; + + if (ncclSuccess != ncclAllGather(send_buffer, recv_buffer, input_elements, + ncclDtype, nccl_comm, gpu_stream)) { + ctx->signalError(Context::FAILURE, "fail to call ncclAllGather\n"); + } + return output; +} + +template +MemRefType ral_reduce_scatter(ExecutionContext* ctx, void* stream_handle, + MemRefType input, void* customAttrs) { + T* send_buffer = input.data; + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_reduce_scatter"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + int scatter_dimension = + dictAttr.get("scatter_dimension").template as().getValue(); + auto replic_groups = + dictAttr.get("replica_groups").template as(); + std::string reductionKind = + dictAttr.get("reduction_kind").template as().getValue(); + auto ncclReductionType = getNcclReductionType(reductionKind); + + int output_sizes[N]; + for (int i = 0; i < N; ++i) output_sizes[i] = input.sizes[i]; + output_sizes[scatter_dimension] = + input.sizes[scatter_dimension] / replic_groups.getShape()[1]; + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + auto nccl_comm = + static_cast(ctx)->getNcclComm(); + int output_elements = 1; + for (int i = 0; i < N; ++i) { + output_elements *= output_sizes[i]; + } + auto ptr = + static_cast(gpu_driver->alloc(ctx, output_elements * sizeof(T))); + auto output = assignMemRef(ptr, output_sizes); + auto recv_buffer = output.data; + + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; + + if (ncclSuccess != + ncclReduceScatter(send_buffer, recv_buffer, output_elements, ncclDtype, + ncclReductionType, nccl_comm, gpu_stream)) { + ctx->signalError(Context::FAILURE, "fail to call ncclReduceScatter\n"); + } + return output; +} + TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); @@ -98,5 +188,24 @@ TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); + +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); +TAO_RAL_API("ral_all_gather", "gpu", ral_all_gather); + +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); +TAO_RAL_API("ral_reduce_scatter", "gpu", ral_reduce_scatter); + } // namespace ral } // namespace tao