Skip to content

Commit

Permalink
oshmem: Add symmetric remote key handling code
Browse files Browse the repository at this point in the history
At very high scale, having each rank storing each other rank's remote keys for
each segment can lead to high memory consumption.

We activate symmetric remote key option to generate remote keys that will be
deduplicated and then used interchangeably.

Signed-off-by: Thomas Vegas <[email protected]>
  • Loading branch information
tvegas1 committed Oct 9, 2023
1 parent 3de90b1 commit 776e8ba
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 31 deletions.
6 changes: 4 additions & 2 deletions config/ompi_check_ucx.m4
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
UCP_PARAM_FIELD_ESTIMATED_NUM_PPN,
UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK,
UCP_OP_ATTR_FLAG_MULTI_SEND,
UCS_MEMORY_TYPE_RDMA],
UCS_MEMORY_TYPE_RDMA,
UCP_MEM_MAP_SYMMETRIC_RKEY],
[], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS],
Expand All @@ -124,7 +125,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([ucp_tag_send_nbx,
ucp_tag_send_sync_nbx,
ucp_tag_recv_nbx],
ucp_tag_recv_nbx,
ucp_rkey_compare],
[], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_TYPES([ucp_request_param_t],
Expand Down
183 changes: 180 additions & 3 deletions oshmem/mca/spml/ucx/spml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "opal/datatype/opal_convertor.h"
#include "opal/mca/common/ucx/common_ucx.h"
#include "opal/util/opal_environ.h"
#include "opal/util/minmax.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/pml/pml.h"

Expand Down Expand Up @@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
};
#endif

unsigned
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx)
{
#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
if (spml_ucx->symmetric_rkey_max_count > 0) {
return UCP_MEM_MAP_SYMMETRIC_RKEY;
}
#endif

return 0;
}

void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store)
{
store->array = NULL;
store->count = 0;
store->size = 0;
}

void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store)
{
int i;

for (i = 0; i < store->count; i++) {
if (store->array[i].refcnt != 0) {
SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0",
i, store->count, store->array[i].refcnt);
}

ucp_rkey_destroy(store->array[i].rkey);
}

free(store->array);
}

/**
* Find position in sorted array for existing or future entry
*
* @param[in] store Store of the entries
* @param[in] worker Common worker for rkeys used
* @param[in] rkey Remote key to search for
* @param[out] index Index of entry
*
* @return
* OSHMEM_ERR_NOT_FOUND: index contains the position where future element
* should be inserted to keep array sorted
* OSHMEM_SUCCESS : index contains the position of the element
* Other error : index is not valid
*/
static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store,
const ucp_worker_h worker,
const ucp_rkey_h rkey,
int *index)
{
#if HAVE_DECL_UCP_RKEY_COMPARE
ucp_rkey_compare_params_t params;
int i, result, m, end;
ucs_status_t status;

for (i = 0, end = store->count; i < end;) {
m = (i + end) / 2;

params.field_mask = 0;
status = ucp_rkey_compare(worker, store->array[m].rkey,
rkey, &params, &result);
if (status != UCS_OK) {
return OSHMEM_ERROR;
} else if (result == 0) {
*index = m;
return OSHMEM_SUCCESS;
} else if (result > 0) {
end = m;
} else {
i = m + 1;
}
}

*index = i;
return OSHMEM_ERR_NOT_FOUND;
#else
return OSHMEM_ERROR;
#endif
}

static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store,
int i, ucp_rkey_h rkey)
{
int size;
mca_spml_ucx_rkey_t *tmp;

if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) {
return;
}

if (store->count >= store->size) {
size = opal_min(opal_max(store->size, 8) * 2,
mca_spml_ucx.symmetric_rkey_max_count);
tmp = realloc(store->array, size * sizeof(*store->array));
if (tmp == NULL) {
return;
}

store->array = tmp;
store->size = size;
}

memmove(&store->array[i + 1], &store->array[i],
(store->count - i) * sizeof(*store->array));
store->array[i].rkey = rkey;
store->array[i].refcnt = 1;
store->count++;
return;
}

/* Takes ownership of input ucp remote key */
static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store,
ucp_worker_h worker,
ucp_rkey_h rkey)
{
int ret, i;

if (mca_spml_ucx.symmetric_rkey_max_count == 0) {
return rkey;
}

ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
if (ret == OSHMEM_SUCCESS) {
ucp_rkey_destroy(rkey);
store->array[i].refcnt++;
return store->array[i].rkey;
}

if (ret == OSHMEM_ERR_NOT_FOUND) {
mca_spml_ucx_rkey_store_insert(store, i, rkey);
}

return rkey;
}

static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store,
ucp_worker_h worker,
ucp_rkey_h rkey)
{
mca_spml_ucx_rkey_t *entry;
int ret, i;

ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
if (ret != OSHMEM_SUCCESS) {
goto out;
}

entry = &store->array[i];
assert(entry->rkey == rkey);
if (--entry->refcnt > 0) {
return;
}

memmove(&store->array[i], &store->array[i + 1],
(store->count - (i + 1)) * sizeof(*store->array));
store->count--;

out:
ucp_rkey_destroy(rkey);
}

int mca_spml_ucx_enable(bool enable)
{
SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****");
Expand Down Expand Up @@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
{
int rc;
ucs_status_t err;
ucp_rkey_h rkey;

rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey);
if (OSHMEM_SUCCESS != rc) {
Expand All @@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
}

if (mkey->u.data) {
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey));
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey);
if (UCS_OK != err) {
SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err));
return OSHMEM_ERROR;
}

if (!oshmem_proc_on_local_node(pe)) {
rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey);
}

(*ucx_mkey)->rkey = rkey;

rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe);
if (OSHMEM_SUCCESS != rc) {
SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed");
Expand All @@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
ucp_peer_t *ucp_peer;
int rc;
ucp_peer = &(ucx_ctx->ucp_peers[pe]);
ucp_rkey_destroy(ucx_mkey->rkey);
mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey);
ucx_mkey->rkey = NULL;
rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno);
if(OSHMEM_SUCCESS != rc){
Expand Down Expand Up @@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
mem_map_params.address = addr;
mem_map_params.length = size;
mem_map_params.flags = flags;
mem_map_params.flags = flags |
mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx);

status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h);
if (UCS_OK != status) {
Expand Down Expand Up @@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
}
}

mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store);

*ucx_ctx_p = ucx_ctx;

return OSHMEM_SUCCESS;
Expand Down
41 changes: 30 additions & 11 deletions oshmem/mca/spml/ucx/spml_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,31 @@ struct ucp_peer {
size_t mkeys_cnt;
};
typedef struct ucp_peer ucp_peer_t;


/* An rkey_store entry */
typedef struct mca_spml_ucx_rkey {
ucp_rkey_h rkey;
int refcnt;
} mca_spml_ucx_rkey_t;

typedef struct mca_spml_ucx_rkey_store {
mca_spml_ucx_rkey_t *array;
int size;
int count;
} mca_spml_ucx_rkey_store_t;

struct mca_spml_ucx_ctx {
ucp_worker_h *ucp_worker;
ucp_peer_t *ucp_peers;
long options;
opal_bitmap_t put_op_bitmap;
unsigned long nb_progress_cnt;
unsigned int ucp_workers;
int *put_proc_indexes;
unsigned put_proc_count;
bool synchronized_quiet;
int strong_sync;
ucp_worker_h *ucp_worker;
ucp_peer_t *ucp_peers;
long options;
opal_bitmap_t put_op_bitmap;
unsigned long nb_progress_cnt;
unsigned int ucp_workers;
int *put_proc_indexes;
unsigned put_proc_count;
bool synchronized_quiet;
int strong_sync;
mca_spml_ucx_rkey_store_t rkey_store;
};
typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t;

Expand Down Expand Up @@ -128,6 +141,7 @@ struct mca_spml_ucx {
unsigned long nb_ucp_worker_progress;
unsigned int ucp_workers;
unsigned int ucp_worker_cnt;
int symmetric_rkey_max_count;
};
typedef struct mca_spml_ucx mca_spml_ucx_t;

Expand Down Expand Up @@ -280,6 +294,11 @@ extern int mca_spml_ucx_team_fcollect(shmem_team_t team, void
extern int mca_spml_ucx_team_reduce(shmem_team_t team, void
*dest, const void *source, size_t nreduce, int operation, int datatype);

extern unsigned
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx);

extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store);
extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store);

static inline int
mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey)
Expand Down
Loading

0 comments on commit 776e8ba

Please sign in to comment.