From 776e8babd6868b968d1724161a6999861723b08a Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Tue, 19 Sep 2023 09:47:42 +0300 Subject: [PATCH] oshmem: Add symmetric remote key handling code At very high scale, having each rank storing each other rank's remote keys for each segment can lead to high memory consumption. We activate symmetric remote key option to generate remote keys that will be deduplicated and then used interchangeably. Signed-off-by: Thomas Vegas --- config/ompi_check_ucx.m4 | 6 +- oshmem/mca/spml/ucx/spml_ucx.c | 183 +++++++++++++++++++++- oshmem/mca/spml/ucx/spml_ucx.h | 41 +++-- oshmem/mca/spml/ucx/spml_ucx_component.c | 39 +++-- oshmem/mca/sshmem/ucx/sshmem_ucx_module.c | 3 +- 5 files changed, 241 insertions(+), 31 deletions(-) diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index fbea98cd7b3..01e39aaf968 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -108,7 +108,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ UCP_PARAM_FIELD_ESTIMATED_NUM_PPN, UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK, UCP_OP_ATTR_FLAG_MULTI_SEND, - UCS_MEMORY_TYPE_RDMA], + UCS_MEMORY_TYPE_RDMA, + UCP_MEM_MAP_SYMMETRIC_RKEY], [], [], [#include ]) AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS], @@ -124,7 +125,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [#include ]) AC_CHECK_DECLS([ucp_tag_send_nbx, ucp_tag_send_sync_nbx, - ucp_tag_recv_nbx], + ucp_tag_recv_nbx, + ucp_rkey_compare], [], [], [#include ]) AC_CHECK_TYPES([ucp_request_param_t], diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 570b4d25a7a..5493d78e661 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -22,6 +22,7 @@ #include "opal/datatype/opal_convertor.h" #include "opal/mca/common/ucx/common_ucx.h" #include "opal/util/opal_environ.h" +#include "opal/util/minmax.h" #include "ompi/datatype/ompi_datatype.h" #include "ompi/mca/pml/pml.h" @@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = { }; #endif +unsigned +mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx) +{ +#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY + if (spml_ucx->symmetric_rkey_max_count > 0) { + return UCP_MEM_MAP_SYMMETRIC_RKEY; + } +#endif + + return 0; +} + +void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store) +{ + store->array = NULL; + store->count = 0; + store->size = 0; +} + +void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store) +{ + int i; + + for (i = 0; i < store->count; i++) { + if (store->array[i].refcnt != 0) { + SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0", + i, store->count, store->array[i].refcnt); + } + + ucp_rkey_destroy(store->array[i].rkey); + } + + free(store->array); +} + +/** + * Find position in sorted array for existing or future entry + * + * @param[in] store Store of the entries + * @param[in] worker Common worker for rkeys used + * @param[in] rkey Remote key to search for + * @param[out] index Index of entry + * + * @return + * OSHMEM_ERR_NOT_FOUND: index contains the position where future element + * should be inserted to keep array sorted + * OSHMEM_SUCCESS : index contains the position of the element + * Other error : index is not valid + */ +static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store, + const ucp_worker_h worker, + const ucp_rkey_h rkey, + int *index) +{ +#if HAVE_DECL_UCP_RKEY_COMPARE + ucp_rkey_compare_params_t params; + int i, result, m, end; + ucs_status_t status; + + for (i = 0, end = store->count; i < end;) { + m = (i + end) / 2; + + params.field_mask = 0; + status = ucp_rkey_compare(worker, store->array[m].rkey, + rkey, ¶ms, &result); + if (status != UCS_OK) { + return OSHMEM_ERROR; + } else if (result == 0) { + *index = m; + return OSHMEM_SUCCESS; + } else if (result > 0) { + end = m; + } else { + i = m + 1; + } + } + + *index = i; + return OSHMEM_ERR_NOT_FOUND; +#else + return OSHMEM_ERROR; +#endif +} + +static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store, + int i, ucp_rkey_h rkey) +{ + int size; + mca_spml_ucx_rkey_t *tmp; + + if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) { + return; + } + + if (store->count >= store->size) { + size = opal_min(opal_max(store->size, 8) * 2, + mca_spml_ucx.symmetric_rkey_max_count); + tmp = realloc(store->array, size * sizeof(*store->array)); + if (tmp == NULL) { + return; + } + + store->array = tmp; + store->size = size; + } + + memmove(&store->array[i + 1], &store->array[i], + (store->count - i) * sizeof(*store->array)); + store->array[i].rkey = rkey; + store->array[i].refcnt = 1; + store->count++; + return; +} + +/* Takes ownership of input ucp remote key */ +static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store, + ucp_worker_h worker, + ucp_rkey_h rkey) +{ + int ret, i; + + if (mca_spml_ucx.symmetric_rkey_max_count == 0) { + return rkey; + } + + ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i); + if (ret == OSHMEM_SUCCESS) { + ucp_rkey_destroy(rkey); + store->array[i].refcnt++; + return store->array[i].rkey; + } + + if (ret == OSHMEM_ERR_NOT_FOUND) { + mca_spml_ucx_rkey_store_insert(store, i, rkey); + } + + return rkey; +} + +static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store, + ucp_worker_h worker, + ucp_rkey_h rkey) +{ + mca_spml_ucx_rkey_t *entry; + int ret, i; + + ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i); + if (ret != OSHMEM_SUCCESS) { + goto out; + } + + entry = &store->array[i]; + assert(entry->rkey == rkey); + if (--entry->refcnt > 0) { + return; + } + + memmove(&store->array[i], &store->array[i + 1], + (store->count - (i + 1)) * sizeof(*store->array)); + store->count--; + +out: + ucp_rkey_destroy(rkey); +} + int mca_spml_ucx_enable(bool enable) { SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****"); @@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn { int rc; ucs_status_t err; + ucp_rkey_h rkey; rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey); if (OSHMEM_SUCCESS != rc) { @@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn } if (mkey->u.data) { - err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey)); + err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey); if (UCS_OK != err) { SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err)); return OSHMEM_ERROR; } + + if (!oshmem_proc_on_local_node(pe)) { + rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey); + } + + (*ucx_mkey)->rkey = rkey; + rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe); if (OSHMEM_SUCCESS != rc) { SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed"); @@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn ucp_peer_t *ucp_peer; int rc; ucp_peer = &(ucx_ctx->ucp_peers[pe]); - ucp_rkey_destroy(ucx_mkey->rkey); + mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey); ucx_mkey->rkey = NULL; rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno); if(OSHMEM_SUCCESS != rc){ @@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr, UCP_MEM_MAP_PARAM_FIELD_FLAGS; mem_map_params.address = addr; mem_map_params.length = size; - mem_map_params.flags = flags; + mem_map_params.flags = flags | + mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx); status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h); if (UCS_OK != status) { @@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx } } + mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store); + *ucx_ctx_p = ucx_ctx; return OSHMEM_SUCCESS; diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index a93ff3756a3..2fec131ad2d 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -76,18 +76,31 @@ struct ucp_peer { size_t mkeys_cnt; }; typedef struct ucp_peer ucp_peer_t; - + +/* An rkey_store entry */ +typedef struct mca_spml_ucx_rkey { + ucp_rkey_h rkey; + int refcnt; +} mca_spml_ucx_rkey_t; + +typedef struct mca_spml_ucx_rkey_store { + mca_spml_ucx_rkey_t *array; + int size; + int count; +} mca_spml_ucx_rkey_store_t; + struct mca_spml_ucx_ctx { - ucp_worker_h *ucp_worker; - ucp_peer_t *ucp_peers; - long options; - opal_bitmap_t put_op_bitmap; - unsigned long nb_progress_cnt; - unsigned int ucp_workers; - int *put_proc_indexes; - unsigned put_proc_count; - bool synchronized_quiet; - int strong_sync; + ucp_worker_h *ucp_worker; + ucp_peer_t *ucp_peers; + long options; + opal_bitmap_t put_op_bitmap; + unsigned long nb_progress_cnt; + unsigned int ucp_workers; + int *put_proc_indexes; + unsigned put_proc_count; + bool synchronized_quiet; + int strong_sync; + mca_spml_ucx_rkey_store_t rkey_store; }; typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t; @@ -128,6 +141,7 @@ struct mca_spml_ucx { unsigned long nb_ucp_worker_progress; unsigned int ucp_workers; unsigned int ucp_worker_cnt; + int symmetric_rkey_max_count; }; typedef struct mca_spml_ucx mca_spml_ucx_t; @@ -280,6 +294,11 @@ extern int mca_spml_ucx_team_fcollect(shmem_team_t team, void extern int mca_spml_ucx_team_reduce(shmem_team_t team, void *dest, const void *source, size_t nreduce, int operation, int datatype); +extern unsigned +mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx); + +extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store); +extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store); static inline int mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey) diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index 1ab00ac1786..e44a800a8be 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -153,6 +153,10 @@ static int mca_spml_ucx_component_register(void) "Enable asynchronous progress thread", &mca_spml_ucx.async_progress); + mca_spml_ucx_param_register_int("symmetric_rkey_max_count", 0, + "Size of the symmetric key store. Non-zero to enable, typical use 5000", + &mca_spml_ucx.symmetric_rkey_max_count); + mca_spml_ucx_param_register_int("async_tick_usec", 3000, "Asynchronous progress tick granularity (in usec)", &mca_spml_ucx.async_tick); @@ -332,6 +336,8 @@ static int spml_ucx_init(void) mca_spml_ucx_ctx_default.ucp_workers++; } + mca_spml_ucx_rkey_store_init(&mca_spml_ucx_ctx_default.rkey_store); + wrk_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; err = ucp_worker_query(mca_spml_ucx_ctx_default.ucp_worker[0], &wrk_attr); @@ -436,10 +442,25 @@ static void _ctx_cleanup(mca_spml_ucx_ctx_t *ctx) free(ctx->ucp_peers); } +static void mca_spml_ucx_ctx_fini(mca_spml_ucx_ctx_t *ctx) +{ + unsigned int i; + + mca_spml_ucx_rkey_store_cleanup(&ctx->rkey_store); + for (i = 0; i < ctx->ucp_workers; i++) { + ucp_worker_destroy(ctx->ucp_worker[i]); + } + free(ctx->ucp_worker); + if (ctx != &mca_spml_ucx_ctx_default) { + free(ctx); + } +} + static int mca_spml_ucx_component_fini(void) { int fenced = 0, i; int ret = OSHMEM_SUCCESS; + mca_spml_ucx_ctx_t *ctx; opal_progress_unregister(spml_ucx_default_progress); if (mca_spml_ucx.active_array.ctxs_count) { @@ -492,36 +513,26 @@ static int mca_spml_ucx_component_fini(void) } } - /* delete all workers */ for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]); - free(mca_spml_ucx.active_array.ctxs[i]->ucp_worker); - free(mca_spml_ucx.active_array.ctxs[i]); + mca_spml_ucx_ctx_fini(mca_spml_ucx.active_array.ctxs[i]); } for (i = 0; i < mca_spml_ucx.idle_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker[0]); - free(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker); - free(mca_spml_ucx.idle_array.ctxs[i]); + mca_spml_ucx_ctx_fini(mca_spml_ucx.idle_array.ctxs[i]); } if (mca_spml_ucx_ctx_default.ucp_worker) { - for (i = 0; i < (signed int)mca_spml_ucx.ucp_workers; i++) { - ucp_worker_destroy(mca_spml_ucx_ctx_default.ucp_worker[i]); - } - free(mca_spml_ucx_ctx_default.ucp_worker); + mca_spml_ucx_ctx_fini(&mca_spml_ucx_ctx_default); } if (mca_spml_ucx.aux_ctx != NULL) { - ucp_worker_destroy(mca_spml_ucx.aux_ctx->ucp_worker[0]); - free(mca_spml_ucx.aux_ctx->ucp_worker); + mca_spml_ucx_ctx_fini(mca_spml_ucx.aux_ctx); } mca_spml_ucx.enabled = false; /* not anymore */ free(mca_spml_ucx.active_array.ctxs); free(mca_spml_ucx.idle_array.ctxs); - free(mca_spml_ucx.aux_ctx); SHMEM_MUTEX_DESTROY(mca_spml_ucx.internal_mutex); pthread_mutex_destroy(&mca_spml_ucx.ctx_create_mutex); diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c index 262bef5ffe6..688bfce6f19 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c @@ -118,7 +118,8 @@ segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, mem_map_params.address = address; mem_map_params.length = size; - mem_map_params.flags = flags; + mem_map_params.flags = flags | + mca_spml_ucx_mem_map_flags_symmetric_rkey(spml); mem_map_params.memory_type = mem_type; status = ucp_mem_map(spml->ucp_context, &mem_map_params, &mem_h);