diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index fbea98cd7b3..01e39aaf968 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -108,7 +108,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ UCP_PARAM_FIELD_ESTIMATED_NUM_PPN, UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK, UCP_OP_ATTR_FLAG_MULTI_SEND, - UCS_MEMORY_TYPE_RDMA], + UCS_MEMORY_TYPE_RDMA, + UCP_MEM_MAP_SYMMETRIC_RKEY], [], [], [#include ]) AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS], @@ -124,7 +125,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [#include ]) AC_CHECK_DECLS([ucp_tag_send_nbx, ucp_tag_send_sync_nbx, - ucp_tag_recv_nbx], + ucp_tag_recv_nbx, + ucp_rkey_compare], [], [], [#include ]) AC_CHECK_TYPES([ucp_request_param_t], diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 570b4d25a7a..5493d78e661 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -22,6 +22,7 @@ #include "opal/datatype/opal_convertor.h" #include "opal/mca/common/ucx/common_ucx.h" #include "opal/util/opal_environ.h" +#include "opal/util/minmax.h" #include "ompi/datatype/ompi_datatype.h" #include "ompi/mca/pml/pml.h" @@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = { }; #endif +unsigned +mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx) +{ +#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY + if (spml_ucx->symmetric_rkey_max_count > 0) { + return UCP_MEM_MAP_SYMMETRIC_RKEY; + } +#endif + + return 0; +} + +void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store) +{ + store->array = NULL; + store->count = 0; + store->size = 0; +} + +void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store) +{ + int i; + + for (i = 0; i < store->count; i++) { + if (store->array[i].refcnt != 0) { + SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0", + i, store->count, store->array[i].refcnt); + } + + ucp_rkey_destroy(store->array[i].rkey); + } + + free(store->array); +} + +/** + * Find position in sorted array for existing or future entry + * + * @param[in] store Store of the entries + * @param[in] worker Common worker for rkeys used + * @param[in] rkey Remote key to search for + * @param[out] index Index of entry + * + * @return + * OSHMEM_ERR_NOT_FOUND: index contains the position where future element + * should be inserted to keep array sorted + * OSHMEM_SUCCESS : index contains the position of the element + * Other error : index is not valid + */ +static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store, + const ucp_worker_h worker, + const ucp_rkey_h rkey, + int *index) +{ +#if HAVE_DECL_UCP_RKEY_COMPARE + ucp_rkey_compare_params_t params; + int i, result, m, end; + ucs_status_t status; + + for (i = 0, end = store->count; i < end;) { + m = (i + end) / 2; + + params.field_mask = 0; + status = ucp_rkey_compare(worker, store->array[m].rkey, + rkey, ¶ms, &result); + if (status != UCS_OK) { + return OSHMEM_ERROR; + } else if (result == 0) { + *index = m; + return OSHMEM_SUCCESS; + } else if (result > 0) { + end = m; + } else { + i = m + 1; + } + } + + *index = i; + return OSHMEM_ERR_NOT_FOUND; +#else + return OSHMEM_ERROR; +#endif +} + +static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store, + int i, ucp_rkey_h rkey) +{ + int size; + mca_spml_ucx_rkey_t *tmp; + + if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) { + return; + } + + if (store->count >= store->size) { + size = opal_min(opal_max(store->size, 8) * 2, + mca_spml_ucx.symmetric_rkey_max_count); + tmp = realloc(store->array, size * sizeof(*store->array)); + if (tmp == NULL) { + return; + } + + store->array = tmp; + store->size = size; + } + + memmove(&store->array[i + 1], &store->array[i], + (store->count - i) * sizeof(*store->array)); + store->array[i].rkey = rkey; + store->array[i].refcnt = 1; + store->count++; + return; +} + +/* Takes ownership of input ucp remote key */ +static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store, + ucp_worker_h worker, + ucp_rkey_h rkey) +{ + int ret, i; + + if (mca_spml_ucx.symmetric_rkey_max_count == 0) { + return rkey; + } + + ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i); + if (ret == OSHMEM_SUCCESS) { + ucp_rkey_destroy(rkey); + store->array[i].refcnt++; + return store->array[i].rkey; + } + + if (ret == OSHMEM_ERR_NOT_FOUND) { + mca_spml_ucx_rkey_store_insert(store, i, rkey); + } + + return rkey; +} + +static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store, + ucp_worker_h worker, + ucp_rkey_h rkey) +{ + mca_spml_ucx_rkey_t *entry; + int ret, i; + + ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i); + if (ret != OSHMEM_SUCCESS) { + goto out; + } + + entry = &store->array[i]; + assert(entry->rkey == rkey); + if (--entry->refcnt > 0) { + return; + } + + memmove(&store->array[i], &store->array[i + 1], + (store->count - (i + 1)) * sizeof(*store->array)); + store->count--; + +out: + ucp_rkey_destroy(rkey); +} + int mca_spml_ucx_enable(bool enable) { SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****"); @@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn { int rc; ucs_status_t err; + ucp_rkey_h rkey; rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey); if (OSHMEM_SUCCESS != rc) { @@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn } if (mkey->u.data) { - err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey)); + err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey); if (UCS_OK != err) { SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err)); return OSHMEM_ERROR; } + + if (!oshmem_proc_on_local_node(pe)) { + rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey); + } + + (*ucx_mkey)->rkey = rkey; + rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe); if (OSHMEM_SUCCESS != rc) { SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed"); @@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn ucp_peer_t *ucp_peer; int rc; ucp_peer = &(ucx_ctx->ucp_peers[pe]); - ucp_rkey_destroy(ucx_mkey->rkey); + mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey); ucx_mkey->rkey = NULL; rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno); if(OSHMEM_SUCCESS != rc){ @@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr, UCP_MEM_MAP_PARAM_FIELD_FLAGS; mem_map_params.address = addr; mem_map_params.length = size; - mem_map_params.flags = flags; + mem_map_params.flags = flags | + mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx); status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h); if (UCS_OK != status) { @@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx } } + mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store); + *ucx_ctx_p = ucx_ctx; return OSHMEM_SUCCESS; diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index a93ff3756a3..2fec131ad2d 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -76,18 +76,31 @@ struct ucp_peer { size_t mkeys_cnt; }; typedef struct ucp_peer ucp_peer_t; - + +/* An rkey_store entry */ +typedef struct mca_spml_ucx_rkey { + ucp_rkey_h rkey; + int refcnt; +} mca_spml_ucx_rkey_t; + +typedef struct mca_spml_ucx_rkey_store { + mca_spml_ucx_rkey_t *array; + int size; + int count; +} mca_spml_ucx_rkey_store_t; + struct mca_spml_ucx_ctx { - ucp_worker_h *ucp_worker; - ucp_peer_t *ucp_peers; - long options; - opal_bitmap_t put_op_bitmap; - unsigned long nb_progress_cnt; - unsigned int ucp_workers; - int *put_proc_indexes; - unsigned put_proc_count; - bool synchronized_quiet; - int strong_sync; + ucp_worker_h *ucp_worker; + ucp_peer_t *ucp_peers; + long options; + opal_bitmap_t put_op_bitmap; + unsigned long nb_progress_cnt; + unsigned int ucp_workers; + int *put_proc_indexes; + unsigned put_proc_count; + bool synchronized_quiet; + int strong_sync; + mca_spml_ucx_rkey_store_t rkey_store; }; typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t; @@ -128,6 +141,7 @@ struct mca_spml_ucx { unsigned long nb_ucp_worker_progress; unsigned int ucp_workers; unsigned int ucp_worker_cnt; + int symmetric_rkey_max_count; }; typedef struct mca_spml_ucx mca_spml_ucx_t; @@ -280,6 +294,11 @@ extern int mca_spml_ucx_team_fcollect(shmem_team_t team, void extern int mca_spml_ucx_team_reduce(shmem_team_t team, void *dest, const void *source, size_t nreduce, int operation, int datatype); +extern unsigned +mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx); + +extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store); +extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store); static inline int mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey) diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index 1ab00ac1786..e44a800a8be 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -153,6 +153,10 @@ static int mca_spml_ucx_component_register(void) "Enable asynchronous progress thread", &mca_spml_ucx.async_progress); + mca_spml_ucx_param_register_int("symmetric_rkey_max_count", 0, + "Size of the symmetric key store. Non-zero to enable, typical use 5000", + &mca_spml_ucx.symmetric_rkey_max_count); + mca_spml_ucx_param_register_int("async_tick_usec", 3000, "Asynchronous progress tick granularity (in usec)", &mca_spml_ucx.async_tick); @@ -332,6 +336,8 @@ static int spml_ucx_init(void) mca_spml_ucx_ctx_default.ucp_workers++; } + mca_spml_ucx_rkey_store_init(&mca_spml_ucx_ctx_default.rkey_store); + wrk_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; err = ucp_worker_query(mca_spml_ucx_ctx_default.ucp_worker[0], &wrk_attr); @@ -436,10 +442,25 @@ static void _ctx_cleanup(mca_spml_ucx_ctx_t *ctx) free(ctx->ucp_peers); } +static void mca_spml_ucx_ctx_fini(mca_spml_ucx_ctx_t *ctx) +{ + unsigned int i; + + mca_spml_ucx_rkey_store_cleanup(&ctx->rkey_store); + for (i = 0; i < ctx->ucp_workers; i++) { + ucp_worker_destroy(ctx->ucp_worker[i]); + } + free(ctx->ucp_worker); + if (ctx != &mca_spml_ucx_ctx_default) { + free(ctx); + } +} + static int mca_spml_ucx_component_fini(void) { int fenced = 0, i; int ret = OSHMEM_SUCCESS; + mca_spml_ucx_ctx_t *ctx; opal_progress_unregister(spml_ucx_default_progress); if (mca_spml_ucx.active_array.ctxs_count) { @@ -492,36 +513,26 @@ static int mca_spml_ucx_component_fini(void) } } - /* delete all workers */ for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]); - free(mca_spml_ucx.active_array.ctxs[i]->ucp_worker); - free(mca_spml_ucx.active_array.ctxs[i]); + mca_spml_ucx_ctx_fini(mca_spml_ucx.active_array.ctxs[i]); } for (i = 0; i < mca_spml_ucx.idle_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker[0]); - free(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker); - free(mca_spml_ucx.idle_array.ctxs[i]); + mca_spml_ucx_ctx_fini(mca_spml_ucx.idle_array.ctxs[i]); } if (mca_spml_ucx_ctx_default.ucp_worker) { - for (i = 0; i < (signed int)mca_spml_ucx.ucp_workers; i++) { - ucp_worker_destroy(mca_spml_ucx_ctx_default.ucp_worker[i]); - } - free(mca_spml_ucx_ctx_default.ucp_worker); + mca_spml_ucx_ctx_fini(&mca_spml_ucx_ctx_default); } if (mca_spml_ucx.aux_ctx != NULL) { - ucp_worker_destroy(mca_spml_ucx.aux_ctx->ucp_worker[0]); - free(mca_spml_ucx.aux_ctx->ucp_worker); + mca_spml_ucx_ctx_fini(mca_spml_ucx.aux_ctx); } mca_spml_ucx.enabled = false; /* not anymore */ free(mca_spml_ucx.active_array.ctxs); free(mca_spml_ucx.idle_array.ctxs); - free(mca_spml_ucx.aux_ctx); SHMEM_MUTEX_DESTROY(mca_spml_ucx.internal_mutex); pthread_mutex_destroy(&mca_spml_ucx.ctx_create_mutex); diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c index 262bef5ffe6..688bfce6f19 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c @@ -118,7 +118,8 @@ segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, mem_map_params.address = address; mem_map_params.length = size; - mem_map_params.flags = flags; + mem_map_params.flags = flags | + mca_spml_ucx_mem_map_flags_symmetric_rkey(spml); mem_map_params.memory_type = mem_type; status = ucp_mem_map(spml->ucp_context, &mem_map_params, &mem_h);