Skip to content

Commit

Permalink
support async collective op execution
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Mar 14, 2024
1 parent f160eb2 commit 2c50126
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 9 deletions.
61 changes: 52 additions & 9 deletions tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op,
}
} // namespace
namespace {


bool IsAsyncCollective(Operation* op) {
if(llvm::isa<mhlo::AllReduceOp>(op)) {
if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_REDUCE")) {
return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0;
}
} else if(llvm::isa<mhlo::AllGatherOp>(op)) {
if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_GATHER")) {
return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0;
}
}

return false;
}

enum ReductionKind {
ALL_REDUCE_SUM,
ALL_REDUCE_PRODUCT,
Expand Down Expand Up @@ -192,6 +208,9 @@ struct CollectiveOpConverter : public OpRewritePattern<mhlo::AllReduceOp> {
if (!reductionKind) {
return failure();
}

bool is_async = IsAsyncCollective(op.getOperation());

for (int i = 0; i < op->getOperands().size(); ++i) {
// no need call all_reduce op if no consumer
if (op->getResult(i).getUsers().empty()) {
Expand All @@ -206,19 +225,43 @@ struct CollectiveOpConverter : public OpRewritePattern<mhlo::AllReduceOp> {
op->setAttr("output_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_input_layouts", rewriter.getStringAttr("*"));
op->setAttr("expected_output_layouts", rewriter.getStringAttr("*"));
SmallVector<NamedAttribute> newAttrs;
newAttrs.push_back(
NamedAttribute(rewriter.getStringAttr("reduction_kind"),
rewriter.getStringAttr(reductionKind.value())));

auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs);

op->setAttr("custom_attrs", newCustomAttrs);
SmallVector<NamedAttribute> attrs;
attrs.push_back(
NamedAttribute(rewriter.getStringAttr("reduction_kind"),
rewriter.getStringAttr(reductionKind.value()))
);
attrs.push_back(
NamedAttribute(rewriter.getStringAttr("is_async"),
rewriter.getBoolAttr(is_async))
);
auto customAttrs = DictionaryAttr::get(op->getContext(), attrs);
op->setAttr("custom_attrs", customAttrs);

auto newOutput = rewriter.create<mhlo_disc::CustomCallV2Op>(
auto reduce_op = rewriter.create<mhlo_disc::CustomCallV2Op>(
op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i],
op->getAttrs());
newOutputs.push_back(newOutput.getResult(0));

if(is_async) {
int64_t async_pair_token = reinterpret_cast<int64_t>(reduce_op.getOperation());
attrs.push_back(
NamedAttribute(rewriter.getStringAttr("async_token_key"),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), async_pair_token))
);
auto newCustomAttrs = DictionaryAttr::get(reduce_op->getContext(), attrs);
reduce_op->setAttr("custom_attrs", newCustomAttrs);
}

if(is_async) {
// Insert CollectiveDoneOp
auto collective_done_op = rewriter.create<mhlo_disc::CustomCallV2Op>(
reduce_op->getLoc(), reduce_op->getResults()[0].getType(), reduce_op->getResults()[0],
reduce_op->getAttrs());
collective_done_op->setAttr("call_target_name", rewriter.getStringAttr("ral_async_collective_done"));
newOutputs.push_back(collective_done_op.getResult(0));
} else {
newOutputs.push_back(reduce_op.getResult(0));
}
}
rewriter.replaceOp(op, newOutputs);
return success();
Expand Down
69 changes: 69 additions & 0 deletions tao_compiler/mlir/ral/collective.cu.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ MemRefType<T, N> ral_all_reduce(ExecutionContext* ctx, void* stream_handle,
auto& dictAttr = attr->as<DictPDLAttr>();
std::string reductionKind =
dictAttr.get("reduction_kind").template as<StrPDLAttr>().getValue();

bool isAsync = dictAttr.get("is_async").template as<BoolPDLAttr>().getValue();

ncclDataType_t ncclDtype = ncclDataTypeMapper<T>::value;
auto ncclReductionType = getNcclReductionType(reductionKind);

Expand All @@ -87,9 +90,54 @@ MemRefType<T, N> ral_all_reduce(ExecutionContext* ctx, void* stream_handle,
if (ncclResult != ncclSuccess) {
ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n");
}

if(isAsync && gpu_stream) {
int64_t token_key =
dictAttr.get("async_token_key").template as<IntPDLAttr>().getValue();
cudaEvent_t event;

auto event_status = cudaEventCreate(&event);
if (event_status != cudaSuccess) {
ctx->signalError(Context::FAILURE, "cudaEventCreate failed:" + cudaGetErrorString(event_status)+ "\n");
}

auto record_status = cudaEventRecord(event, gpu_stream);
if (record_status != cudaSuccess) {
cudaEventDestroy(event);
ctx->signalError(Context::FAILURE, "cudaEventRecord failed:" + cudaGetErrorString(record_status) + "\n");
}

static_cast<gpu::BaseCudaExecutionContext*>(ctx)->addAsyncPairToken(token_key, event);
}

return output;
}

template <typename T, int N>
MemRefType<T, N> ral_async_collective_done(ExecutionContext* ctx, void* stream_handle,
MemRefType<T, N> input, void* customAttrs) {
auto attr =
getOrParsePDLAttr(ctx, customAttrs, "simple_test_fused_add_mul_kernel");
if (!attr) {
ctx->signalError(Context::FAILURE, "fail to parse custom_attrs in ral_async_collective_done\n");
}
auto& dictAttr = attr->as<DictPDLAttr>();
int64_t token_key =
dictAttr.get("async_token_key").template as<IntPDLAttr>().getValue();
auto event =
static_cast<gpu::BaseCudaExecutionContext*>(ctx)->getAsyncPairToken(token_key);
if(event) {
auto sync_status = cudaEventSynchronize(event);
if (sync_status != cudaSuccess) {
ctx->signalError(Context::FAILURE, "cudaEventSynchronize failed: " + cudaGetErrorString(sync_status) + "\n");
}
static_cast<gpu::BaseCudaExecutionContext*>(ctx)->removeAsyncPairToken(token_key);
cudaEventDestroy(event);
}

return input;
}

TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float, 1>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float, 2>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float, 3>);
Expand All @@ -98,5 +146,26 @@ TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 1>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 2>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 3>);
TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce<float16, 4>);




TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float, 1>);
TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float, 2>);
TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float, 3>);
TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float, 4>);
TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float16, 1>);
TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float16, 2>);
TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float16, 3>);
TAO_RAL_API("ral_async_collective_done", "gpu", ral_async_collective_done<float16, 4>);









} // namespace ral
} // namespace tao

0 comments on commit 2c50126

Please sign in to comment.