Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCT/IB/MD: retry memory registration with reduced access flags on failure #10341

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e47f0ff
UCT/IB/MD: retry memory registration with reduced access flags on fai…
amastbaum Dec 1, 2024
0ed3576
UCT/IB/MD: add myself to AUTHORS file
amastbaum Dec 2, 2024
5b5a4ba
UCT/IB/MD: added some changes
amastbaum Dec 2, 2024
ebd6386
UCT/IB/MD: set mem reg flags to read-only in proto rndv rts
amastbaum Dec 2, 2024
e2564b6
UCT/IB/MD: CR fixes
amastbaum Dec 3, 2024
bcf4d11
UCT/IB/MD: CR fixes 2
amastbaum Dec 3, 2024
7f0ad47
UCT/IB/MD: add ucp_mem_map params->prot to uct_flags translation
amastbaum Dec 8, 2024
0be06fc
Merge branch 'master' into support_mem_reg_of_a_read_only_address
amastbaum Dec 10, 2024
547432e
UCT/IB/MD: CR fixes
amastbaum Dec 11, 2024
f45911c
UCT/IB/MD: use only ucp defined flags in ib mem reg
amastbaum Dec 16, 2024
fdaa23a
UCT/IB/MD: revert changes in UCP
amastbaum Dec 16, 2024
a25d39b
Merge branch 'master' into support_mem_reg_of_a_read_only_address
amastbaum Dec 30, 2024
251cd9a
UCT/IB/MD: pass uct_flags to gva mem_reg
amastbaum Jan 1, 2025
2dc3f68
UCT/IB/MD: set uct_flags to ALL when calling ucp_datatype_iter_mem_reg
amastbaum Jan 6, 2025
26e991e
UCT/IB/MD: change mem_reg permissions in rndv_rtr to ALL
amastbaum Jan 6, 2025
bc3cf3c
UCT/IB/MD: change libperf memory registration permissions to ALL
amastbaum Jan 6, 2025
d2b1b60
UCT/IB/MD: add ucp_mem_map params->prot to uct_flags translation
amastbaum Jan 7, 2025
8aa1b4e
UCT/IB/MD: add full permissions to ucp_test mem_map
amastbaum Jan 7, 2025
0ad3eb4
UCT/IB/MD: decoupled atomic permission flags assignment from rma flags
amastbaum Jan 9, 2025
22ec2c2
UCT/IB/MD: added memory protection flags to map_buffer test function
amastbaum Jan 9, 2025
256bd62
UCT/IB/MD: specified memory protection flags in test_ucp_mem_type and…
amastbaum Jan 9, 2025
c6f8cbe
Merge branch 'master' into support_mem_reg_of_a_read_only_address
amastbaum Jan 9, 2025
e3aa7ed
UCT/IB/MD: don't fail in dereg_invalidate_rkey_check if the requested…
amastbaum Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Alex Margolin <[email protected]>
Alex Mikheev <[email protected]>
Alexey Rivkin <[email protected]>
Alina Sklarevich <[email protected]>
Alma Mastbaum <[email protected]>
Anatoly Vildemanov <[email protected]>
Andrey Maslennikov <[email protected]>
Artem Polyakov <[email protected]>
Expand Down
7 changes: 6 additions & 1 deletion src/tools/perf/lib/libperf_memory.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ static ucs_status_t ucp_perf_mem_alloc(const ucx_perf_context_t *perf,
params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_FLAGS |
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE |
UCP_MEM_MAP_PARAM_FIELD_PROT;
params.address = NULL;
params.memory_type = mem_type;
params.length = length;
params.flags = UCP_MEM_MAP_ALLOCATE;
params.prot = UCP_MEM_MAP_PROT_LOCAL_READ |
UCP_MEM_MAP_PROT_LOCAL_WRITE |
UCP_MEM_MAP_PROT_REMOTE_READ |
UCP_MEM_MAP_PROT_REMOTE_WRITE;
if (perf->params.flags & UCX_PERF_TEST_FLAG_MAP_NONBLOCK) {
params.flags |= UCP_MEM_MAP_NONBLOCK;
}
Expand Down
7 changes: 6 additions & 1 deletion src/tools/perf/perftest_daemon.c
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,13 @@ ucp_perf_daemon_memh_import(ucp_perf_daemon_context_t *ctx, void *packed_memh)
ucs_status_t status;
ucp_mem_h memh;

params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER |
UCP_MEM_MAP_PARAM_FIELD_PROT;
params.exported_memh_buffer = packed_memh;
params.prot = UCP_MEM_MAP_PROT_LOCAL_READ |
UCP_MEM_MAP_PROT_LOCAL_WRITE |
UCP_MEM_MAP_PROT_REMOTE_READ |
UCP_MEM_MAP_PROT_REMOTE_WRITE;

status = ucp_mem_map(ctx->context, &params, &memh);
if (status != UCS_OK) {
Expand Down
28 changes: 23 additions & 5 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,26 @@ ucp_mem_map_params2uct_flags(const ucp_context_h context,
{
unsigned flags = 0;

if (context->config.features & UCP_FEATURE_RMA) {
flags |= UCT_MD_MEM_ACCESS_RMA;
if (params->field_mask & UCP_MEM_MAP_PARAM_FIELD_PROT) {
if (params->prot & UCP_MEM_MAP_PROT_LOCAL_READ) {
flags |= UCT_MD_MEM_ACCESS_LOCAL_READ;
}

if (params->prot & UCP_MEM_MAP_PROT_REMOTE_READ) {
flags |= UCT_MD_MEM_ACCESS_REMOTE_GET;
}

if (params->prot & UCP_MEM_MAP_PROT_LOCAL_WRITE) {
flags |= UCT_MD_MEM_ACCESS_LOCAL_WRITE;
}

if (params->prot & UCP_MEM_MAP_PROT_REMOTE_WRITE) {
flags |= UCT_MD_MEM_ACCESS_REMOTE_PUT;
}
} else {
if (context->config.features & UCP_FEATURE_RMA) {
flags |= UCT_MD_MEM_ACCESS_RMA;
}
}

if (context->config.features & UCP_FEATURE_AMO) {
Expand Down Expand Up @@ -462,7 +480,7 @@ static void ucp_memh_cleanup(ucp_context_h context, ucp_mem_h memh)
}

static ucs_status_t ucp_memh_register_gva(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map)
ucp_md_map_t md_map, unsigned uct_flags)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand it's temp, but, just in case, it would be good to replace gva_mr with some structure and keep registrations for different access flags there.

{
ucp_md_map_t reg_md_map = context->gva_md_map[memh->mem_type] & md_map;
void *address = ucp_memh_address(memh);
Expand All @@ -477,7 +495,7 @@ static ucs_status_t ucp_memh_register_gva(ucp_context_h context, ucp_mem_h memh,
}

params.field_mask = UCT_MD_MEM_REG_FIELD_FLAGS;
params.flags = UCT_MD_MEM_GVA;
params.flags = UCT_MD_MEM_GVA | uct_flags;

if (context->config.ext.gva_mlock &&
!(memh->flags & UCP_MEMH_FLAG_MLOCKED)) {
Expand Down Expand Up @@ -537,7 +555,7 @@ ucp_memh_register_internal(ucp_context_h context, ucp_mem_h memh,
size_t reg_align;

if (gva_enable) {
status = ucp_memh_register_gva(context, memh, md_map);
status = ucp_memh_register_gva(context, memh, md_map, uct_flags);
if ((status != UCS_OK) && !(uct_flags & UCT_MD_MEM_FLAG_HIDE_ERRORS)) {
return status;
}
Expand Down
3 changes: 2 additions & 1 deletion src/ucp/rndv/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -1730,7 +1730,8 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_receive, (worker, rreq, rndv_rts_hdr, rkey_buf),
ucp_datatype_iter_mem_reg(context, &rreq->recv.dt_iter,
context->reg_md_map[mem_type] &
ep_config->key.rma_bw_md_map,
0, UCS_BIT(UCP_DATATYPE_CONTIG));
UCT_MD_MEM_ACCESS_ALL,
UCS_BIT(UCP_DATATYPE_CONTIG));

rtr_md_map = ep_config->key.rma_bw_md_map;
}
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/rndv/rndv_rtr.c
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ static ucs_status_t ucp_proto_rndv_rtr_progress(uct_pending_req_t *self)
status = ucp_datatype_iter_mem_reg(worker->context,
&req->send.state.dt_iter,
rpriv->super.md_map,
UCT_MD_MEM_ACCESS_REMOTE_PUT |
UCT_MD_MEM_ACCESS_ALL |
UCT_MD_MEM_FLAG_HIDE_ERRORS,
UCP_DT_MASK_ALL);
if (status != UCS_OK) {
Expand Down
1 change: 1 addition & 0 deletions src/ucp/tag/offload.c
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ ucp_tag_offload_do_post(ucp_request_t *req)
/* register the whole buffer to support SW RNDV fallback */
status = ucp_datatype_iter_mem_reg(context, &req->recv.dt_iter,
UCS_BIT(mdi),
UCT_MD_MEM_ACCESS_ALL |
UCT_MD_MEM_FLAG_HIDE_ERRORS,
UCS_BIT(UCP_DATATYPE_CONTIG));
if (status != UCS_OK) {
Expand Down
41 changes: 33 additions & 8 deletions src/uct/ib/base/ib_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,34 @@ ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
return UCS_OK;
}

static uint64_t uct_ib_flags_to_ibv_mem_access_flags(uint64_t uct_flags)
{
uint64_t ibv_flags = 0;

if (uct_flags & UCT_MD_MEM_ACCESS_LOCAL_WRITE) {
ibv_flags |= IBV_ACCESS_LOCAL_WRITE;
}

if (uct_flags & UCT_MD_MEM_ACCESS_REMOTE_GET) {
ibv_flags |= IBV_ACCESS_REMOTE_READ;
}

if (uct_flags & UCT_MD_MEM_ACCESS_REMOTE_PUT) {
ibv_flags |= IBV_ACCESS_REMOTE_WRITE;
}

if (uct_flags & UCT_MD_MEM_ACCESS_REMOTE_ATOMIC) {
ibv_flags |= IBV_ACCESS_REMOTE_ATOMIC;
}

return ibv_flags;
}

uint64_t uct_ib_memh_access_flags(uct_ib_mem_t *memh, int relaxed_order,
uint64_t access_flags)
uint64_t access_flags, uint64_t uct_flags)
{
access_flags |= uct_ib_flags_to_ibv_mem_access_flags(uct_flags);

if (memh->flags & UCT_IB_MEM_FLAG_ODP) {
access_flags |= IBV_ACCESS_ON_DEMAND;
}
Expand All @@ -644,27 +669,27 @@ ucs_status_t uct_ib_verbs_mem_reg(uct_md_h uct_md, void *address, size_t length,
const uct_md_mem_reg_params_t *params,
uct_mem_h *memh_p)
{
uct_ib_md_t *md = ucs_derived_of(uct_md, uct_ib_md_t);
uct_ib_md_t *md = ucs_derived_of(uct_md, uct_ib_md_t);
uint64_t uct_flags = UCT_MD_MEM_REG_FIELD_VALUE(params, flags,
FIELD_FLAGS, 0);
struct ibv_mr *mr_default;
uct_ib_verbs_mem_t *memh;
uct_ib_mem_t *ib_memh;
uint64_t access_flags;
ucs_status_t status;

status = uct_ib_memh_alloc(md, length,
UCT_MD_MEM_REG_FIELD_VALUE(params, flags,
FIELD_FLAGS, 0),
status = uct_ib_memh_alloc(md, length, uct_flags,
sizeof(*memh), sizeof(memh->mrs[0]), &ib_memh);
if (status != UCS_OK) {
goto err;
}

memh = ucs_derived_of(ib_memh, uct_ib_verbs_mem_t);
access_flags = uct_ib_memh_access_flags(&memh->super, md->relaxed_order,
md->dev.mr_access_flags);
md->dev.mr_access_flags, uct_flags);

status = uct_ib_reg_mr(md, address, length, params, access_flags, NULL,
&mr_default);
status = uct_ib_reg_mr(md, address, length, params, access_flags,
NULL, &mr_default);
if (status != UCS_OK) {
goto err_free;
}
Expand Down
2 changes: 1 addition & 1 deletion src/uct/ib/base/ib_md.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ uct_ib_md_handle_mr_list_mt(uct_ib_md_t *md, void *address, size_t length,
struct ibv_mr **mrs);

uint64_t uct_ib_memh_access_flags(uct_ib_mem_t *memh, int relaxed_order,
uint64_t access_flags);
uint64_t access_flags, uint64_t uct_flags);

ucs_status_t uct_ib_verbs_mem_reg(uct_md_h uct_md, void *address, size_t length,
const uct_md_mem_reg_params_t *params,
Expand Down
13 changes: 7 additions & 6 deletions src/uct/ib/mlx5/dv/ib_mlx5dv_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -703,12 +703,12 @@ uct_ib_mlx5_devx_reg_mr(uct_ib_mlx5_md_t *md, uct_ib_mlx5_devx_mem_t *memh,
uct_ib_mr_type_t mr_type, uint64_t access_mask,
uint32_t *lkey_p, uint32_t *rkey_p)
{
uint64_t access_flags =
uct_ib_memh_access_flags(&memh->super, md->super.relaxed_order,
md->super.dev.mr_access_flags) &
access_mask;
unsigned flags = UCT_MD_MEM_REG_FIELD_VALUE(params, flags,
FIELD_FLAGS, 0);
uint64_t access_flags =
uct_ib_memh_access_flags(&memh->super, md->super.relaxed_order,
md->super.dev.mr_access_flags, flags) &
access_mask;
ucs_status_t status;
uint32_t mkey;

Expand Down Expand Up @@ -807,7 +807,8 @@ uct_ib_mlx5_devx_mem_reg_gva(uct_md_h uct_md, unsigned flags, uct_mem_h *memh_p)

relaxed_order = md->flags & UCT_IB_MLX5_MD_FLAG_GVA_RO;
access_flags = uct_ib_memh_access_flags(&memh->super, relaxed_order,
md->super.dev.mr_access_flags);
md->super.dev.mr_access_flags,
flags);
status = uct_ib_reg_mr(&md->super, NULL, SIZE_MAX, &params, access_flags,
NULL, &memh->mrs[UCT_IB_MR_DEFAULT].super.ib);
if (status != UCS_OK) {
Expand Down Expand Up @@ -908,7 +909,7 @@ uct_ib_devx_dereg_invalidate_rkey_check(uct_ib_mlx5_md_t *md,
if (!(md->super.cap_flags & cap_mask)) {
ucs_debug("%s: invalidate %s is not supported (rkey=0x%x)",
uct_ib_device_name(&md->super.dev), name, rkey);
return UCS_ERR_UNSUPPORTED;
return UCS_OK;
}

if (rkey == UCT_IB_INVALID_MKEY) {
Expand Down
7 changes: 6 additions & 1 deletion test/apps/iodemo/ucx_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,14 @@ bool UcxContext::map_buffer(size_t length, void *address, ucp_mem_h *memh_p)
ucp_mem_map_params_t mem_map_params;

mem_map_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH;
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_PROT;
mem_map_params.address = address;
mem_map_params.length = length;
mem_map_params.prot = UCP_MEM_MAP_PROT_LOCAL_READ |
UCP_MEM_MAP_PROT_LOCAL_WRITE |
UCP_MEM_MAP_PROT_REMOTE_READ |
UCP_MEM_MAP_PROT_REMOTE_WRITE;

return ucp_mem_map(_context, &mem_map_params, memh_p) == UCS_OK;
}
Expand Down
4 changes: 2 additions & 2 deletions test/gtest/ucp/test_ucp_dt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ class test_ucp_dt_iter : public ucs::test_with_param<ucp_datatype_t> {

ucp_md_map_t md_map = m_ucph->reg_md_map[UCS_MEMORY_TYPE_HOST] &
m_ucph->cache_md_map[UCS_MEMORY_TYPE_HOST];
status = ucp_datatype_iter_mem_reg(m_ucph, &m_dt_iter, md_map, 0,
UINT_MAX);
status = ucp_datatype_iter_mem_reg(m_ucph, &m_dt_iter, md_map,
UCT_MD_MEM_ACCESS_ALL, UINT_MAX);
ASSERT_UCS_OK(status);

UCS_STRING_BUFFER_ONSTACK(strb, 64);
Expand Down
7 changes: 6 additions & 1 deletion test/gtest/ucp/test_ucp_mem_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,14 @@ UCS_TEST_P(test_ucp_cuda, sparse_regions) {

ucp_mem_map_params_t params;
params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH;
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_PROT;
params.address = ptr[i];
params.length = size;
params.prot = UCP_MEM_MAP_PROT_LOCAL_READ |
UCP_MEM_MAP_PROT_LOCAL_WRITE |
UCP_MEM_MAP_PROT_REMOTE_READ |
UCP_MEM_MAP_PROT_REMOTE_WRITE;

status = ucp_mem_map(sender().ucph(), &params, &memh[i]);
ASSERT_UCS_OK(status);
Expand Down
7 changes: 6 additions & 1 deletion test/gtest/ucp/test_ucp_memheap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ void test_ucp_memheap::test_xfer(send_func_t send_func, size_t size,
/* Allocate heap */
params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
UCP_MEM_MAP_PARAM_FIELD_FLAGS |
UCP_MEM_MAP_PARAM_FIELD_PROT;
params.address = memheap.ptr();
params.length = memheap.size();
params.flags = mem_map_flags;
params.prot = UCP_MEM_MAP_PROT_LOCAL_WRITE |
UCP_MEM_MAP_PROT_REMOTE_WRITE;

status = ucp_mem_map(receiver().ucph(), &params, &memheap_memh);
ASSERT_UCS_OK(status);
Expand All @@ -71,6 +74,8 @@ void test_ucp_memheap::test_xfer(send_func_t send_func, size_t size,
if (user_memh) {
params.address = expected_data.ptr();
params.length = expected_data.size();
params.prot = UCP_MEM_MAP_PROT_LOCAL_READ |
UCP_MEM_MAP_PROT_REMOTE_READ;
status = ucp_mem_map(sender().ucph(), &params, &send_memh);
ASSERT_UCS_OK(status);
}
Expand Down
14 changes: 12 additions & 2 deletions test/gtest/ucp/ucp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1054,9 +1054,14 @@ ucp_mem_h ucp_test_base::entity::mem_map(void *address, size_t length)
ucp_mem_map_params_t params;

params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH;
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_PROT;
params.address = address;
params.length = length;
params.prot = UCP_MEM_MAP_PROT_LOCAL_READ |
UCP_MEM_MAP_PROT_LOCAL_WRITE |
UCP_MEM_MAP_PROT_REMOTE_READ |
UCP_MEM_MAP_PROT_REMOTE_WRITE;

ucp_mem_h memh;
ucs_status_t status = ucp_mem_map(ucph(), &params, &memh);
Expand Down Expand Up @@ -1252,10 +1257,15 @@ ucp_test::mapped_buffer::mapped_buffer(size_t size, const entity& entity,
ucp_mem_map_params_t params;
params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
UCP_MEM_MAP_PARAM_FIELD_FLAGS |
UCP_MEM_MAP_PARAM_FIELD_PROT;
params.flags = flags;
params.address = ptr();
params.length = size;
params.prot = UCP_MEM_MAP_PROT_LOCAL_READ |
UCP_MEM_MAP_PROT_LOCAL_WRITE |
UCP_MEM_MAP_PROT_REMOTE_READ |
UCP_MEM_MAP_PROT_REMOTE_WRITE;

status = ucp_mem_map(m_entity.ucph(), &params, &m_memh);
ASSERT_UCS_OK(status);
Expand Down
Loading