Skip to content

Commit 776e8ba

Browse files
committed
oshmem: Add symmetric remote key handling code
At very high scale, having each rank storing each other rank's remote keys for each segment can lead to high memory consumption. We activate symmetric remote key option to generate remote keys that will be deduplicated and then used interchangeably. Signed-off-by: Thomas Vegas <[email protected]>
1 parent 3de90b1 commit 776e8ba

File tree

5 files changed

+241
-31
lines changed

5 files changed

+241
-31
lines changed

config/ompi_check_ucx.m4

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
108108
UCP_PARAM_FIELD_ESTIMATED_NUM_PPN,
109109
UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK,
110110
UCP_OP_ATTR_FLAG_MULTI_SEND,
111-
UCS_MEMORY_TYPE_RDMA],
111+
UCS_MEMORY_TYPE_RDMA,
112+
UCP_MEM_MAP_SYMMETRIC_RKEY],
112113
[], [],
113114
[#include <ucp/api/ucp.h>])
114115
AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS],
@@ -124,7 +125,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
124125
[#include <ucp/api/ucp.h>])
125126
AC_CHECK_DECLS([ucp_tag_send_nbx,
126127
ucp_tag_send_sync_nbx,
127-
ucp_tag_recv_nbx],
128+
ucp_tag_recv_nbx,
129+
ucp_rkey_compare],
128130
[], [],
129131
[#include <ucp/api/ucp.h>])
130132
AC_CHECK_TYPES([ucp_request_param_t],

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 180 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "opal/datatype/opal_convertor.h"
2323
#include "opal/mca/common/ucx/common_ucx.h"
2424
#include "opal/util/opal_environ.h"
25+
#include "opal/util/minmax.h"
2526
#include "ompi/datatype/ompi_datatype.h"
2627
#include "ompi/mca/pml/pml.h"
2728

@@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
126127
};
127128
#endif
128129

130+
unsigned
131+
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx)
132+
{
133+
#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
134+
if (spml_ucx->symmetric_rkey_max_count > 0) {
135+
return UCP_MEM_MAP_SYMMETRIC_RKEY;
136+
}
137+
#endif
138+
139+
return 0;
140+
}
141+
142+
void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store)
143+
{
144+
store->array = NULL;
145+
store->count = 0;
146+
store->size = 0;
147+
}
148+
149+
void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store)
150+
{
151+
int i;
152+
153+
for (i = 0; i < store->count; i++) {
154+
if (store->array[i].refcnt != 0) {
155+
SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0",
156+
i, store->count, store->array[i].refcnt);
157+
}
158+
159+
ucp_rkey_destroy(store->array[i].rkey);
160+
}
161+
162+
free(store->array);
163+
}
164+
165+
/**
166+
* Find position in sorted array for existing or future entry
167+
*
168+
* @param[in] store Store of the entries
169+
* @param[in] worker Common worker for rkeys used
170+
* @param[in] rkey Remote key to search for
171+
* @param[out] index Index of entry
172+
*
173+
* @return
174+
* OSHMEM_ERR_NOT_FOUND: index contains the position where future element
175+
* should be inserted to keep array sorted
176+
* OSHMEM_SUCCESS : index contains the position of the element
177+
* Other error : index is not valid
178+
*/
179+
static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store,
180+
const ucp_worker_h worker,
181+
const ucp_rkey_h rkey,
182+
int *index)
183+
{
184+
#if HAVE_DECL_UCP_RKEY_COMPARE
185+
ucp_rkey_compare_params_t params;
186+
int i, result, m, end;
187+
ucs_status_t status;
188+
189+
for (i = 0, end = store->count; i < end;) {
190+
m = (i + end) / 2;
191+
192+
params.field_mask = 0;
193+
status = ucp_rkey_compare(worker, store->array[m].rkey,
194+
rkey, &params, &result);
195+
if (status != UCS_OK) {
196+
return OSHMEM_ERROR;
197+
} else if (result == 0) {
198+
*index = m;
199+
return OSHMEM_SUCCESS;
200+
} else if (result > 0) {
201+
end = m;
202+
} else {
203+
i = m + 1;
204+
}
205+
}
206+
207+
*index = i;
208+
return OSHMEM_ERR_NOT_FOUND;
209+
#else
210+
return OSHMEM_ERROR;
211+
#endif
212+
}
213+
214+
static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store,
215+
int i, ucp_rkey_h rkey)
216+
{
217+
int size;
218+
mca_spml_ucx_rkey_t *tmp;
219+
220+
if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) {
221+
return;
222+
}
223+
224+
if (store->count >= store->size) {
225+
size = opal_min(opal_max(store->size, 8) * 2,
226+
mca_spml_ucx.symmetric_rkey_max_count);
227+
tmp = realloc(store->array, size * sizeof(*store->array));
228+
if (tmp == NULL) {
229+
return;
230+
}
231+
232+
store->array = tmp;
233+
store->size = size;
234+
}
235+
236+
memmove(&store->array[i + 1], &store->array[i],
237+
(store->count - i) * sizeof(*store->array));
238+
store->array[i].rkey = rkey;
239+
store->array[i].refcnt = 1;
240+
store->count++;
241+
return;
242+
}
243+
244+
/* Takes ownership of input ucp remote key */
245+
static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store,
246+
ucp_worker_h worker,
247+
ucp_rkey_h rkey)
248+
{
249+
int ret, i;
250+
251+
if (mca_spml_ucx.symmetric_rkey_max_count == 0) {
252+
return rkey;
253+
}
254+
255+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
256+
if (ret == OSHMEM_SUCCESS) {
257+
ucp_rkey_destroy(rkey);
258+
store->array[i].refcnt++;
259+
return store->array[i].rkey;
260+
}
261+
262+
if (ret == OSHMEM_ERR_NOT_FOUND) {
263+
mca_spml_ucx_rkey_store_insert(store, i, rkey);
264+
}
265+
266+
return rkey;
267+
}
268+
269+
static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store,
270+
ucp_worker_h worker,
271+
ucp_rkey_h rkey)
272+
{
273+
mca_spml_ucx_rkey_t *entry;
274+
int ret, i;
275+
276+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
277+
if (ret != OSHMEM_SUCCESS) {
278+
goto out;
279+
}
280+
281+
entry = &store->array[i];
282+
assert(entry->rkey == rkey);
283+
if (--entry->refcnt > 0) {
284+
return;
285+
}
286+
287+
memmove(&store->array[i], &store->array[i + 1],
288+
(store->count - (i + 1)) * sizeof(*store->array));
289+
store->count--;
290+
291+
out:
292+
ucp_rkey_destroy(rkey);
293+
}
294+
129295
int mca_spml_ucx_enable(bool enable)
130296
{
131297
SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****");
@@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
240406
{
241407
int rc;
242408
ucs_status_t err;
409+
ucp_rkey_h rkey;
243410

244411
rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey);
245412
if (OSHMEM_SUCCESS != rc) {
@@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
248415
}
249416

250417
if (mkey->u.data) {
251-
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey));
418+
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey);
252419
if (UCS_OK != err) {
253420
SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err));
254421
return OSHMEM_ERROR;
255422
}
423+
424+
if (!oshmem_proc_on_local_node(pe)) {
425+
rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey);
426+
}
427+
428+
(*ucx_mkey)->rkey = rkey;
429+
256430
rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe);
257431
if (OSHMEM_SUCCESS != rc) {
258432
SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed");
@@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
267441
ucp_peer_t *ucp_peer;
268442
int rc;
269443
ucp_peer = &(ucx_ctx->ucp_peers[pe]);
270-
ucp_rkey_destroy(ucx_mkey->rkey);
444+
mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey);
271445
ucx_mkey->rkey = NULL;
272446
rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno);
273447
if(OSHMEM_SUCCESS != rc){
@@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
725899
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
726900
mem_map_params.address = addr;
727901
mem_map_params.length = size;
728-
mem_map_params.flags = flags;
902+
mem_map_params.flags = flags |
903+
mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx);
729904

730905
status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h);
731906
if (UCS_OK != status) {
@@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
9171092
}
9181093
}
9191094

1095+
mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store);
1096+
9201097
*ucx_ctx_p = ucx_ctx;
9211098

9221099
return OSHMEM_SUCCESS;

oshmem/mca/spml/ucx/spml_ucx.h

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,31 @@ struct ucp_peer {
7676
size_t mkeys_cnt;
7777
};
7878
typedef struct ucp_peer ucp_peer_t;
79-
79+
80+
/* An rkey_store entry */
81+
typedef struct mca_spml_ucx_rkey {
82+
ucp_rkey_h rkey;
83+
int refcnt;
84+
} mca_spml_ucx_rkey_t;
85+
86+
typedef struct mca_spml_ucx_rkey_store {
87+
mca_spml_ucx_rkey_t *array;
88+
int size;
89+
int count;
90+
} mca_spml_ucx_rkey_store_t;
91+
8092
struct mca_spml_ucx_ctx {
81-
ucp_worker_h *ucp_worker;
82-
ucp_peer_t *ucp_peers;
83-
long options;
84-
opal_bitmap_t put_op_bitmap;
85-
unsigned long nb_progress_cnt;
86-
unsigned int ucp_workers;
87-
int *put_proc_indexes;
88-
unsigned put_proc_count;
89-
bool synchronized_quiet;
90-
int strong_sync;
93+
ucp_worker_h *ucp_worker;
94+
ucp_peer_t *ucp_peers;
95+
long options;
96+
opal_bitmap_t put_op_bitmap;
97+
unsigned long nb_progress_cnt;
98+
unsigned int ucp_workers;
99+
int *put_proc_indexes;
100+
unsigned put_proc_count;
101+
bool synchronized_quiet;
102+
int strong_sync;
103+
mca_spml_ucx_rkey_store_t rkey_store;
91104
};
92105
typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t;
93106

@@ -128,6 +141,7 @@ struct mca_spml_ucx {
128141
unsigned long nb_ucp_worker_progress;
129142
unsigned int ucp_workers;
130143
unsigned int ucp_worker_cnt;
144+
int symmetric_rkey_max_count;
131145
};
132146
typedef struct mca_spml_ucx mca_spml_ucx_t;
133147

@@ -280,6 +294,11 @@ extern int mca_spml_ucx_team_fcollect(shmem_team_t team, void
280294
extern int mca_spml_ucx_team_reduce(shmem_team_t team, void
281295
*dest, const void *source, size_t nreduce, int operation, int datatype);
282296

297+
extern unsigned
298+
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx);
299+
300+
extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store);
301+
extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store);
283302

284303
static inline int
285304
mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey)

0 commit comments

Comments
 (0)