diff --git a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc index e5f6b46ab21..f101c9a3d55 100644 --- a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc +++ b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc @@ -134,6 +134,21 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op, } } // namespace namespace { + +bool IsAsyncCollective(Operation* op) { + if (llvm::isa(op)) { + if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_REDUCE")) { + return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0; + } + } else if (llvm::isa(op)) { + if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_GATHER")) { + return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0; + } + } + + return false; +} + enum ReductionKind { ALL_REDUCE_SUM, ALL_REDUCE_PRODUCT, @@ -192,6 +207,9 @@ struct CollectiveOpConverter : public OpRewritePattern { if (!reductionKind) { return failure(); } + + bool is_async = IsAsyncCollective(op.getOperation()); + for (int i = 0; i < op->getOperands().size(); ++i) { // no need call all_reduce op if no consumer if (op->getResult(i).getUsers().empty()) { @@ -206,19 +224,48 @@ struct CollectiveOpConverter : public OpRewritePattern { op->setAttr("output_layouts", rewriter.getStringAttr("*")); op->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); op->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); - SmallVector newAttrs; - newAttrs.push_back( + + SmallVector attrs; + attrs.push_back( NamedAttribute(rewriter.getStringAttr("reduction_kind"), rewriter.getStringAttr(reductionKind.value()))); + attrs.push_back(NamedAttribute(rewriter.getStringAttr("is_async"), + rewriter.getBoolAttr(is_async))); + auto customAttrs = DictionaryAttr::get(op->getContext(), attrs); + op->setAttr("custom_attrs", customAttrs); - auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs); - - op->setAttr("custom_attrs", newCustomAttrs); - - auto newOutput = rewriter.create( + auto reduce_op = rewriter.create( op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i], op->getAttrs()); - newOutputs.push_back(newOutput.getResult(0)); + + if (is_async) { + int64_t async_pair_token = + reinterpret_cast(reduce_op.getOperation()); + attrs.push_back( + NamedAttribute(rewriter.getStringAttr("async_token_key"), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + async_pair_token))); + auto newCustomAttrs = + DictionaryAttr::get(reduce_op->getContext(), attrs); + reduce_op->setAttr("custom_attrs", newCustomAttrs); + } + + if (is_async) { + // Insert CollectiveDoneOp + auto collective_done_op = rewriter.create( + reduce_op->getLoc(), reduce_op->getResults()[0].getType(), + reduce_op->getResults()[0], reduce_op->getAttrs()); + collective_done_op->setAttr( + "call_target_name", + rewriter.getStringAttr("ral_async_collective_done")); + + // Place collective_done_op right before the first consumer. + collective_done_op->moveBefore(*(op->getResult(i).user_begin())); + + newOutputs.push_back(collective_done_op.getResult(0)); + } else { + newOutputs.push_back(reduce_op.getResult(0)); + } } rewriter.replaceOp(op, newOutputs); return success(); diff --git a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter_with_async_collective_op.mlir b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter_with_async_collective_op.mlir new file mode 100755 index 00000000000..55bdad058ba --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter_with_async_collective_op.mlir @@ -0,0 +1,15 @@ +// RUN: ENABLE_ASYNC_ALL_REDUCE=true disc-opt -disc-mhlo-decomp-rewriter -split-input-file %s -o - | FileCheck %s + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor) { + // CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {async_token_key = 94869518500912 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %2 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {async_token_key = 94869518501936 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor + // CHECK: %1 = "mhlo_disc.custom_call_v2"(%0) {call_target_name = "ral_async_collective_done", custom_attrs = {async_token_key = 94869518500912 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %3 = "mhlo_disc.custom_call_v2"(%2) {call_target_name = "ral_async_collective_done", custom_attrs = {async_token_key = 94869518501936 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor + %0:2 = "mhlo.all_reduce"(%arg1, %arg0) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + mhlo.return %1 : tensor + }) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>, tensor) -> (tensor<4xf32>, tensor) + // CHECK: return %1, %3 : tensor<4xf32>, tensor + return %0#0, %0#1 : tensor<4xf32>, tensor +} diff --git a/tao_compiler/mlir/ral/collective.cu.cc b/tao_compiler/mlir/ral/collective.cu.cc index 35f92258aa0..431253d7301 100644 --- a/tao_compiler/mlir/ral/collective.cu.cc +++ b/tao_compiler/mlir/ral/collective.cu.cc @@ -63,6 +63,9 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, auto& dictAttr = attr->as(); std::string reductionKind = dictAttr.get("reduction_kind").template as().getValue(); + + bool isAsync = dictAttr.get("is_async").template as().getValue(); + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; auto ncclReductionType = getNcclReductionType(reductionKind); @@ -74,7 +77,7 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, auto gpu_driver = ctx->getDriver( tao::ral::gpu::GPUDriver::name()); auto gpu_stream = - static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + static_cast(ctx)->getCommStream(); auto nccl_comm = static_cast(ctx)->getNcclComm(); auto ptr = static_cast(gpu_driver->alloc(ctx, element_count * sizeof(T))); @@ -87,9 +90,69 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, if (ncclResult != ncclSuccess) { ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n"); } + + if (isAsync && gpu_stream) { + int64_t token_key = + dictAttr.get("async_token_key").template as().getValue(); + cudaEvent_t event; + + auto event_status = cudaEventCreate(&event); + if (event_status != cudaSuccess) { + ctx->signalError(Context::FAILURE, "cudaEventCreate failed\n"); + } + + auto record_status = cudaEventRecord(event, gpu_stream); + if (record_status != cudaSuccess) { + cudaEventDestroy(event); + ctx->signalError(Context::FAILURE, "cudaEventRecord failed\n"); + } + + static_cast(ctx)->addAsyncPairToken( + token_key, event); + } + return output; } +template +MemRefType ral_async_collective_done(ExecutionContext* ctx, + void* stream_handle, + MemRefType input, + void* customAttrs) { + auto attr = + getOrParsePDLAttr(ctx, customAttrs, "simple_test_fused_add_mul_kernel"); + if (!attr) { + ctx->signalError( + Context::FAILURE, + "fail to parse custom_attrs in ral_async_collective_done\n"); + } + + auto& dictAttr = attr->as(); + int64_t token_key = + dictAttr.get("async_token_key").template as().getValue(); + auto event = + static_cast(ctx)->getAsyncPairToken( + token_key); + if (event) { + auto sync_status = cudaEventSynchronize(event); + if (sync_status != cudaSuccess) { + ctx->signalError(Context::FAILURE, "cudaEventSynchronize failed\n"); + } + static_cast(ctx)->removeAsyncPairToken( + token_key); + cudaEventDestroy(event); + } + + // Increase ref count for input to prevent double free + auto it = + static_cast(ctx)->device_ptr_map.find( + input.data); + ; + ++it->second; + + return input; +} + TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); @@ -98,5 +161,23 @@ TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); + +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); + } // namespace ral } // namespace tao diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc index 0af72f978f2..7c84d83639c 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc @@ -119,10 +119,13 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { ncclComm_t nccl_comm = nullptr; GpuStreamHandle stream = nullptr; + GpuStreamHandle comm_stream = nullptr; // map blob ptr -> loaded module std::map blobs; // map -> callable kernel std::map, GpuFunctionHandle> kernels; + // map int64 -> cudaEvent_t + std::map async_pair_tokens; std::shared_ptr gpu_allocator; bool cache_workspace_mem_across_execution; @@ -146,6 +149,7 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { "StreamSync"); #else reportErrorIfAny(cuStreamSynchronize(stream), ctx, "StreamSync"); + reportErrorIfAny(cuStreamSynchronize(comm_stream), ctx, "StreamSync"); #endif for (const_buffer_t buffer : device_persistent_buffers) { gpu_allocator->dealloc(const_cast(buffer)); @@ -173,6 +177,7 @@ std::unique_ptr MakeBaseCudaContext( auto state = new BaseCudaContextState; state->stream = gpu_opt.stream; state->nccl_comm = gpu_opt.nccl_comm; + state->comm_stream = gpu_opt.comm_stream; if (gpu_opt.gpu_allocator != nullptr) { state->gpu_allocator = gpu_opt.gpu_allocator; } else { @@ -206,6 +211,34 @@ ncclComm_t BaseCudaExecutionContext::getNcclComm() { return state->nccl_comm; } +GpuStreamHandle BaseCudaExecutionContext::getCommStream() { + auto* state = getResource(kRalBaseCudaContextState); + return state->comm_stream; +} + +cudaEvent_t BaseCudaExecutionContext::getAsyncPairToken(int64_t key) { + auto* state = getResource(kRalBaseCudaContextState); + if (state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) { + return state->async_pair_tokens[key]; + } + return nullptr; +} + +void BaseCudaExecutionContext::addAsyncPairToken(int64_t key, + cudaEvent_t token) { + auto* state = getResource(kRalBaseCudaContextState); + state->async_pair_tokens[key] = token; + return; +} + +void BaseCudaExecutionContext::removeAsyncPairToken(int64_t key) { + auto* state = getResource(kRalBaseCudaContextState); + if (state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) { + state->async_pair_tokens.erase(key); + } + return; +} + void BaseCudaExecutionContext::setOutputDeleter(OutputBufferWrapper& output) { { if (synced) { diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h index 314f5f7285a..c15cf1ba56e 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h @@ -48,6 +48,7 @@ using GpuStreamHandle = CUstream; struct BaseCudaContextOption { ncclComm_t nccl_comm = nullptr; GpuStreamHandle stream = nullptr; + GpuStreamHandle comm_stream = nullptr; int device_ordinal = 0; bool use_stream_executor = true; bool cache_workspace_mem_across_execution = false; @@ -64,6 +65,12 @@ struct BaseCudaExecutionContext ~BaseCudaExecutionContext(); ncclComm_t getNcclComm(); + + GpuStreamHandle getCommStream(); + + cudaEvent_t getAsyncPairToken(int64_t key); + void addAsyncPairToken(int64_t key, cudaEvent_t token); + void removeAsyncPairToken(int64_t key); // We need to sync on the gpu stream before we fetch the first output. bool synced = false; // all buffer allocated by the gpu_allocator