Skip to content

Commit

Permalink
UCP: UCP derived memh skeleton
Browse files Browse the repository at this point in the history
  • Loading branch information
iyastreb committed Jan 9, 2025
1 parent 0196e96 commit 06b6131
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 4 deletions.
128 changes: 128 additions & 0 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -1998,3 +1998,131 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer,
out:
return status;
}

static UCS_F_ALWAYS_INLINE int ucp_is_invalidate_cap(unsigned uct_flags)
{
return uct_flags & (UCT_MD_MKEY_PACK_FLAG_INVALIDATE_RMA |
UCT_MD_MKEY_PACK_FLAG_INVALIDATE_AMO);
}

static UCS_F_ALWAYS_INLINE int
ucp_memh_is_invalidate_cap(const ucp_mem_h memh, ucp_md_index_t md_idx)
{
return ucp_is_invalidate_cap(memh->context->tl_mds[md_idx].pack_flags_mask);
}

static void ucp_memh_derived_destroy(ucp_mem_h derived)
{
ucp_context_h context = derived->context;
uct_md_mem_dereg_params_t params;
ucp_md_index_t md_index;
ucs_status_t status;

ucs_trace("destroying derived memh=%p", derived);
ucs_for_each_bit(md_index, derived->md_map) {
if (ucp_memh_is_invalidate_cap(derived, md_index)) {
params.memh = derived->uct[md_index];
ucs_trace("de-registering derived memh[%d]=%p", md_index,
derived->uct[md_index]);
status = uct_md_mem_dereg_v2(context->tl_mds[md_index].md, &params);
if (status != UCS_OK) {
ucs_warn("derived memh %p failed to dereg from md[%d]=%s: %s",
derived, md_index, context->tl_mds[md_index].rsc.md_name,
ucs_status_string(status));
}
}
}

ucs_free(derived);
}

static ucp_mem_h ucp_memh_derived_create(ucp_mem_h memh)
{
ucp_mem_h derived;
ucp_md_index_t md_index;
ucs_status_t status;
uct_md_mem_reg_params_t params;

ucs_trace("creating derived memh from %p", memh);
derived = ucs_calloc(1, ucp_memh_size(memh->context), "ucp_memh_derived");
if (derived == NULL) {
ucs_error("failed to allocate memory for derived memh");
return NULL;
}

/* Intentionally copy UCP memh data without UCT memory handles */
memcpy(derived, memh, sizeof(ucp_mem_t));

params.field_mask = UCT_MD_MEM_REG_FIELD_MEMH;

/* Now copy all the UCT memory handles from the original UCP memh */
ucs_for_each_bit(md_index, derived->md_map) {
if (ucp_memh_is_invalidate_cap(memh, md_index)) {
/* Invalidation is supported: create derived UCT memh */
params.memh = &memh->uct[md_index];

ucs_trace("registering derived memh[%d]=%p", md_index, params.memh);
status = uct_md_mem_reg_v2(memh->context->tl_mds[md_index].md, NULL,
0ul, &params, &derived->uct[md_index]);
if (status == UCS_ERR_UNSUPPORTED) {
/* Invalidation is declared but not supported: shallow copy */
ucs_trace("unsupported derived memh[%d]=%p, shallow copy",
md_index, params.memh);
derived->uct[md_index] = memh->uct[md_index];
} else if (status != UCS_OK) {
ucp_memh_derived_destroy(derived);
return NULL;
}
} else {
/* Invalidation is not supported: do shallow copy */
ucs_trace("shallow copy memh[%d]=%p", md_index, params.memh);
derived->uct[md_index] = memh->uct[md_index];
}
}

derived->flags |= UCP_MEMH_FLAG_DERIVED;
derived->parent = memh;
derived->derived = memh->derived;
memh->derived = derived;

ucs_trace("created derived memh=%p from memh=%p", derived, memh);
return derived;
}

static ucp_mem_h
ucp_memh_derived_get(ucp_mem_h memh, int create)
{
if ((memh->derived != NULL) &&
!(memh->derived->flags & UCP_MEMH_FLAG_INVALIDATED)) {
return memh->derived;
}

if (!create) {
ucs_error("no valid derived memory handle for %p", memh);
return NULL;
}

return ucp_memh_derived_create(memh);
}

ucp_mem_h ucp_memh_get_pack_memh(ucp_mem_h memh, ucp_md_map_t md_map,
unsigned uct_flags, int create)
{
int md_invalidate_cap = 0;
ucp_md_index_t md_index;

if (!ucp_is_invalidate_cap(uct_flags)) {
return memh;
}

/* Check if any of the requested MDs supports invalidation */
ucs_for_each_bit(md_index, md_map) {
md_invalidate_cap += ucp_memh_is_invalidate_cap(memh, md_index);
}

if (md_invalidate_cap == 0) {
return memh;
}

return ucp_memh_derived_get(memh, create);
}
20 changes: 18 additions & 2 deletions src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,17 @@ enum {
/**
* Avoid using registration cache for the particular memory region.
*/
UCP_MEMH_FLAG_NO_RCACHE = UCS_BIT(3)
UCP_MEMH_FLAG_NO_RCACHE = UCS_BIT(3),

/**
* TODO
*/
UCP_MEMH_FLAG_DERIVED = UCS_BIT(4),

/**
* TODO
*/
UCP_MEMH_FLAG_INVALIDATED = UCS_BIT(5),
};


Expand Down Expand Up @@ -78,6 +88,7 @@ typedef struct ucp_mem {
- pointer to rcache memh if entry is a user memh
- pointer to self if entry is a user memh
and rcache is disabled */
ucp_mem_h derived; /* TODO */
uint64_t reg_id; /* Registration ID */
uct_mem_h uct[0]; /* Sparse memory handles array num_mds in size */
} ucp_mem_t;
Expand Down Expand Up @@ -206,6 +217,12 @@ void ucp_mem_rcache_cleanup(ucp_context_h context);

void ucp_memh_disable_gva(ucp_mem_h memh, ucp_md_map_t md_map);

/**
* TODO
*/
ucp_mem_h ucp_memh_get_pack_memh(ucp_mem_h memh, ucp_md_map_t md_map,
unsigned uct_flags, int create);

/**
* Get memory domain index that is used to allocate certain memory type.
*
Expand Down Expand Up @@ -257,7 +274,6 @@ static UCS_F_ALWAYS_INLINE size_t ucp_memh_length(const ucp_mem_h memh)
return memh->super.super.end - memh->super.super.start;
}


#define UCP_MEM_IS_HOST(_mem_type) ((_mem_type) == UCS_MEMORY_TYPE_HOST)
#define UCP_MEM_IS_ROCM(_mem_type) ((_mem_type) == UCS_MEMORY_TYPE_ROCM)
#define UCP_MEM_IS_CUDA(_mem_type) ((_mem_type) == UCS_MEMORY_TYPE_CUDA)
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_rkey.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ UCS_PROFILE_FUNC(ssize_t, ucp_rkey_pack_memh,
(context, md_map, memh, address, length, mem_info, sys_dev_map,
sys_distance, uct_flags, buffer),
ucp_context_h context, ucp_md_map_t md_map,
const ucp_mem_h memh, void *address, size_t length,
ucp_mem_h memh, void *address, size_t length,
const ucp_memory_info_t *mem_info,
ucp_sys_dev_map_t sys_dev_map,
const ucs_sys_dev_distance_t *sys_distance, unsigned uct_flags,
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_rkey.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ void ucp_rkey_packed_copy(ucp_context_h context, ucp_md_map_t md_map,


ssize_t ucp_rkey_pack_memh(ucp_context_h context, ucp_md_map_t md_map,
const ucp_mem_h memh, void *address, size_t length,
ucp_mem_h memh, void *address, size_t length,
const ucp_memory_info_t *mem_info,
ucp_sys_dev_map_t sys_dev_map,
const ucs_sys_dev_distance_t *sys_distance,
Expand Down

0 comments on commit 06b6131

Please sign in to comment.