diff --git a/src/ucp/core/ucp_mm.c b/src/ucp/core/ucp_mm.c index f132e57d164..6cbf1a0b987 100644 --- a/src/ucp/core/ucp_mm.c +++ b/src/ucp/core/ucp_mm.c @@ -1998,3 +1998,131 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer, out: return status; } + +static UCS_F_ALWAYS_INLINE int ucp_is_invalidate_cap(unsigned uct_flags) +{ + return uct_flags & (UCT_MD_MKEY_PACK_FLAG_INVALIDATE_RMA | + UCT_MD_MKEY_PACK_FLAG_INVALIDATE_AMO); +} + +static UCS_F_ALWAYS_INLINE int +ucp_memh_is_invalidate_cap(const ucp_mem_h memh, ucp_md_index_t md_idx) +{ + return ucp_is_invalidate_cap(memh->context->tl_mds[md_idx].pack_flags_mask); +} + +static void ucp_memh_derived_destroy(ucp_mem_h derived) +{ + ucp_context_h context = derived->context; + uct_md_mem_dereg_params_t params; + ucp_md_index_t md_index; + ucs_status_t status; + + ucs_trace("destroying derived memh=%p", derived); + ucs_for_each_bit(md_index, derived->md_map) { + if (ucp_memh_is_invalidate_cap(derived, md_index)) { + params.memh = derived->uct[md_index]; + ucs_trace("de-registering derived memh[%d]=%p", md_index, + derived->uct[md_index]); + status = uct_md_mem_dereg_v2(context->tl_mds[md_index].md, ¶ms); + if (status != UCS_OK) { + ucs_warn("derived memh %p failed to dereg from md[%d]=%s: %s", + derived, md_index, context->tl_mds[md_index].rsc.md_name, + ucs_status_string(status)); + } + } + } + + ucs_free(derived); +} + +static ucp_mem_h ucp_memh_derived_create(ucp_mem_h memh) +{ + ucp_mem_h derived; + ucp_md_index_t md_index; + ucs_status_t status; + uct_md_mem_reg_params_t params; + + ucs_trace("creating derived memh from %p", memh); + derived = ucs_calloc(1, ucp_memh_size(memh->context), "ucp_memh_derived"); + if (derived == NULL) { + ucs_error("failed to allocate memory for derived memh"); + return NULL; + } + + /* Intentionally copy UCP memh data without UCT memory handles */ + memcpy(derived, memh, sizeof(ucp_mem_t)); + + params.field_mask = UCT_MD_MEM_REG_FIELD_MEMH; + + /* Now copy all the UCT memory handles from the original UCP memh */ + ucs_for_each_bit(md_index, derived->md_map) { + if (ucp_memh_is_invalidate_cap(memh, md_index)) { + /* Invalidation is supported: create derived UCT memh */ + params.memh = &memh->uct[md_index]; + + ucs_trace("registering derived memh[%d]=%p", md_index, params.memh); + status = uct_md_mem_reg_v2(memh->context->tl_mds[md_index].md, NULL, + 0ul, ¶ms, &derived->uct[md_index]); + if (status == UCS_ERR_UNSUPPORTED) { + /* Invalidation is declared but not supported: shallow copy */ + ucs_trace("unsupported derived memh[%d]=%p, shallow copy", + md_index, params.memh); + derived->uct[md_index] = memh->uct[md_index]; + } else if (status != UCS_OK) { + ucp_memh_derived_destroy(derived); + return NULL; + } + } else { + /* Invalidation is not supported: do shallow copy */ + ucs_trace("shallow copy memh[%d]=%p", md_index, params.memh); + derived->uct[md_index] = memh->uct[md_index]; + } + } + + derived->flags |= UCP_MEMH_FLAG_DERIVED; + derived->parent = memh; + derived->derived = memh->derived; + memh->derived = derived; + + ucs_trace("created derived memh=%p from memh=%p", derived, memh); + return derived; +} + +static ucp_mem_h +ucp_memh_derived_get(ucp_mem_h memh, int create) +{ + if ((memh->derived != NULL) && + !(memh->derived->flags & UCP_MEMH_FLAG_INVALIDATED)) { + return memh->derived; + } + + if (!create) { + ucs_error("no valid derived memory handle for %p", memh); + return NULL; + } + + return ucp_memh_derived_create(memh); +} + +ucp_mem_h ucp_memh_get_pack_memh(ucp_mem_h memh, ucp_md_map_t md_map, + unsigned uct_flags, int create) +{ + int md_invalidate_cap = 0; + ucp_md_index_t md_index; + + if (!ucp_is_invalidate_cap(uct_flags)) { + return memh; + } + + /* Check if any of the requested MDs supports invalidation */ + ucs_for_each_bit(md_index, md_map) { + md_invalidate_cap += ucp_memh_is_invalidate_cap(memh, md_index); + } + + if (md_invalidate_cap == 0) { + return memh; + } + + return ucp_memh_derived_get(memh, create); +} diff --git a/src/ucp/core/ucp_mm.h b/src/ucp/core/ucp_mm.h index 8743e5342c4..69fea7b6b6e 100644 --- a/src/ucp/core/ucp_mm.h +++ b/src/ucp/core/ucp_mm.h @@ -43,7 +43,17 @@ enum { /** * Avoid using registration cache for the particular memory region. */ - UCP_MEMH_FLAG_NO_RCACHE = UCS_BIT(3) + UCP_MEMH_FLAG_NO_RCACHE = UCS_BIT(3), + + /** + * TODO + */ + UCP_MEMH_FLAG_DERIVED = UCS_BIT(4), + + /** + * TODO + */ + UCP_MEMH_FLAG_INVALIDATED = UCS_BIT(5), }; @@ -78,6 +88,7 @@ typedef struct ucp_mem { - pointer to rcache memh if entry is a user memh - pointer to self if entry is a user memh and rcache is disabled */ + ucp_mem_h derived; /* TODO */ uint64_t reg_id; /* Registration ID */ uct_mem_h uct[0]; /* Sparse memory handles array num_mds in size */ } ucp_mem_t; @@ -206,6 +217,12 @@ void ucp_mem_rcache_cleanup(ucp_context_h context); void ucp_memh_disable_gva(ucp_mem_h memh, ucp_md_map_t md_map); +/** + * TODO + */ +ucp_mem_h ucp_memh_get_pack_memh(ucp_mem_h memh, ucp_md_map_t md_map, + unsigned uct_flags, int create); + /** * Get memory domain index that is used to allocate certain memory type. * @@ -257,7 +274,6 @@ static UCS_F_ALWAYS_INLINE size_t ucp_memh_length(const ucp_mem_h memh) return memh->super.super.end - memh->super.super.start; } - #define UCP_MEM_IS_HOST(_mem_type) ((_mem_type) == UCS_MEMORY_TYPE_HOST) #define UCP_MEM_IS_ROCM(_mem_type) ((_mem_type) == UCS_MEMORY_TYPE_ROCM) #define UCP_MEM_IS_CUDA(_mem_type) ((_mem_type) == UCS_MEMORY_TYPE_CUDA) diff --git a/src/ucp/core/ucp_rkey.c b/src/ucp/core/ucp_rkey.c index a2928417d54..8401d5d07e0 100644 --- a/src/ucp/core/ucp_rkey.c +++ b/src/ucp/core/ucp_rkey.c @@ -129,7 +129,7 @@ UCS_PROFILE_FUNC(ssize_t, ucp_rkey_pack_memh, (context, md_map, memh, address, length, mem_info, sys_dev_map, sys_distance, uct_flags, buffer), ucp_context_h context, ucp_md_map_t md_map, - const ucp_mem_h memh, void *address, size_t length, + ucp_mem_h memh, void *address, size_t length, const ucp_memory_info_t *mem_info, ucp_sys_dev_map_t sys_dev_map, const ucs_sys_dev_distance_t *sys_distance, unsigned uct_flags, diff --git a/src/ucp/core/ucp_rkey.h b/src/ucp/core/ucp_rkey.h index b1816df7a62..f4b09b121f7 100644 --- a/src/ucp/core/ucp_rkey.h +++ b/src/ucp/core/ucp_rkey.h @@ -203,7 +203,7 @@ void ucp_rkey_packed_copy(ucp_context_h context, ucp_md_map_t md_map, ssize_t ucp_rkey_pack_memh(ucp_context_h context, ucp_md_map_t md_map, - const ucp_mem_h memh, void *address, size_t length, + ucp_mem_h memh, void *address, size_t length, const ucp_memory_info_t *mem_info, ucp_sys_dev_map_t sys_dev_map, const ucs_sys_dev_distance_t *sys_distance,