Skip to content

Commit

Permalink
refactor MPI backend: move request polling to progress
Browse files Browse the repository at this point in the history
  • Loading branch information
JiakunYan committed Jan 28, 2024
1 parent 7741fb5 commit 347ce31
Showing 1 changed file with 165 additions and 103 deletions.
268 changes: 165 additions & 103 deletions src/backend/mpi/backend_mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace lcw
{
int g_rank = -1;
int g_nranks = -1;
LCT_queue_type_t cq_type;
int LCW_MPI_DEFAULT_QUEUE_LENGTH = 65536;

void backend_mpi_t::initialize()
{
Expand All @@ -17,6 +19,31 @@ void backend_mpi_t::initialize()
"Cannot enable multithreaded MPI!\n");
MPI_SAFECALL(MPI_Comm_rank(MPI_COMM_WORLD, &g_rank));
MPI_SAFECALL(MPI_Comm_size(MPI_COMM_WORLD, &g_nranks));

// Completion queue type
const LCT_queue_type_t cq_type_default = LCT_QUEUE_ARRAY_ATOMIC_FAA;
LCT_dict_str_int_t dict[] = {
{NULL, cq_type_default},
{"array_atomic_faa", LCT_QUEUE_ARRAY_ATOMIC_FAA},
{"array_atomic_cas", LCT_QUEUE_ARRAY_ATOMIC_CAS},
{"array_atomic_basic", LCT_QUEUE_ARRAY_ATOMIC_BASIC},
{"array_mutex", LCT_QUEUE_ARRAY_MUTEX},
{"std_mutex", LCT_QUEUE_STD_MUTEX},
};
bool succeed = LCT_str_int_search(dict, sizeof(dict) / sizeof(dict[0]),
getenv("LCW_MPI_CQ_TYPE"), cq_type_default,
(int*)&cq_type);
if (!succeed) {
LCW_Warn("Unknown LCI_CQ_TYPE %s. Use the default type: array_atomic_faa\n",
getenv("LCI_CQ_TYPE"));
}
LCW_Log(LCW_LOG_INFO, "comp", "Set LCW_MPI_CQ_TYPE to %d\n", cq_type);

// Completion queue length
{
char* p = getenv("LCW_MPI_DEFAULT_QUEUE_LENGTH");
if (p) LCW_MPI_DEFAULT_QUEUE_LENGTH = atoi(p);
}
}

void backend_mpi_t::finalize() { MPI_SAFECALL(MPI_Finalize()); }
Expand All @@ -27,29 +54,32 @@ int64_t backend_mpi_t::get_nranks() { return g_nranks; }

namespace mpi
{
struct cq_entry_t {
MPI_Request request;
request_t context;
struct progress_entry_t {
MPI_Request mpi_req;
comp_t completion;
request_t* request;
};

struct cq_t {
std::deque<cq_entry_t> entries;
spinlock_t lock;
struct progress_engine_t {
progress_entry_t put_entry;
spinlock_t put_entry_lock;
std::deque<progress_entry_t> entries;
spinlock_t entries_lock;
};

struct device_t {
MPI_Comm comm_2sided;
MPI_Comm comm_1sided;
std::vector<char> put_rbuf;
tag_t max_tag;
progress_engine_t pengine;
};

void push_cq(comp_t completion, mpi::cq_entry_t entry)
void add_to_progress_engine(mpi::device_t* device, mpi::progress_entry_t entry)
{
auto* cq = reinterpret_cast<mpi::cq_t*>(completion);
cq->lock.lock();
cq->entries.push_back(entry);
cq->lock.unlock();
device->pengine.entries_lock.lock();
device->pengine.entries.push_back(entry);
device->pengine.entries_lock.unlock();
}

const int PUT_SIGNAL_TAG = 0;
Expand All @@ -58,9 +88,10 @@ const int PUT_SIGNAL_TAG = 0;
void post_put_recv(device_t device, comp_t completion)
{
auto* device_p = reinterpret_cast<mpi::device_t*>(device);
mpi::cq_entry_t entry = {
.request = MPI_REQUEST_NULL,
.context = {
mpi::progress_entry_t entry = {
.mpi_req = MPI_REQUEST_NULL,
.completion = completion,
.request = new request_t{
.op = op_t::PUT_SIGNAL,
.device = device,
.rank = -1,
Expand All @@ -69,10 +100,10 @@ void post_put_recv(device_t device, comp_t completion)
.length = static_cast<int64_t>(device_p->put_rbuf.size()),
.user_context = nullptr,
}};
MPI_SAFECALL(MPI_Irecv(entry.context.buffer, entry.context.length, MPI_CHAR,
MPI_ANY_SOURCE, entry.context.tag,
device_p->comm_1sided, &entry.request));
mpi::push_cq(completion, entry);
MPI_SAFECALL(MPI_Irecv(entry.request->buffer, entry.request->length, MPI_CHAR,
MPI_ANY_SOURCE, entry.request->tag,
device_p->comm_1sided, &entry.mpi_req));
device_p->pengine.put_entry = entry;
}

device_t backend_mpi_t::alloc_device(int64_t max_put_length, comp_t put_comp)
Expand All @@ -91,6 +122,8 @@ device_t backend_mpi_t::alloc_device(int64_t max_put_length, comp_t put_comp)
device_p->put_rbuf.resize(max_put_length);
MPI_SAFECALL(MPI_Comm_dup(MPI_COMM_WORLD, &device_p->comm_1sided));
post_put_recv(device, put_comp);
} else {
device_p->pengine.put_entry.mpi_req = MPI_REQUEST_NULL;
}
return device;
}
Expand All @@ -102,131 +135,160 @@ void backend_mpi_t::free_device(device_t device)
if (!device_p->put_rbuf.empty()) {
MPI_Comm_free(&device_p->comm_1sided);
}
delete device_p;
}

bool backend_mpi_t::do_progress(device_t device)
{
auto* device_p = reinterpret_cast<mpi::device_t*>(device);
return false;
// work on put
int succeed = 0;
MPI_Status status;
mpi::progress_entry_t entry;
if (device_p->pengine.put_entry.mpi_req != MPI_REQUEST_NULL &&
device_p->pengine.put_entry_lock.try_lock()) {
if (device_p->pengine.put_entry.mpi_req != MPI_REQUEST_NULL) {
MPI_SAFECALL(
MPI_Test(&device_p->pengine.put_entry.mpi_req, &succeed, &status));
if (succeed) {
entry = device_p->pengine.put_entry;
int count;
MPI_SAFECALL(MPI_Get_count(&status, MPI_CHAR, &count));
entry.request->length = count;
entry.request->tag = status.MPI_TAG;
entry.request->rank = status.MPI_SOURCE;
// Copy the data out and repost the receive
LCW_Assert(entry.request->op == op_t::PUT_SIGNAL, "Unexpected op\n");
void* buffer;
int ret =
posix_memalign(&buffer, LCW_CACHE_LINE, entry.request->length);
LCW_Assert(ret == 0, "posix_memalign(%ld) failed!\n",
entry.request->length);
memcpy(buffer, entry.request->buffer, entry.request->length);
entry.request->buffer = buffer;
post_put_recv(entry.request->device, entry.completion);
}
}
device_p->pengine.put_entry_lock.unlock();
}
if (!succeed) {
// work on sendrecv
if (device_p->pengine.entries.empty() ||
!device_p->pengine.entries_lock.try_lock())
return false;
if (device_p->pengine.entries.empty()) {
device_p->pengine.entries_lock.unlock();
return false;
}
entry = device_p->pengine.entries.front();
device_p->pengine.entries.pop_front();
if (entry.mpi_req == MPI_REQUEST_NULL)
succeed = 1;
else {
MPI_SAFECALL(MPI_Test(&entry.mpi_req, &succeed, &status));
}
if (!succeed) {
device_p->pengine.entries.push_back(entry);
}
device_p->pengine.entries_lock.unlock();
if (!succeed) return false;
// We have got something
LCW_Assert(entry.request->op != op_t::PUT_SIGNAL, "Unexpected op\n");
if (entry.request->op == op_t::RECV) {
int count;
MPI_SAFECALL(MPI_Get_count(&status, MPI_CHAR, &count));
entry.request->length = count;
entry.request->tag = status.MPI_TAG;
entry.request->rank = status.MPI_SOURCE;
}
}
auto cq = reinterpret_cast<LCT_queue_t>(entry.completion);
LCT_queue_push(cq, entry.request);
return true;
}

comp_t backend_mpi_t::alloc_cq()
{
auto* cq = new mpi::cq_t;
LCT_queue_t cq = LCT_queue_alloc(cq_type, LCW_MPI_DEFAULT_QUEUE_LENGTH);
return reinterpret_cast<comp_t>(cq);
}

void backend_mpi_t::free_cq(comp_t completion)
{
auto* cq = reinterpret_cast<mpi::cq_t*>(completion);
while (!cq->entries.empty()) {
auto entry = cq->entries.front();
cq->entries.pop_front();
}
delete cq;
auto cq = reinterpret_cast<LCT_queue_t>(completion);
LCT_queue_free(&cq);
}

bool backend_mpi_t::poll_cq(comp_t completion, request_t* request)
{
auto* cq = reinterpret_cast<mpi::cq_t*>(completion);
if (cq->entries.empty() || !cq->lock.try_lock()) return false;
if (cq->entries.empty()) {
cq->lock.unlock();
return false;
}
auto entry = cq->entries.front();
cq->entries.pop_front();
int succeed = 0;
MPI_Status status;
if (entry.request == MPI_REQUEST_NULL)
succeed = 1;
else {
MPI_SAFECALL(MPI_Test(&entry.request, &succeed, &status));
}
if (succeed) {
*request = entry.context;
} else {
cq->entries.push_back(entry);
}
cq->lock.unlock();
if (!succeed) return false;
if (request->op == op_t::RECV || request->op == op_t::PUT_SIGNAL) {
int count;
MPI_SAFECALL(MPI_Get_count(&status, MPI_CHAR, &count));
request->length = count;
request->tag = status.MPI_TAG;
request->rank = status.MPI_SOURCE;
}
if (request->op == op_t::PUT_SIGNAL) {
// Copy the data out and repost the receive
void* buffer;
int ret = posix_memalign(&buffer, LCW_CACHE_LINE, request->length);
LCW_Assert(ret == 0, "posix_memalign(%ld) failed!\n", request->length);
memcpy(buffer, request->buffer, request->length);
request->buffer = buffer;
post_put_recv(request->device, completion);
}
return succeed;
auto cq = reinterpret_cast<LCT_queue_t>(completion);
auto* req = static_cast<request_t*>(LCT_queue_pop(cq));
if (req == nullptr) return false;
*request = *req;
return true;
}

bool backend_mpi_t::send(device_t device, rank_t rank, tag_t tag, void* buf,
int64_t length, comp_t completion, void* user_context)
{
auto* device_p = reinterpret_cast<mpi::device_t*>(device);
mpi::cq_entry_t entry = {.request = MPI_REQUEST_NULL,
.context = {
.op = op_t::SEND,
.device = device,
.rank = rank,
.tag = tag,
.buffer = buf,
.length = length,
.user_context = user_context,
}};
mpi::progress_entry_t entry = {.mpi_req = MPI_REQUEST_NULL,
.completion = completion,
.request = new request_t{
.op = op_t::SEND,
.device = device,
.rank = rank,
.tag = tag,
.buffer = buf,
.length = length,
.user_context = user_context,
}};
MPI_SAFECALL(MPI_Isend(buf, length, MPI_CHAR, rank, tag,
device_p->comm_2sided, &entry.request));
push_cq(completion, entry);
device_p->comm_2sided, &entry.mpi_req));
add_to_progress_engine(device_p, entry);
return true;
}

bool backend_mpi_t::recv(device_t device, rank_t rank, tag_t tag, void* buf,
int64_t length, comp_t completion, void* user_context)
{
auto* device_p = reinterpret_cast<mpi::device_t*>(device);
mpi::cq_entry_t entry = {.request = MPI_REQUEST_NULL,
.context = {
.op = op_t::RECV,
.device = device,
.rank = rank,
.tag = tag,
.buffer = buf,
.length = length,
.user_context = user_context,
}};
mpi::progress_entry_t entry = {.mpi_req = MPI_REQUEST_NULL,
.completion = completion,
.request = new request_t{
.op = op_t::RECV,
.device = device,
.rank = rank,
.tag = tag,
.buffer = buf,
.length = length,
.user_context = user_context,
}};
MPI_SAFECALL(MPI_Irecv(buf, length, MPI_CHAR, rank, tag,
device_p->comm_2sided, &entry.request));
push_cq(completion, entry);
device_p->comm_2sided, &entry.mpi_req));
add_to_progress_engine(device_p, entry);
return true;
}

bool backend_mpi_t::put(device_t device, rank_t rank, void* buf, int64_t length,
comp_t completion, void* user_context)
{
auto* device_p = reinterpret_cast<mpi::device_t*>(device);
mpi::cq_entry_t entry = {.request = MPI_REQUEST_NULL,
.context = {
.op = op_t::PUT,
.device = device,
.rank = rank,
.tag = mpi::PUT_SIGNAL_TAG,
.buffer = buf,
.length = length,
.user_context = user_context,
}};
MPI_SAFECALL(MPI_Isend(entry.context.buffer, entry.context.length, MPI_CHAR,
entry.context.rank, entry.context.tag,
device_p->comm_1sided, &entry.request));
push_cq(completion, entry);
mpi::progress_entry_t entry = {.mpi_req = MPI_REQUEST_NULL,
.completion = completion,
.request = new request_t{
.op = op_t::PUT,
.device = device,
.rank = rank,
.tag = mpi::PUT_SIGNAL_TAG,
.buffer = buf,
.length = length,
.user_context = user_context,
}};
MPI_SAFECALL(MPI_Isend(entry.request->buffer, entry.request->length, MPI_CHAR,
entry.request->rank, entry.request->tag,
device_p->comm_1sided, &entry.mpi_req));
add_to_progress_engine(device_p, entry);
return true;
}

Expand Down

0 comments on commit 347ce31

Please sign in to comment.