diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc index a3de94b8677..c73a0a34b63 100755 --- a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc @@ -97,6 +97,18 @@ Value GetDefaultStreamHandle(Operation* op, PatternRewriter& rewriter) { return stream_idx; } + +Value GetAsyncCollectiveOpStreamHandle(Operation* op, PatternRewriter& rewriter) { + Location loc = op->getLoc(); + MLIRContext* ctx = rewriter.getContext(); + Type llvm_int32_type = IntegerType::get(ctx, 32); + Value zero = rewriter.create(loc, llvm_int32_type, + rewriter.getI32IntegerAttr(1)); + Type pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + Value stream_idx = rewriter.create(loc, pointer_type, zero); + return stream_idx; +} + // Insert a sync on stream call. void InsertSyncOnStream(Operation* op, Value ctx, Value stream_handle, PatternRewriter& rewriter) { @@ -847,6 +859,15 @@ LogicalResult emitAttr(Attribute attr, StrT& out) { return failure(); } +bool IsAsyncCollectiveOp(CustomCallV2Op op) { + + if(op.getAttr("call_target_name") == "ral_all_reduce" || + op.getAttr("call_target_name") == "ral_all_gather") { + return true; + } + + return false; +} struct CustomCallV2OpConvertor : public OpRewritePattern { CustomCallV2OpConvertor(MLIRContext* context, bool gpuEnabled) : OpRewritePattern::OpRewritePattern(context) { @@ -867,7 +888,13 @@ struct CustomCallV2OpConvertor : public OpRewritePattern { << "fail to lower the custom_attrs of the custom call op.\n"; } - Value streamHandle = GetDefaultStreamHandle(op, rewriter); + Value streamHandle; + if (IsAsyncCollectiveOp(op)) { + streamHandle = GetAsyncCollectiveOpStreamHandle(op, rewriter); + } else { + streamHandle = GetDefaultStreamHandle(op, rewriter); + } + SmallVector newOperands{streamHandle}; for (Value operand : op->getOperands()) newOperands.push_back(operand); diff --git a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc old mode 100644 new mode 100755 index e5f6b46ab21..9c78350b9ed --- a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc +++ b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc @@ -134,6 +134,22 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op, } } // namespace namespace { + + +bool IsAsyncCollective(Operation* op) { + if(llvm::isa(op)) { + if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_REDUCE")) { + return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0; + } + } else if(llvm::isa(op)) { + if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_GATHER")) { + return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0; + } + } + + return false; +} + enum ReductionKind { ALL_REDUCE_SUM, ALL_REDUCE_PRODUCT, @@ -192,6 +208,9 @@ struct CollectiveOpConverter : public OpRewritePattern { if (!reductionKind) { return failure(); } + + bool is_async = IsAsyncCollective(op.getOperation()); + for (int i = 0; i < op->getOperands().size(); ++i) { // no need call all_reduce op if no consumer if (op->getResult(i).getUsers().empty()) { @@ -206,19 +225,43 @@ struct CollectiveOpConverter : public OpRewritePattern { op->setAttr("output_layouts", rewriter.getStringAttr("*")); op->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); op->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); - SmallVector newAttrs; - newAttrs.push_back( - NamedAttribute(rewriter.getStringAttr("reduction_kind"), - rewriter.getStringAttr(reductionKind.value()))); - - auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs); - op->setAttr("custom_attrs", newCustomAttrs); + SmallVector attrs; + attrs.push_back( + NamedAttribute(rewriter.getStringAttr("reduction_kind"), + rewriter.getStringAttr(reductionKind.value())) + ); + attrs.push_back( + NamedAttribute(rewriter.getStringAttr("is_async"), + rewriter.getBoolAttr(is_async)) + ); + auto customAttrs = DictionaryAttr::get(op->getContext(), attrs); + op->setAttr("custom_attrs", customAttrs); - auto newOutput = rewriter.create( + auto reduce_op = rewriter.create( op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i], op->getAttrs()); - newOutputs.push_back(newOutput.getResult(0)); + + if(is_async) { + int64_t async_pair_token = reinterpret_cast(reduce_op.getOperation()); + attrs.push_back( + NamedAttribute(rewriter.getStringAttr("async_token_key"), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), async_pair_token)) + ); + auto newCustomAttrs = DictionaryAttr::get(reduce_op->getContext(), attrs); + reduce_op->setAttr("custom_attrs", newCustomAttrs); + } + + if(is_async) { + // Insert CollectiveDoneOp + auto collective_done_op = rewriter.create( + reduce_op->getLoc(), reduce_op->getResults()[0].getType(), reduce_op->getResults()[0], + reduce_op->getAttrs()); + collective_done_op->setAttr("call_target_name", rewriter.getStringAttr("ral_async_collective_done")); + newOutputs.push_back(collective_done_op.getResult(0)); + } else { + newOutputs.push_back(reduce_op.getResult(0)); + } } rewriter.replaceOp(op, newOutputs); return success(); diff --git a/tao_compiler/mlir/ral/collective.cu.cc b/tao_compiler/mlir/ral/collective.cu.cc old mode 100644 new mode 100755 index 35f92258aa0..d3e8870c5c6 --- a/tao_compiler/mlir/ral/collective.cu.cc +++ b/tao_compiler/mlir/ral/collective.cu.cc @@ -63,6 +63,9 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, auto& dictAttr = attr->as(); std::string reductionKind = dictAttr.get("reduction_kind").template as().getValue(); + + bool isAsync = dictAttr.get("is_async").template as().getValue(); + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; auto ncclReductionType = getNcclReductionType(reductionKind); @@ -87,9 +90,54 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, if (ncclResult != ncclSuccess) { ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n"); } + + if(isAsync && gpu_stream) { + int64_t token_key = + dictAttr.get("async_token_key").template as().getValue(); + cudaEvent_t event; + + auto event_status = cudaEventCreate(&event); + if (event_status != cudaSuccess) { + ctx->signalError(Context::FAILURE, "cudaEventCreate failed:" + cudaGetErrorString(event_status)+ "\n"); + } + + auto record_status = cudaEventRecord(event, gpu_stream); + if (record_status != cudaSuccess) { + cudaEventDestroy(event); + ctx->signalError(Context::FAILURE, "cudaEventRecord failed:" + cudaGetErrorString(record_status) + "\n"); + } + + static_cast(ctx)->addAsyncPairToken(token_key, event); + } + return output; } +template +MemRefType ral_async_collective_done(ExecutionContext* ctx, void* stream_handle, + MemRefType input, void* customAttrs) { + auto attr = + getOrParsePDLAttr(ctx, customAttrs, "simple_test_fused_add_mul_kernel"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs in ral_async_collective_done\n"); + } + auto& dictAttr = attr->as(); + int64_t token_key = + dictAttr.get("async_token_key").template as().getValue(); + auto event = + static_cast(ctx)->getAsyncPairToken(token_key); + if(event) { + auto sync_status = cudaEventSynchronize(event); + if (sync_status != cudaSuccess) { + ctx->signalError(Context::FAILURE, "cudaEventSynchronize failed: " + cudaGetErrorString(sync_status) + "\n"); + } + static_cast(ctx)->removeAsyncPairToken(token_key); + cudaEventDestroy(event); + } + + return input; +} + TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); @@ -98,5 +146,26 @@ TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); + + + + +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done); + + + + + + + + + } // namespace ral } // namespace tao diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc old mode 100644 new mode 100755 index 0af72f978f2..39bd21125ca --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc @@ -123,6 +123,8 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { std::map blobs; // map -> callable kernel std::map, GpuFunctionHandle> kernels; + // map int64 -> cudaEvent_t + std::map async_pair_tokens; std::shared_ptr gpu_allocator; bool cache_workspace_mem_across_execution; @@ -206,6 +208,28 @@ ncclComm_t BaseCudaExecutionContext::getNcclComm() { return state->nccl_comm; } +cudaEvent_t BaseCudaExecutionContext::getAsyncPairToken(int64_t key) { + auto* state = getResource(kRalBaseCudaContextState); + if(state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) { + return state->async_pair_tokens[key]; + } + return nullptr; +} + +void BaseCudaExecutionContext::addAsyncPairToken(int64_t key, cudaEvent_t token) { + auto* state = getResource(kRalBaseCudaContextState); + state->async_pair_tokens[key] = token; + return; +} + +void removeAsyncPairToken(int64_t key) { + auto* state = getResource(kRalBaseCudaContextState); + if(state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) { + return state->async_pair_tokens.erase(key); + } + return; +} + void BaseCudaExecutionContext::setOutputDeleter(OutputBufferWrapper& output) { { if (synced) { diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h old mode 100644 new mode 100755 index 314f5f7285a..2e625c977df --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h @@ -64,6 +64,10 @@ struct BaseCudaExecutionContext ~BaseCudaExecutionContext(); ncclComm_t getNcclComm(); + + cudaEvent_t getAsyncPairToken(int64_t key); + void addAsyncPairToken(int64_t key, cudaEvent_t token); + void removeAsyncPairToken(int64_t key); // We need to sync on the gpu stream before we fetch the first output. bool synced = false; // all buffer allocated by the gpu_allocator