Skip to content

Commit

Permalink
support async collective op execution
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Mar 18, 2024
1 parent f160eb2 commit ceadedf
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 9 deletions.
63 changes: 55 additions & 8 deletions tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,21 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op,
}
} // namespace
namespace {

bool IsAsyncCollective(Operation* op) {
if (llvm::isa<mhlo::AllReduceOp>(op)) {
if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_REDUCE")) {
return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0;
}
} else if (llvm::isa<mhlo::AllGatherOp>(op)) {
if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_GATHER")) {
return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0;
}
}

return false;
}

enum ReductionKind {
ALL_REDUCE_SUM,
ALL_REDUCE_PRODUCT,
Expand Down Expand Up @@ -192,6 +207,9 @@ struct CollectiveOpConverter : public OpRewritePattern<mhlo::AllReduceOp> {
if (!reductionKind) {
return failure();
}

bool is_async = IsAsyncCollective(op.getOperation());

for (int i = 0; i < op->getOperands().size(); ++i) {
// no need call all_reduce op if no consumer
if (op->getResult(i).getUsers().empty()) {
Expand All @@ -206,19 +224,48 @@ struct CollectiveOpConverter : public OpRewritePattern<mhlo::AllReduceOp> {
op->setAttr("output_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_input_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_output_layouts", rewriter.getStringAttr("*"));
SmallVector<NamedAttribute> newAttrs;
newAttrs.push_back(

SmallVector<NamedAttribute> attrs;
attrs.push_back(
NamedAttribute(rewriter.getStringAttr("reduction_kind"),
rewriter.getStringAttr(reductionKind.value())));
attrs.push_back(NamedAttribute(rewriter.getStringAttr("is_async"),
rewriter.getBoolAttr(is_async)));
auto customAttrs = DictionaryAttr::get(op->getContext(), attrs);
op->setAttr("custom_attrs", customAttrs);

auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs);

op->setAttr("custom_attrs", newCustomAttrs);

auto newOutput = rewriter.create<mhlo_disc::CustomCallV2Op>(
auto reduce_op = rewriter.create<mhlo_disc::CustomCallV2Op>(
op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i],
op->getAttrs());
newOutputs.push_back(newOutput.getResult(0));

if (is_async) {
int64_t async_pair_token =
reinterpret_cast<int64_t>(reduce_op.getOperation());
attrs.push_back(
NamedAttribute(rewriter.getStringAttr("async_token_key"),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
async_pair_token)));
auto newCustomAttrs =
DictionaryAttr::get(reduce_op->getContext(), attrs);
reduce_op->setAttr("custom_attrs", newCustomAttrs);
}

if (is_async) {
// Insert CollectiveDoneOp
auto collective_done_op = rewriter.create<mhlo_disc::CustomCallV2Op>(
reduce_op->getLoc(), reduce_op->getResults()[0].getType(),
reduce_op->getResults()[0], reduce_op->getAttrs());
collective_done_op->setAttr(
"call_target_name",
rewriter.getStringAttr("ral_async_collective_done"));

// Place collective_done_op right before the first consumer.
collective_done_op->moveBefore(*(op->getResult(i).user_begin()));

newOutputs.push_back(collective_done_op.getResult(0));
} else {
newOutputs.push_back(reduce_op.getResult(0));
}
}
rewriter.replaceOp(op, newOutputs);
return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ENABLE_ASYNC_ALL_REDUCE=true disc-opt -disc-mhlo-decomp-rewriter -split-input-file %s -o - | FileCheck %s

func.func @main(%arg0: tensor<f32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<f32>) {
// CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {async_token_key = 94869518500912 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %2 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {async_token_key = 94869518501936 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<f32>) -> tensor<f32>
// CHECK: %1 = "mhlo_disc.custom_call_v2"(%0) {call_target_name = "ral_async_collective_done", custom_attrs = {async_token_key = 94869518500912 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %3 = "mhlo_disc.custom_call_v2"(%2) {call_target_name = "ral_async_collective_done", custom_attrs = {async_token_key = 94869518501936 : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<f32>) -> tensor<f32>
%0:2 = "mhlo.all_reduce"(%arg1, %arg0) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
mhlo.return %1 : tensor<f32>
}) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>, tensor<f32>) -> (tensor<4xf32>, tensor<f32>)
// CHECK: return %1, %3 : tensor<4xf32>, tensor<f32>
return %0#0, %0#1 : tensor<4xf32>, tensor<f32>
}
83 changes: 82 additions & 1 deletion tao_compiler/mlir/ral/collective.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ MemRefType<T, N> ral_all_reduce(ExecutionContext* ctx, void* stream_handle,
auto& dictAttr = attr->as<DictPDLAttr>();
std::string reductionKind =
dictAttr.get("reduction_kind").template as<StrPDLAttr>().getValue();

bool isAsync = dictAttr.get("is_async").template as<BoolPDLAttr>().getValue();

ncclDataType_t ncclDtype = ncclDataTypeMapper<T>::value;
auto ncclReductionType = getNcclReductionType(reductionKind);

Expand All @@ -74,7 +77,7 @@ MemRefType<T, N> ral_all_reduce(ExecutionContext* ctx, void* stream_handle,
auto gpu_driver = ctx->getDriver<tao::ral::gpu::GPUDriver>(
tao::ral::gpu::GPUDriver::name());
auto gpu_stream =
static_cast<cudaStream_t>(gpu_driver->asCUStream(ctx, stream_handle));
static_cast<gpu::BaseCudaExecutionContext*>(ctx)->getCommStream();
auto nccl_comm =
static_cast<gpu::BaseCudaExecutionContext*>(ctx)->getNcclComm();
auto ptr = static_cast<T*>(gpu_driver->alloc(ctx, element_count * sizeof(T)));
Expand All @@ -87,9 +90,69 @@ MemRefType<T, N> ral_all_reduce(ExecutionContext* ctx, void* stream_handle,
if (ncclResult != ncclSuccess) {
ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n");
}

if (isAsync && gpu_stream) {
int64_t token_key =
dictAttr.get("async_token_key").template as<IntPDLAttr>().getValue();
cudaEvent_t event;

auto event_status = cudaEventCreate(&event);
if (event_status != cudaSuccess) {
ctx->signalError(Context::FAILURE, "cudaEventCreate failed\n");
}

auto record_status = cudaEventRecord(event, gpu_stream);
if (record_status != cudaSuccess) {
cudaEventDestroy(event);
ctx->signalError(Context::FAILURE, "cudaEventRecord failed\n");
}

static_cast<gpu::BaseCudaExecutionContext*>(ctx)->addAsyncPairToken(
token_key, event);
}

return output;
}

template <typename T, int N>
MemRefType<T, N> ral_async_collective_done(ExecutionContext* ctx,
void* stream_handle,
MemRefType<T, N> input,
void* customAttrs) {
auto attr =
getOrParsePDLAttr(ctx, customAttrs, "simple_test_fused_add_mul_kernel");
if (!attr) {
ctx->signalError(
Context::FAILURE,
"fail to parse custom_attrs in ral_async_collective_done\n");
}

auto& dictAttr = attr->as<DictPDLAttr>();
int64_t token_key =
dictAttr.get("async_token_key").template as<IntPDLAttr>().getValue();
auto event =
static_cast<gpu::BaseCudaExecutionContext*>(ctx)->getAsyncPairToken(
token_key);
if (event) {
auto sync_status = cudaEventSynchronize(event);
if (sync_status != cudaSuccess) {
ctx->signalError(Context::FAILURE, "cudaEventSynchronize failed\n");
}
static_cast<gpu::BaseCudaExecutionContext*>(ctx)->removeAsyncPairToken(
token_key);
cudaEventDestroy(event);
}

// Increase ref count for input to prevent double free
auto it =
static_cast<gpu::BaseCudaExecutionContext*>(ctx)->device_ptr_map.find(
input.data);
;
++it->second;

return input;
}

TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float, 1>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float, 2>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float, 3>);
Expand All @@ -98,5 +161,23 @@ TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 1>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 2>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 3>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 4>);

TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float, 1>);
TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float, 2>);
TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float, 3>);
TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float, 4>);
TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float16, 1>);
TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float16, 2>);
TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float16, 3>);
TAO_RAL_API("ral_async_collective_done", "gpu",
ral_async_collective_done<float16, 4>);

} // namespace ral
} // namespace tao
33 changes: 33 additions & 0 deletions tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,13 @@ struct BaseCudaContextState : public tao::ral::Context::Resource {

ncclComm_t nccl_comm = nullptr;
GpuStreamHandle stream = nullptr;
GpuStreamHandle comm_stream = nullptr;
// map blob ptr -> loaded module
std::map<void*, GpuModuleHandle> blobs;
// map <blob ptr, kernel name> -> callable kernel
std::map<std::pair<void*, std::string>, GpuFunctionHandle> kernels;
// map int64 -> cudaEvent_t
std::map<int64_t, cudaEvent_t> async_pair_tokens;

std::shared_ptr<Allocator> gpu_allocator;
bool cache_workspace_mem_across_execution;
Expand All @@ -146,6 +149,7 @@ struct BaseCudaContextState : public tao::ral::Context::Resource {
"StreamSync");
#else
reportErrorIfAny(cuStreamSynchronize(stream), ctx, "StreamSync");
reportErrorIfAny(cuStreamSynchronize(comm_stream), ctx, "StreamSync");
#endif
for (const_buffer_t buffer : device_persistent_buffers) {
gpu_allocator->dealloc(const_cast<buffer_t>(buffer));
Expand Down Expand Up @@ -173,6 +177,7 @@ std::unique_ptr<BaseContext> MakeBaseCudaContext(
auto state = new BaseCudaContextState;
state->stream = gpu_opt.stream;
state->nccl_comm = gpu_opt.nccl_comm;
state->comm_stream = gpu_opt.comm_stream;
if (gpu_opt.gpu_allocator != nullptr) {
state->gpu_allocator = gpu_opt.gpu_allocator;
} else {
Expand Down Expand Up @@ -206,6 +211,34 @@ ncclComm_t BaseCudaExecutionContext::getNcclComm() {
return state->nccl_comm;
}

GpuStreamHandle BaseCudaExecutionContext::getCommStream() {
auto* state = getResource<BaseCudaContextState>(kRalBaseCudaContextState);
return state->comm_stream;
}

cudaEvent_t BaseCudaExecutionContext::getAsyncPairToken(int64_t key) {
auto* state = getResource<BaseCudaContextState>(kRalBaseCudaContextState);
if (state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) {
return state->async_pair_tokens[key];
}
return nullptr;
}

void BaseCudaExecutionContext::addAsyncPairToken(int64_t key,
cudaEvent_t token) {
auto* state = getResource<BaseCudaContextState>(kRalBaseCudaContextState);
state->async_pair_tokens[key] = token;
return;
}

void BaseCudaExecutionContext::removeAsyncPairToken(int64_t key) {
auto* state = getResource<BaseCudaContextState>(kRalBaseCudaContextState);
if (state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) {
state->async_pair_tokens.erase(key);
}
return;
}

void BaseCudaExecutionContext::setOutputDeleter(OutputBufferWrapper& output) {
{
if (synced) {
Expand Down
7 changes: 7 additions & 0 deletions tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ using GpuStreamHandle = CUstream;
struct BaseCudaContextOption {
ncclComm_t nccl_comm = nullptr;
GpuStreamHandle stream = nullptr;
GpuStreamHandle comm_stream = nullptr;
int device_ordinal = 0;
bool use_stream_executor = true;
bool cache_workspace_mem_across_execution = false;
Expand All @@ -64,6 +65,12 @@ struct BaseCudaExecutionContext
~BaseCudaExecutionContext();

ncclComm_t getNcclComm();

GpuStreamHandle getCommStream();

cudaEvent_t getAsyncPairToken(int64_t key);
void addAsyncPairToken(int64_t key, cudaEvent_t token);
void removeAsyncPairToken(int64_t key);
// We need to sync on the gpu stream before we fetch the first output.
bool synced = false;
// all buffer allocated by the gpu_allocator
Expand Down

0 comments on commit ceadedf

Please sign in to comment.