From 8c692ca724976b2ba5153b34dc92bc760f7e1a2f Mon Sep 17 00:00:00 2001 From: Ilia Yastrebov Date: Thu, 30 Jan 2025 13:27:22 +0000 Subject: [PATCH 1/2] UCP/RKEY: Acquire context lock when calling ucp_rkey_pack_memh --- src/ucp/core/ucp_rkey.c | 1 + src/ucp/proto/proto_common.inl | 21 +++++----- src/ucp/rndv/rndv.c | 12 +++++- src/ucp/rndv/rndv_rtr.c | 9 ++++- test/gtest/ucp/test_ucp_rma_mt.cc | 64 +++++++++++++++++++++++++------ 5 files changed, 82 insertions(+), 25 deletions(-) diff --git a/src/ucp/core/ucp_rkey.c b/src/ucp/core/ucp_rkey.c index a2928417d54..0efa24ce027 100644 --- a/src/ucp/core/ucp_rkey.c +++ b/src/ucp/core/ucp_rkey.c @@ -125,6 +125,7 @@ ucp_rkey_unpack_distance(const ucp_rkey_packed_distance_t *packed_distance, distance->bandwidth = UCS_FP8_UNPACK(BANDWIDTH, packed_distance->bandwidth); } +/* context->mt_lock must be held */ UCS_PROFILE_FUNC(ssize_t, ucp_rkey_pack_memh, (context, md_map, memh, address, length, mem_info, sys_dev_map, sys_distance, uct_flags, buffer), diff --git a/src/ucp/proto/proto_common.inl b/src/ucp/proto/proto_common.inl index 00aa629fa1b..322acd94def 100644 --- a/src/ucp/proto/proto_common.inl +++ b/src/ucp/proto/proto_common.inl @@ -165,12 +165,12 @@ ucp_proto_request_set_stage(ucp_request_t *req, uint8_t proto_stage) { const ucp_proto_t *proto = req->send.proto_config->proto; - ucs_assertv(proto_stage < UCP_PROTO_STAGE_LAST, "stage=%"PRIu8, + ucs_assertv(proto_stage < UCP_PROTO_STAGE_LAST, "stage=%" PRIu8, proto_stage); ucs_assert(proto->progress[proto_stage] != NULL); ucp_trace_req(req, "set to stage %u, progress function '%s'", proto_stage, - ucs_debug_get_symbol_name(proto->progress[proto_stage])); + ucs_debug_get_symbol_name((void *)proto->progress[proto_stage])); req->send.proto_stage = proto_stage; /* Set pointer to progress function */ @@ -186,7 +186,7 @@ static void ucp_proto_request_set_proto(ucp_request_t *req, const ucp_proto_config_t *proto_config, size_t msg_length) { - ucs_assertv(req->flags & UCP_REQUEST_FLAG_PROTO_SEND, "flags=0x%"PRIx32, + ucs_assertv(req->flags & UCP_REQUEST_FLAG_PROTO_SEND, "flags=0x%" PRIx32, req->flags); req->send.proto_config = proto_config; @@ -346,6 +346,7 @@ ucp_proto_request_pack_rkey(ucp_request_t *req, ucp_md_map_t md_map, const ucs_sys_dev_distance_t *dev_distance, void *rkey_buffer) { + ucp_context_h context = req->send.ep->worker->context; const ucp_datatype_iter_t *dt_iter = &req->send.state.dt_iter; ucp_mem_h memh; ssize_t packed_rkey_size; @@ -366,17 +367,19 @@ ucp_proto_request_pack_rkey(ucp_request_t *req, ucp_md_map_t md_map, ucs_unlikely(memh->flags & UCP_MEMH_FLAG_HAS_AUTO_GVA)) { ucp_memh_disable_gva(memh, md_map); } - if (!ucs_test_all_flags(memh->md_map, md_map)) { - ucs_trace("dt_iter_md_map=0x%"PRIx64" md_map=0x%"PRIx64, memh->md_map, - md_map); + ucs_trace("dt_iter_md_map=0x%" PRIx64 " md_map=0x%" PRIx64, + memh->md_map, md_map); } + /* TODO: context lock is not scalable. Consider fine-grained lock per memh, + * immutable memh with rkey cache, RCU/COW */ + UCP_THREAD_CS_ENTER(&context->mt_lock); packed_rkey_size = ucp_rkey_pack_memh( - req->send.ep->worker->context, md_map & memh->md_map, memh, - dt_iter->type.contig.buffer, dt_iter->length, &dt_iter->mem_info, - distance_dev_map, dev_distance, + context, md_map & memh->md_map, memh, dt_iter->type.contig.buffer, + dt_iter->length, &dt_iter->mem_info, distance_dev_map, dev_distance, ucp_ep_config(req->send.ep)->uct_rkey_pack_flags, rkey_buffer); + UCP_THREAD_CS_EXIT(&context->mt_lock); if (packed_rkey_size < 0) { ucs_error("failed to pack remote key: %s", diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index ea510e20b22..edde9176d3c 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -178,11 +178,15 @@ size_t ucp_rndv_rts_pack(ucp_request_t *sreq, ucp_rndv_rts_hdr_t *rndv_rts_hdr, rndv_rts_hdr->address = (uintptr_t)sreq->send.buffer; rkey_buf = UCS_PTR_BYTE_OFFSET(rndv_rts_hdr, sizeof(*rndv_rts_hdr)); - packed_rkey_size = ucp_rkey_pack_memh( + + UCP_THREAD_CS_ENTER(&worker->context->mt_lock); + packed_rkey_size = ucp_rkey_pack_memh( worker->context, sreq->send.rndv.md_map, sreq->send.state.dt.dt.contig.memh, sreq->send.buffer, sreq->send.length, &mem_info, 0, NULL, ucp_ep_config(sreq->send.ep)->uct_rkey_pack_flags, rkey_buf); + UCP_THREAD_CS_EXIT(&worker->context->mt_lock); + if (packed_rkey_size < 0) { ucs_fatal("failed to pack rendezvous remote key: %s", ucs_status_string((ucs_status_t)packed_rkey_size)); @@ -205,6 +209,7 @@ static size_t ucp_rndv_rtr_pack(void *dest, void *arg) ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = dest; ucp_request_t *rreq = ucp_request_get_super(rndv_req); ucp_ep_h ep = rndv_req->send.ep; + ucp_context_h context = ep->worker->context; ucp_memory_info_t mem_info; ssize_t packed_rkey_size; @@ -221,12 +226,15 @@ static size_t ucp_rndv_rtr_pack(void *dest, void *arg) mem_info.type = rreq->recv.dt_iter.mem_info.type; mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN; + UCP_THREAD_CS_ENTER(&context->mt_lock); packed_rkey_size = ucp_rkey_pack_memh( - ep->worker->context, rndv_req->send.rndv.md_map, + context, rndv_req->send.rndv.md_map, rreq->recv.dt_iter.type.contig.memh, rreq->recv.dt_iter.type.contig.buffer, rndv_req->send.length, &mem_info, 0, NULL, ucp_ep_config(ep)->uct_rkey_pack_flags, rndv_rtr_hdr + 1); + UCP_THREAD_CS_EXIT(&context->mt_lock); + if (packed_rkey_size < 0) { return packed_rkey_size; } diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index 38853d7a39d..8ca225d74da 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -252,6 +252,7 @@ static size_t ucp_proto_rndv_rtr_mtype_pack(void *dest, void *arg) { ucp_rndv_rtr_hdr_t *rtr = dest; ucp_request_t *req = arg; + ucp_context_h context = req->send.ep->worker->context; const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; ucp_md_map_t md_map = rpriv->super.md_map; ucp_mem_desc_t *mdesc = req->send.rndv.mdesc; @@ -266,10 +267,14 @@ static size_t ucp_proto_rndv_rtr_mtype_pack(void *dest, void *arg) /* Pack remote key for the fragment */ mem_info.type = mdesc->memh->mem_type; mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN; - packed_rkey_size = ucp_rkey_pack_memh(req->send.ep->worker->context, md_map, - mdesc->memh, mdesc->ptr, + + UCP_THREAD_CS_ENTER(&context->mt_lock); + packed_rkey_size = ucp_rkey_pack_memh(context, md_map, mdesc->memh, + mdesc->ptr, req->send.state.dt_iter.length, &mem_info, 0, NULL, 0, rtr + 1); + UCP_THREAD_CS_EXIT(&context->mt_lock); + if (packed_rkey_size < 0) { ucs_error("failed to pack remote key: %s", ucs_status_string((ucs_status_t)packed_rkey_size)); diff --git a/test/gtest/ucp/test_ucp_rma_mt.cc b/test/gtest/ucp/test_ucp_rma_mt.cc index d1acb05c2eb..a38e3479f84 100644 --- a/test/gtest/ucp/test_ucp_rma_mt.cc +++ b/test/gtest/ucp/test_ucp_rma_mt.cc @@ -9,6 +9,10 @@ #include +extern "C" { +#include +} + #if _OPENMP #include "omp.h" #endif @@ -35,6 +39,22 @@ class test_ucp_rma_mt : public ucp_test { add_variant(variants, UCP_FEATURE_RMA, MULTI_THREAD_CONTEXT); add_variant(variants, UCP_FEATURE_RMA, MULTI_THREAD_WORKER); } + + ucp_mem_h mem_map(ucp_context_h context, void *data, size_t size) + { + ucp_mem_map_params_t params; + ucp_mem_h memh; + + params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_FLAGS; + params.address = data; + params.length = size; + params.flags = get_variant_value(); + + ASSERT_UCS_OK(ucp_mem_map(context, ¶ms, &memh)); + return memh; + } }; UCS_TEST_P(test_ucp_rma_mt, put_get) { @@ -43,19 +63,9 @@ UCS_TEST_P(test_ucp_rma_mt, put_get) { uint64_t orig_data[num_threads] GTEST_ATTRIBUTE_UNUSED_; uint64_t target_data[num_threads] GTEST_ATTRIBUTE_UNUSED_; - ucp_mem_map_params_t params; - ucp_mem_h memh; void *memheap = target_data; - - params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | - UCP_MEM_MAP_PARAM_FIELD_LENGTH | - UCP_MEM_MAP_PARAM_FIELD_FLAGS; - params.address = memheap; - params.length = sizeof(uint64_t) * num_threads; - params.flags = get_variant_value(); - - st = ucp_mem_map(receiver().ucph(), ¶ms, &memh); - ASSERT_UCS_OK(st); + ucp_mem_h memh = mem_map(receiver().ucph(), memheap, + sizeof(uint64_t) * num_threads); void *rkey_buffer; size_t rkey_buffer_size; @@ -200,4 +210,34 @@ UCS_TEST_P(test_ucp_rma_mt, put_get) { ASSERT_UCS_OK(st); } +UCS_TEST_P(test_ucp_rma_mt, rkey_pack) { + uint8_t data[1024] GTEST_ATTRIBUTE_UNUSED_; + ucp_context_h context = sender().ucph(); + ucp_mem_h memh = mem_map(context, data, sizeof(data)); + +#if _OPENMP && ENABLE_MT +#pragma omp parallel for + for (int i = 0; i < mt_num_threads(); i++) { + if (i % 2 == 0) { + void *rkey; + size_t rkey_size; + ASSERT_UCS_OK(ucp_rkey_pack(context, memh, &rkey, &rkey_size)); + ucp_rkey_buffer_release(rkey); + } else { + ucs_sys_dev_distance_t sys_dev = {}; + ucp_request req = {}; + req.send.ep = sender().ep(); + req.send.state.dt_iter.type.contig.memh = memh; + req.send.state.dt_iter.type.contig.buffer = data; + req.send.state.dt_iter.length = sizeof(data); + + uint8_t rkey[1024]; + ucp_proto_request_pack_rkey(&req, memh->md_map, 0, &sys_dev, rkey); + } + } +#endif + + ASSERT_UCS_OK(ucp_mem_unmap(context, memh)); +} + UCP_INSTANTIATE_TEST_CASE(test_ucp_rma_mt) From 8e5530a02f33b197b71ea255fc25107ae08af44a Mon Sep 17 00:00:00 2001 From: Ilia Yastrebov Date: Tue, 4 Feb 2025 15:22:41 +0000 Subject: [PATCH 2/2] UCP: Added new MT mode for context: UCP_MT_TYPE_WORKER_ASYNC --- src/ucp/core/ucp_context.c | 28 +++++++++++++++ src/ucp/core/ucp_context.h | 4 +++ src/ucp/core/ucp_rkey.c | 4 +-- src/ucp/core/ucp_thread.h | 57 ++++++++++++++++++++++++------- src/ucp/core/ucp_worker.c | 10 ++++++ test/gtest/ucp/test_ucp_rma_mt.cc | 10 +++++- 6 files changed, 97 insertions(+), 16 deletions(-) diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index 8b55eb2445b..697f75071fa 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -1963,6 +1963,34 @@ static void ucp_apply_params(ucp_context_h context, const ucp_params_t *params, } } +void ucp_context_set_worker_async(ucp_context_h context, + ucs_async_context_t *async) +{ + if (async != NULL) { + /* Setting new worker async mutex */ + if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) { + ucs_error("worker async %p is already set for context %p", + context->mt_lock.lock.mt_worker_async, context); + } else if (context->mt_lock.mt_type != UCP_MT_TYPE_NONE) { + ucs_debug("context %p is already set with mutex mt_type %d", + context, context->mt_lock.mt_type); + } else { + context->mt_lock.mt_type = UCP_MT_TYPE_WORKER_ASYNC; + context->mt_lock.lock.mt_worker_async = async; + } + } else { + /* Resetting existing worker async mutex */ + if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) { + if (context->mt_lock.lock.mt_worker_async != NULL) { + context->mt_lock.mt_type = UCP_MT_TYPE_NONE; + context->mt_lock.lock.mt_worker_async = NULL; + } else { + ucs_error("worker async is not set for context %p", context); + } + } + } +} + static ucs_status_t ucp_fill_rndv_frag_config(const ucp_context_config_names_t *config, const size_t *default_sizes, size_t *sizes) diff --git a/src/ucp/core/ucp_context.h b/src/ucp/core/ucp_context.h index e19b82c9d19..41e8ab30460 100644 --- a/src/ucp/core/ucp_context.h +++ b/src/ucp/core/ucp_context.h @@ -755,4 +755,8 @@ ucp_config_modify_internal(ucp_config_t *config, const char *name, void ucp_apply_uct_config_list(ucp_context_h context, void *config); + +void ucp_context_set_worker_async(ucp_context_h context, + ucs_async_context_t *async); + #endif diff --git a/src/ucp/core/ucp_rkey.c b/src/ucp/core/ucp_rkey.c index 0efa24ce027..285a68027b6 100644 --- a/src/ucp/core/ucp_rkey.c +++ b/src/ucp/core/ucp_rkey.c @@ -642,7 +642,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params, return UCS_OK; } - UCP_THREAD_CS_ENTER(&context->mt_lock); + UCP_THREAD_CS_ASYNC_ENTER(&context->mt_lock); size = ucp_memh_packed_size(memh, flags, rkey_compat); @@ -677,7 +677,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params, err_destroy: ucs_free(memh_buffer); out: - UCP_THREAD_CS_EXIT(&context->mt_lock); + UCP_THREAD_CS_ASYNC_EXIT(&context->mt_lock); return status; } diff --git a/src/ucp/core/ucp_thread.h b/src/ucp/core/ucp_thread.h index 3b46e8200a0..9f658c29762 100644 --- a/src/ucp/core/ucp_thread.h +++ b/src/ucp/core/ucp_thread.h @@ -22,7 +22,8 @@ typedef enum ucp_mt_type { UCP_MT_TYPE_NONE = 0, UCP_MT_TYPE_SPINLOCK, - UCP_MT_TYPE_MUTEX + UCP_MT_TYPE_MUTEX, + UCP_MT_TYPE_WORKER_ASYNC } ucp_mt_type_t; @@ -36,6 +37,13 @@ typedef struct ucp_mt_lock { at one time. Spinlock is the default option. */ ucs_recursive_spinlock_t mt_spinlock; pthread_mutex_t mt_mutex; + /* Lock for MULTI_THREAD_WORKER case, when mt-single context is used by + * a single mt-shared worker. In this case the worker progress flow is + * already protected by worker mutex, and we don't need to lock inside + * that flow. This is to protect certain API calls that can be triggered + * from the user thread without holding a worker mutex. + * Essentially this mutex is a pointer to a worker mutex */ + ucs_async_context_t *mt_worker_async; } lock; } ucp_mt_lock_t; @@ -58,21 +66,44 @@ typedef struct ucp_mt_lock { pthread_mutex_destroy(&((_lock_ptr)->lock.mt_mutex)); \ } \ } while (0) -#define UCP_THREAD_CS_ENTER(_lock_ptr) \ + +static UCS_F_ALWAYS_INLINE void ucp_mt_lock_lock(ucp_mt_lock_t *lock) +{ + if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) { + ucs_recursive_spin_lock(&lock->lock.mt_spinlock); + } else if (lock->mt_type == UCP_MT_TYPE_MUTEX) { + pthread_mutex_lock(&lock->lock.mt_mutex); + } +} + +static UCS_F_ALWAYS_INLINE void ucp_mt_lock_unlock(ucp_mt_lock_t *lock) +{ + if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) { + ucs_recursive_spin_unlock(&lock->lock.mt_spinlock); + } else if (lock->mt_type == UCP_MT_TYPE_MUTEX) { + pthread_mutex_unlock(&lock->lock.mt_mutex); + } +} + +#define UCP_THREAD_CS_ENTER(_lock_ptr) ucp_mt_lock_lock(_lock_ptr) +#define UCP_THREAD_CS_EXIT(_lock_ptr) ucp_mt_lock_unlock(_lock_ptr) + +#define UCP_THREAD_CS_ASYNC_ENTER(_lock_ptr) \ do { \ - if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \ - ucs_recursive_spin_lock(&((_lock_ptr)->lock.mt_spinlock)); \ - } else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \ - pthread_mutex_lock(&((_lock_ptr)->lock.mt_mutex)); \ + if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \ + UCS_ASYNC_BLOCK((_lock_ptr)->lock.mt_worker_async); \ + } else { \ + ucp_mt_lock_lock(_lock_ptr); \ } \ - } while (0) -#define UCP_THREAD_CS_EXIT(_lock_ptr) \ + } while(0) + +#define UCP_THREAD_CS_ASYNC_EXIT(_lock_ptr) \ do { \ - if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \ - ucs_recursive_spin_unlock(&((_lock_ptr)->lock.mt_spinlock)); \ - } else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \ - pthread_mutex_unlock(&((_lock_ptr)->lock.mt_mutex)); \ + if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \ + UCS_ASYNC_UNBLOCK((_lock_ptr)->lock.mt_worker_async); \ + } else { \ + ucp_mt_lock_unlock(_lock_ptr); \ } \ - } while (0) + } while(0) #endif diff --git a/src/ucp/core/ucp_worker.c b/src/ucp/core/ucp_worker.c index 28e874d68ae..dbcb86955d4 100644 --- a/src/ucp/core/ucp_worker.c +++ b/src/ucp/core/ucp_worker.c @@ -2569,6 +2569,10 @@ ucs_status_t ucp_worker_create(ucp_context_h context, goto err_free_tm_offload_stats; } + if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) { + ucp_context_set_worker_async(context, &worker->async); + } + /* Create the underlying UCT worker */ status = uct_worker_create(&worker->async, uct_thread_mode, &worker->uct); if (status != UCS_OK) { @@ -2668,6 +2672,9 @@ ucs_status_t ucp_worker_create(ucp_context_h context, err_destroy_uct_worker: uct_worker_destroy(worker->uct); err_destroy_async: + if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) { + ucp_context_set_worker_async(context, NULL); + } ucs_async_context_cleanup(&worker->async); err_free_tm_offload_stats: UCS_STATS_NODE_FREE(worker->tm_offload_stats); @@ -2923,6 +2930,9 @@ void ucp_worker_destroy(ucp_worker_h worker) ucs_conn_match_cleanup(&worker->conn_match_ctx); ucp_worker_wakeup_cleanup(worker); uct_worker_destroy(worker->uct); + if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) { + ucp_context_set_worker_async(worker->context, NULL); + } ucs_async_context_cleanup(&worker->async); UCS_STATS_NODE_FREE(worker->tm_offload_stats); UCS_STATS_NODE_FREE(worker->stats); diff --git a/test/gtest/ucp/test_ucp_rma_mt.cc b/test/gtest/ucp/test_ucp_rma_mt.cc index a38e3479f84..0fd3e79bf34 100644 --- a/test/gtest/ucp/test_ucp_rma_mt.cc +++ b/test/gtest/ucp/test_ucp_rma_mt.cc @@ -218,6 +218,11 @@ UCS_TEST_P(test_ucp_rma_mt, rkey_pack) { #if _OPENMP && ENABLE_MT #pragma omp parallel for for (int i = 0; i < mt_num_threads(); i++) { + int worker_index = 0; + if (get_variant_thread_type() == MULTI_THREAD_CONTEXT) { + worker_index = i; + } + if (i % 2 == 0) { void *rkey; size_t rkey_size; @@ -226,13 +231,16 @@ UCS_TEST_P(test_ucp_rma_mt, rkey_pack) { } else { ucs_sys_dev_distance_t sys_dev = {}; ucp_request req = {}; - req.send.ep = sender().ep(); + req.send.ep = sender().ep(worker_index); req.send.state.dt_iter.type.contig.memh = memh; req.send.state.dt_iter.type.contig.buffer = data; req.send.state.dt_iter.length = sizeof(data); uint8_t rkey[1024]; + ucp_worker_h worker = sender().worker(worker_index); + UCP_WORKER_THREAD_CS_ENTER_CONDITIONAL(worker); ucp_proto_request_pack_rkey(&req, memh->md_map, 0, &sys_dev, rkey); + UCP_WORKER_THREAD_CS_EXIT_CONDITIONAL(worker); } } #endif