22
22
#include "opal/datatype/opal_convertor.h"
23
23
#include "opal/mca/common/ucx/common_ucx.h"
24
24
#include "opal/util/opal_environ.h"
25
+ #include "opal/util/minmax.h"
25
26
#include "ompi/datatype/ompi_datatype.h"
26
27
#include "ompi/mca/pml/pml.h"
27
28
@@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
126
127
};
127
128
#endif
128
129
130
+ unsigned
131
+ mca_spml_ucx_mem_map_flags_symmetric_rkey (struct mca_spml_ucx * spml_ucx )
132
+ {
133
+ #if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
134
+ if (spml_ucx -> symmetric_rkey_max_count > 0 ) {
135
+ return UCP_MEM_MAP_SYMMETRIC_RKEY ;
136
+ }
137
+ #endif
138
+
139
+ return 0 ;
140
+ }
141
+
142
+ void mca_spml_ucx_rkey_store_init (mca_spml_ucx_rkey_store_t * store )
143
+ {
144
+ store -> array = NULL ;
145
+ store -> count = 0 ;
146
+ store -> size = 0 ;
147
+ }
148
+
149
+ void mca_spml_ucx_rkey_store_cleanup (mca_spml_ucx_rkey_store_t * store )
150
+ {
151
+ int i ;
152
+
153
+ for (i = 0 ; i < store -> count ; i ++ ) {
154
+ if (store -> array [i ].refcnt != 0 ) {
155
+ SPML_UCX_ERROR ("rkey store destroy: %d/%d has refcnt %d > 0" ,
156
+ i , store -> count , store -> array [i ].refcnt );
157
+ }
158
+
159
+ ucp_rkey_destroy (store -> array [i ].rkey );
160
+ }
161
+
162
+ free (store -> array );
163
+ }
164
+
165
+ /**
166
+ * Find position in sorted array for existing or future entry
167
+ *
168
+ * @param[in] store Store of the entries
169
+ * @param[in] worker Common worker for rkeys used
170
+ * @param[in] rkey Remote key to search for
171
+ * @param[out] index Index of entry
172
+ *
173
+ * @return
174
+ * OSHMEM_ERR_NOT_FOUND: index contains the position where future element
175
+ * should be inserted to keep array sorted
176
+ * OSHMEM_SUCCESS : index contains the position of the element
177
+ * Other error : index is not valid
178
+ */
179
+ static int mca_spml_ucx_rkey_store_find (const mca_spml_ucx_rkey_store_t * store ,
180
+ const ucp_worker_h worker ,
181
+ const ucp_rkey_h rkey ,
182
+ int * index )
183
+ {
184
+ #if HAVE_DECL_UCP_RKEY_COMPARE
185
+ ucp_rkey_compare_params_t params ;
186
+ int i , result , m , end ;
187
+ ucs_status_t status ;
188
+
189
+ for (i = 0 , end = store -> count ; i < end ;) {
190
+ m = (i + end ) / 2 ;
191
+
192
+ params .field_mask = 0 ;
193
+ status = ucp_rkey_compare (worker , store -> array [m ].rkey ,
194
+ rkey , & params , & result );
195
+ if (status != UCS_OK ) {
196
+ return OSHMEM_ERROR ;
197
+ } else if (result == 0 ) {
198
+ * index = m ;
199
+ return OSHMEM_SUCCESS ;
200
+ } else if (result > 0 ) {
201
+ end = m ;
202
+ } else {
203
+ i = m + 1 ;
204
+ }
205
+ }
206
+
207
+ * index = i ;
208
+ return OSHMEM_ERR_NOT_FOUND ;
209
+ #else
210
+ return OSHMEM_ERROR ;
211
+ #endif
212
+ }
213
+
214
+ static void mca_spml_ucx_rkey_store_insert (mca_spml_ucx_rkey_store_t * store ,
215
+ int i , ucp_rkey_h rkey )
216
+ {
217
+ int size ;
218
+ mca_spml_ucx_rkey_t * tmp ;
219
+
220
+ if (store -> count >= mca_spml_ucx .symmetric_rkey_max_count ) {
221
+ return ;
222
+ }
223
+
224
+ if (store -> count >= store -> size ) {
225
+ size = opal_min (opal_max (store -> size , 8 ) * 2 ,
226
+ mca_spml_ucx .symmetric_rkey_max_count );
227
+ tmp = realloc (store -> array , size * sizeof (* store -> array ));
228
+ if (tmp == NULL ) {
229
+ return ;
230
+ }
231
+
232
+ store -> array = tmp ;
233
+ store -> size = size ;
234
+ }
235
+
236
+ memmove (& store -> array [i + 1 ], & store -> array [i ],
237
+ (store -> count - i ) * sizeof (* store -> array ));
238
+ store -> array [i ].rkey = rkey ;
239
+ store -> array [i ].refcnt = 1 ;
240
+ store -> count ++ ;
241
+ return ;
242
+ }
243
+
244
+ /* Takes ownership of input ucp remote key */
245
+ static ucp_rkey_h mca_spml_ucx_rkey_store_get (mca_spml_ucx_rkey_store_t * store ,
246
+ ucp_worker_h worker ,
247
+ ucp_rkey_h rkey )
248
+ {
249
+ int ret , i ;
250
+
251
+ if (mca_spml_ucx .symmetric_rkey_max_count == 0 ) {
252
+ return rkey ;
253
+ }
254
+
255
+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
256
+ if (ret == OSHMEM_SUCCESS ) {
257
+ ucp_rkey_destroy (rkey );
258
+ store -> array [i ].refcnt ++ ;
259
+ return store -> array [i ].rkey ;
260
+ }
261
+
262
+ if (ret == OSHMEM_ERR_NOT_FOUND ) {
263
+ mca_spml_ucx_rkey_store_insert (store , i , rkey );
264
+ }
265
+
266
+ return rkey ;
267
+ }
268
+
269
+ static void mca_spml_ucx_rkey_store_put (mca_spml_ucx_rkey_store_t * store ,
270
+ ucp_worker_h worker ,
271
+ ucp_rkey_h rkey )
272
+ {
273
+ mca_spml_ucx_rkey_t * entry ;
274
+ int ret , i ;
275
+
276
+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
277
+ if (ret != OSHMEM_SUCCESS ) {
278
+ goto out ;
279
+ }
280
+
281
+ entry = & store -> array [i ];
282
+ assert (entry -> rkey == rkey );
283
+ if (-- entry -> refcnt > 0 ) {
284
+ return ;
285
+ }
286
+
287
+ memmove (& store -> array [i ], & store -> array [i + 1 ],
288
+ (store -> count - (i + 1 )) * sizeof (* store -> array ));
289
+ store -> count -- ;
290
+
291
+ out :
292
+ ucp_rkey_destroy (rkey );
293
+ }
294
+
129
295
int mca_spml_ucx_enable (bool enable )
130
296
{
131
297
SPML_UCX_VERBOSE (50 , "*** ucx ENABLED ****" );
@@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
240
406
{
241
407
int rc ;
242
408
ucs_status_t err ;
409
+ ucp_rkey_h rkey ;
243
410
244
411
rc = mca_spml_ucx_ctx_mkey_new (ucx_ctx , pe , segno , ucx_mkey );
245
412
if (OSHMEM_SUCCESS != rc ) {
@@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
248
415
}
249
416
250
417
if (mkey -> u .data ) {
251
- err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & (( * ucx_mkey ) -> rkey ) );
418
+ err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & rkey );
252
419
if (UCS_OK != err ) {
253
420
SPML_UCX_ERROR ("failed to unpack rkey: %s" , ucs_status_string (err ));
254
421
return OSHMEM_ERROR ;
255
422
}
423
+
424
+ if (!oshmem_proc_on_local_node (pe )) {
425
+ rkey = mca_spml_ucx_rkey_store_get (& ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [0 ], rkey );
426
+ }
427
+
428
+ (* ucx_mkey )-> rkey = rkey ;
429
+
256
430
rc = mca_spml_ucx_ctx_mkey_cache (ucx_ctx , mkey , segno , pe );
257
431
if (OSHMEM_SUCCESS != rc ) {
258
432
SPML_UCX_ERROR ("mca_spml_ucx_ctx_mkey_cache failed" );
@@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
267
441
ucp_peer_t * ucp_peer ;
268
442
int rc ;
269
443
ucp_peer = & (ucx_ctx -> ucp_peers [pe ]);
270
- ucp_rkey_destroy ( ucx_mkey -> rkey );
444
+ mca_spml_ucx_rkey_store_put ( & ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [ 0 ], ucx_mkey -> rkey );
271
445
ucx_mkey -> rkey = NULL ;
272
446
rc = mca_spml_ucx_peer_mkey_cache_del (ucp_peer , segno );
273
447
if (OSHMEM_SUCCESS != rc ){
@@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
725
899
UCP_MEM_MAP_PARAM_FIELD_FLAGS ;
726
900
mem_map_params .address = addr ;
727
901
mem_map_params .length = size ;
728
- mem_map_params .flags = flags ;
902
+ mem_map_params .flags = flags |
903
+ mca_spml_ucx_mem_map_flags_symmetric_rkey (& mca_spml_ucx );
729
904
730
905
status = ucp_mem_map (mca_spml_ucx .ucp_context , & mem_map_params , & mem_h );
731
906
if (UCS_OK != status ) {
@@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
917
1092
}
918
1093
}
919
1094
1095
+ mca_spml_ucx_rkey_store_init (& ucx_ctx -> rkey_store );
1096
+
920
1097
* ucx_ctx_p = ucx_ctx ;
921
1098
922
1099
return OSHMEM_SUCCESS ;
0 commit comments