Skip to content

Commit

Permalink
UCP: Added new MT mode for context: UCP_MT_TYPE_WORKER_ASYNC
Browse files Browse the repository at this point in the history
  • Loading branch information
iyastreb committed Feb 4, 2025
1 parent 8c692ca commit 8e5530a
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 16 deletions.
28 changes: 28 additions & 0 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,34 @@ static void ucp_apply_params(ucp_context_h context, const ucp_params_t *params,
}
}

void ucp_context_set_worker_async(ucp_context_h context,
ucs_async_context_t *async)
{
if (async != NULL) {
/* Setting new worker async mutex */
if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) {
ucs_error("worker async %p is already set for context %p",
context->mt_lock.lock.mt_worker_async, context);
} else if (context->mt_lock.mt_type != UCP_MT_TYPE_NONE) {
ucs_debug("context %p is already set with mutex mt_type %d",
context, context->mt_lock.mt_type);
} else {
context->mt_lock.mt_type = UCP_MT_TYPE_WORKER_ASYNC;
context->mt_lock.lock.mt_worker_async = async;
}
} else {
/* Resetting existing worker async mutex */
if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) {
if (context->mt_lock.lock.mt_worker_async != NULL) {
context->mt_lock.mt_type = UCP_MT_TYPE_NONE;
context->mt_lock.lock.mt_worker_async = NULL;
} else {
ucs_error("worker async is not set for context %p", context);
}
}
}
}

static ucs_status_t
ucp_fill_rndv_frag_config(const ucp_context_config_names_t *config,
const size_t *default_sizes, size_t *sizes)
Expand Down
4 changes: 4 additions & 0 deletions src/ucp/core/ucp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,4 +755,8 @@ ucp_config_modify_internal(ucp_config_t *config, const char *name,

void ucp_apply_uct_config_list(ucp_context_h context, void *config);


void ucp_context_set_worker_async(ucp_context_h context,
ucs_async_context_t *async);

#endif
4 changes: 2 additions & 2 deletions src/ucp/core/ucp_rkey.c
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params,
return UCS_OK;
}

UCP_THREAD_CS_ENTER(&context->mt_lock);
UCP_THREAD_CS_ASYNC_ENTER(&context->mt_lock);

size = ucp_memh_packed_size(memh, flags, rkey_compat);

Expand Down Expand Up @@ -677,7 +677,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params,
err_destroy:
ucs_free(memh_buffer);
out:
UCP_THREAD_CS_EXIT(&context->mt_lock);
UCP_THREAD_CS_ASYNC_EXIT(&context->mt_lock);
return status;
}

Expand Down
57 changes: 44 additions & 13 deletions src/ucp/core/ucp_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
typedef enum ucp_mt_type {
UCP_MT_TYPE_NONE = 0,
UCP_MT_TYPE_SPINLOCK,
UCP_MT_TYPE_MUTEX
UCP_MT_TYPE_MUTEX,
UCP_MT_TYPE_WORKER_ASYNC
} ucp_mt_type_t;


Expand All @@ -36,6 +37,13 @@ typedef struct ucp_mt_lock {
at one time. Spinlock is the default option. */
ucs_recursive_spinlock_t mt_spinlock;
pthread_mutex_t mt_mutex;
/* Lock for MULTI_THREAD_WORKER case, when mt-single context is used by
* a single mt-shared worker. In this case the worker progress flow is
* already protected by worker mutex, and we don't need to lock inside
* that flow. This is to protect certain API calls that can be triggered
* from the user thread without holding a worker mutex.
* Essentially this mutex is a pointer to a worker mutex */
ucs_async_context_t *mt_worker_async;
} lock;
} ucp_mt_lock_t;

Expand All @@ -58,21 +66,44 @@ typedef struct ucp_mt_lock {
pthread_mutex_destroy(&((_lock_ptr)->lock.mt_mutex)); \
} \
} while (0)
#define UCP_THREAD_CS_ENTER(_lock_ptr) \

static UCS_F_ALWAYS_INLINE void ucp_mt_lock_lock(ucp_mt_lock_t *lock)
{
if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) {
ucs_recursive_spin_lock(&lock->lock.mt_spinlock);
} else if (lock->mt_type == UCP_MT_TYPE_MUTEX) {
pthread_mutex_lock(&lock->lock.mt_mutex);
}
}

static UCS_F_ALWAYS_INLINE void ucp_mt_lock_unlock(ucp_mt_lock_t *lock)
{
if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) {
ucs_recursive_spin_unlock(&lock->lock.mt_spinlock);
} else if (lock->mt_type == UCP_MT_TYPE_MUTEX) {
pthread_mutex_unlock(&lock->lock.mt_mutex);
}
}

#define UCP_THREAD_CS_ENTER(_lock_ptr) ucp_mt_lock_lock(_lock_ptr)
#define UCP_THREAD_CS_EXIT(_lock_ptr) ucp_mt_lock_unlock(_lock_ptr)

#define UCP_THREAD_CS_ASYNC_ENTER(_lock_ptr) \
do { \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \
ucs_recursive_spin_lock(&((_lock_ptr)->lock.mt_spinlock)); \
} else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \
pthread_mutex_lock(&((_lock_ptr)->lock.mt_mutex)); \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \
UCS_ASYNC_BLOCK((_lock_ptr)->lock.mt_worker_async); \
} else { \
ucp_mt_lock_lock(_lock_ptr); \
} \
} while (0)
#define UCP_THREAD_CS_EXIT(_lock_ptr) \
} while(0)

#define UCP_THREAD_CS_ASYNC_EXIT(_lock_ptr) \
do { \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \
ucs_recursive_spin_unlock(&((_lock_ptr)->lock.mt_spinlock)); \
} else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \
pthread_mutex_unlock(&((_lock_ptr)->lock.mt_mutex)); \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \
UCS_ASYNC_UNBLOCK((_lock_ptr)->lock.mt_worker_async); \
} else { \
ucp_mt_lock_unlock(_lock_ptr); \
} \
} while (0)
} while(0)

#endif
10 changes: 10 additions & 0 deletions src/ucp/core/ucp_worker.c
Original file line number Diff line number Diff line change
Expand Up @@ -2569,6 +2569,10 @@ ucs_status_t ucp_worker_create(ucp_context_h context,
goto err_free_tm_offload_stats;
}

if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
ucp_context_set_worker_async(context, &worker->async);
}

/* Create the underlying UCT worker */
status = uct_worker_create(&worker->async, uct_thread_mode, &worker->uct);
if (status != UCS_OK) {
Expand Down Expand Up @@ -2668,6 +2672,9 @@ ucs_status_t ucp_worker_create(ucp_context_h context,
err_destroy_uct_worker:
uct_worker_destroy(worker->uct);
err_destroy_async:
if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
ucp_context_set_worker_async(context, NULL);
}
ucs_async_context_cleanup(&worker->async);
err_free_tm_offload_stats:
UCS_STATS_NODE_FREE(worker->tm_offload_stats);
Expand Down Expand Up @@ -2923,6 +2930,9 @@ void ucp_worker_destroy(ucp_worker_h worker)
ucs_conn_match_cleanup(&worker->conn_match_ctx);
ucp_worker_wakeup_cleanup(worker);
uct_worker_destroy(worker->uct);
if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
ucp_context_set_worker_async(worker->context, NULL);
}
ucs_async_context_cleanup(&worker->async);
UCS_STATS_NODE_FREE(worker->tm_offload_stats);
UCS_STATS_NODE_FREE(worker->stats);
Expand Down
10 changes: 9 additions & 1 deletion test/gtest/ucp/test_ucp_rma_mt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ UCS_TEST_P(test_ucp_rma_mt, rkey_pack) {
#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < mt_num_threads(); i++) {
int worker_index = 0;
if (get_variant_thread_type() == MULTI_THREAD_CONTEXT) {
worker_index = i;
}

if (i % 2 == 0) {
void *rkey;
size_t rkey_size;
Expand All @@ -226,13 +231,16 @@ UCS_TEST_P(test_ucp_rma_mt, rkey_pack) {
} else {
ucs_sys_dev_distance_t sys_dev = {};
ucp_request req = {};
req.send.ep = sender().ep();
req.send.ep = sender().ep(worker_index);
req.send.state.dt_iter.type.contig.memh = memh;
req.send.state.dt_iter.type.contig.buffer = data;
req.send.state.dt_iter.length = sizeof(data);

uint8_t rkey[1024];
ucp_worker_h worker = sender().worker(worker_index);
UCP_WORKER_THREAD_CS_ENTER_CONDITIONAL(worker);
ucp_proto_request_pack_rkey(&req, memh->md_map, 0, &sys_dev, rkey);
UCP_WORKER_THREAD_CS_EXIT_CONDITIONAL(worker);
}
}
#endif
Expand Down

0 comments on commit 8e5530a

Please sign in to comment.