Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCP/RKEY: Acquire context lock when calling ucp_rkey_pack_memh #10462

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ucp/core/ucp_rkey.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ ucp_rkey_unpack_distance(const ucp_rkey_packed_distance_t *packed_distance,
distance->bandwidth = UCS_FP8_UNPACK(BANDWIDTH, packed_distance->bandwidth);
}

/* context->mt_lock must be held */
UCS_PROFILE_FUNC(ssize_t, ucp_rkey_pack_memh,
(context, md_map, memh, address, length, mem_info, sys_dev_map,
sys_distance, uct_flags, buffer),
Expand Down
21 changes: 12 additions & 9 deletions src/ucp/proto/proto_common.inl
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ ucp_proto_request_set_stage(ucp_request_t *req, uint8_t proto_stage)
{
const ucp_proto_t *proto = req->send.proto_config->proto;

ucs_assertv(proto_stage < UCP_PROTO_STAGE_LAST, "stage=%"PRIu8,
ucs_assertv(proto_stage < UCP_PROTO_STAGE_LAST, "stage=%" PRIu8,
proto_stage);
ucs_assert(proto->progress[proto_stage] != NULL);

ucp_trace_req(req, "set to stage %u, progress function '%s'", proto_stage,
ucs_debug_get_symbol_name(proto->progress[proto_stage]));
ucs_debug_get_symbol_name((void *)proto->progress[proto_stage]));
req->send.proto_stage = proto_stage;

/* Set pointer to progress function */
Expand All @@ -186,7 +186,7 @@ static void ucp_proto_request_set_proto(ucp_request_t *req,
const ucp_proto_config_t *proto_config,
size_t msg_length)
{
ucs_assertv(req->flags & UCP_REQUEST_FLAG_PROTO_SEND, "flags=0x%"PRIx32,
ucs_assertv(req->flags & UCP_REQUEST_FLAG_PROTO_SEND, "flags=0x%" PRIx32,
req->flags);

req->send.proto_config = proto_config;
Expand Down Expand Up @@ -346,6 +346,7 @@ ucp_proto_request_pack_rkey(ucp_request_t *req, ucp_md_map_t md_map,
const ucs_sys_dev_distance_t *dev_distance,
void *rkey_buffer)
{
ucp_context_h context = req->send.ep->worker->context;
const ucp_datatype_iter_t *dt_iter = &req->send.state.dt_iter;
ucp_mem_h memh;
ssize_t packed_rkey_size;
Expand All @@ -366,17 +367,19 @@ ucp_proto_request_pack_rkey(ucp_request_t *req, ucp_md_map_t md_map,
ucs_unlikely(memh->flags & UCP_MEMH_FLAG_HAS_AUTO_GVA)) {
ucp_memh_disable_gva(memh, md_map);
}

if (!ucs_test_all_flags(memh->md_map, md_map)) {
ucs_trace("dt_iter_md_map=0x%"PRIx64" md_map=0x%"PRIx64, memh->md_map,
md_map);
ucs_trace("dt_iter_md_map=0x%" PRIx64 " md_map=0x%" PRIx64,
memh->md_map, md_map);
}

/* TODO: context lock is not scalable. Consider fine-grained lock per memh,
* immutable memh with rkey cache, RCU/COW */
UCP_THREAD_CS_ENTER(&context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(
req->send.ep->worker->context, md_map & memh->md_map, memh,
dt_iter->type.contig.buffer, dt_iter->length, &dt_iter->mem_info,
distance_dev_map, dev_distance,
context, md_map & memh->md_map, memh, dt_iter->type.contig.buffer,
dt_iter->length, &dt_iter->mem_info, distance_dev_map, dev_distance,
ucp_ep_config(req->send.ep)->uct_rkey_pack_flags, rkey_buffer);
UCP_THREAD_CS_EXIT(&context->mt_lock);

if (packed_rkey_size < 0) {
ucs_error("failed to pack remote key: %s",
Expand Down
12 changes: 10 additions & 2 deletions src/ucp/rndv/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,15 @@ size_t ucp_rndv_rts_pack(ucp_request_t *sreq, ucp_rndv_rts_hdr_t *rndv_rts_hdr,
rndv_rts_hdr->address = (uintptr_t)sreq->send.buffer;
rkey_buf = UCS_PTR_BYTE_OFFSET(rndv_rts_hdr,
sizeof(*rndv_rts_hdr));
packed_rkey_size = ucp_rkey_pack_memh(

UCP_THREAD_CS_ENTER(&worker->context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(
worker->context, sreq->send.rndv.md_map,
sreq->send.state.dt.dt.contig.memh, sreq->send.buffer,
sreq->send.length, &mem_info, 0, NULL,
ucp_ep_config(sreq->send.ep)->uct_rkey_pack_flags, rkey_buf);
UCP_THREAD_CS_EXIT(&worker->context->mt_lock);

if (packed_rkey_size < 0) {
ucs_fatal("failed to pack rendezvous remote key: %s",
ucs_status_string((ucs_status_t)packed_rkey_size));
Expand All @@ -205,6 +209,7 @@ static size_t ucp_rndv_rtr_pack(void *dest, void *arg)
ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = dest;
ucp_request_t *rreq = ucp_request_get_super(rndv_req);
ucp_ep_h ep = rndv_req->send.ep;
ucp_context_h context = ep->worker->context;
ucp_memory_info_t mem_info;
ssize_t packed_rkey_size;

Expand All @@ -221,12 +226,15 @@ static size_t ucp_rndv_rtr_pack(void *dest, void *arg)
mem_info.type = rreq->recv.dt_iter.mem_info.type;
mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;

UCP_THREAD_CS_ENTER(&context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(
ep->worker->context, rndv_req->send.rndv.md_map,
context, rndv_req->send.rndv.md_map,
rreq->recv.dt_iter.type.contig.memh,
rreq->recv.dt_iter.type.contig.buffer, rndv_req->send.length,
&mem_info, 0, NULL, ucp_ep_config(ep)->uct_rkey_pack_flags,
rndv_rtr_hdr + 1);
UCP_THREAD_CS_EXIT(&context->mt_lock);

if (packed_rkey_size < 0) {
return packed_rkey_size;
}
Expand Down
9 changes: 7 additions & 2 deletions src/ucp/rndv/rndv_rtr.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ static size_t ucp_proto_rndv_rtr_mtype_pack(void *dest, void *arg)
{
ucp_rndv_rtr_hdr_t *rtr = dest;
ucp_request_t *req = arg;
ucp_context_h context = req->send.ep->worker->context;
const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv;
ucp_md_map_t md_map = rpriv->super.md_map;
ucp_mem_desc_t *mdesc = req->send.rndv.mdesc;
Expand All @@ -266,10 +267,14 @@ static size_t ucp_proto_rndv_rtr_mtype_pack(void *dest, void *arg)
/* Pack remote key for the fragment */
mem_info.type = mdesc->memh->mem_type;
mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
packed_rkey_size = ucp_rkey_pack_memh(req->send.ep->worker->context, md_map,
mdesc->memh, mdesc->ptr,

UCP_THREAD_CS_ENTER(&context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(context, md_map, mdesc->memh,
mdesc->ptr,
req->send.state.dt_iter.length,
&mem_info, 0, NULL, 0, rtr + 1);
UCP_THREAD_CS_EXIT(&context->mt_lock);

if (packed_rkey_size < 0) {
ucs_error("failed to pack remote key: %s",
ucs_status_string((ucs_status_t)packed_rkey_size));
Expand Down
64 changes: 52 additions & 12 deletions test/gtest/ucp/test_ucp_rma_mt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

#include <common/test_helpers.h>

extern "C" {
#include <ucp/proto/proto_common.inl>
}

#if _OPENMP
#include "omp.h"
#endif
Expand All @@ -35,6 +39,22 @@ class test_ucp_rma_mt : public ucp_test {
add_variant(variants, UCP_FEATURE_RMA, MULTI_THREAD_CONTEXT);
add_variant(variants, UCP_FEATURE_RMA, MULTI_THREAD_WORKER);
}

ucp_mem_h mem_map(ucp_context_h context, void *data, size_t size)
{
ucp_mem_map_params_t params;
ucp_mem_h memh;

params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
params.address = data;
params.length = size;
params.flags = get_variant_value();

ASSERT_UCS_OK(ucp_mem_map(context, &params, &memh));
return memh;
}
};

UCS_TEST_P(test_ucp_rma_mt, put_get) {
Expand All @@ -43,19 +63,9 @@ UCS_TEST_P(test_ucp_rma_mt, put_get) {
uint64_t orig_data[num_threads] GTEST_ATTRIBUTE_UNUSED_;
uint64_t target_data[num_threads] GTEST_ATTRIBUTE_UNUSED_;

ucp_mem_map_params_t params;
ucp_mem_h memh;
void *memheap = target_data;

params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
params.address = memheap;
params.length = sizeof(uint64_t) * num_threads;
params.flags = get_variant_value();

st = ucp_mem_map(receiver().ucph(), &params, &memh);
ASSERT_UCS_OK(st);
ucp_mem_h memh = mem_map(receiver().ucph(), memheap,
sizeof(uint64_t) * num_threads);

void *rkey_buffer;
size_t rkey_buffer_size;
Expand Down Expand Up @@ -200,4 +210,34 @@ UCS_TEST_P(test_ucp_rma_mt, put_get) {
ASSERT_UCS_OK(st);
}

UCS_TEST_P(test_ucp_rma_mt, rkey_pack) {
uint8_t data[1024] GTEST_ATTRIBUTE_UNUSED_;
ucp_context_h context = sender().ucph();
ucp_mem_h memh = mem_map(context, data, sizeof(data));

#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < mt_num_threads(); i++) {
if (i % 2 == 0) {
void *rkey;
size_t rkey_size;
ASSERT_UCS_OK(ucp_rkey_pack(context, memh, &rkey, &rkey_size));
ucp_rkey_buffer_release(rkey);
} else {
ucs_sys_dev_distance_t sys_dev = {};
ucp_request req = {};
req.send.ep = sender().ep();
req.send.state.dt_iter.type.contig.memh = memh;
req.send.state.dt_iter.type.contig.buffer = data;
req.send.state.dt_iter.length = sizeof(data);

uint8_t rkey[1024];
ucp_proto_request_pack_rkey(&req, memh->md_map, 0, &sys_dev, rkey);
}
}
#endif

ASSERT_UCS_OK(ucp_mem_unmap(context, memh));
}

UCP_INSTANTIATE_TEST_CASE(test_ucp_rma_mt)
Loading