Skip to content

Commit 8e5530a

Browse files
committed
UCP: Added new MT mode for context: UCP_MT_TYPE_WORKER_ASYNC
1 parent 8c692ca commit 8e5530a

File tree

6 files changed

+97
-16
lines changed

6 files changed

+97
-16
lines changed

src/ucp/core/ucp_context.c

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,34 @@ static void ucp_apply_params(ucp_context_h context, const ucp_params_t *params,
19631963
}
19641964
}
19651965

1966+
void ucp_context_set_worker_async(ucp_context_h context,
1967+
ucs_async_context_t *async)
1968+
{
1969+
if (async != NULL) {
1970+
/* Setting new worker async mutex */
1971+
if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) {
1972+
ucs_error("worker async %p is already set for context %p",
1973+
context->mt_lock.lock.mt_worker_async, context);
1974+
} else if (context->mt_lock.mt_type != UCP_MT_TYPE_NONE) {
1975+
ucs_debug("context %p is already set with mutex mt_type %d",
1976+
context, context->mt_lock.mt_type);
1977+
} else {
1978+
context->mt_lock.mt_type = UCP_MT_TYPE_WORKER_ASYNC;
1979+
context->mt_lock.lock.mt_worker_async = async;
1980+
}
1981+
} else {
1982+
/* Resetting existing worker async mutex */
1983+
if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) {
1984+
if (context->mt_lock.lock.mt_worker_async != NULL) {
1985+
context->mt_lock.mt_type = UCP_MT_TYPE_NONE;
1986+
context->mt_lock.lock.mt_worker_async = NULL;
1987+
} else {
1988+
ucs_error("worker async is not set for context %p", context);
1989+
}
1990+
}
1991+
}
1992+
}
1993+
19661994
static ucs_status_t
19671995
ucp_fill_rndv_frag_config(const ucp_context_config_names_t *config,
19681996
const size_t *default_sizes, size_t *sizes)

src/ucp/core/ucp_context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,4 +755,8 @@ ucp_config_modify_internal(ucp_config_t *config, const char *name,
755755

756756
void ucp_apply_uct_config_list(ucp_context_h context, void *config);
757757

758+
759+
void ucp_context_set_worker_async(ucp_context_h context,
760+
ucs_async_context_t *async);
761+
758762
#endif

src/ucp/core/ucp_rkey.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params,
642642
return UCS_OK;
643643
}
644644

645-
UCP_THREAD_CS_ENTER(&context->mt_lock);
645+
UCP_THREAD_CS_ASYNC_ENTER(&context->mt_lock);
646646

647647
size = ucp_memh_packed_size(memh, flags, rkey_compat);
648648

@@ -677,7 +677,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params,
677677
err_destroy:
678678
ucs_free(memh_buffer);
679679
out:
680-
UCP_THREAD_CS_EXIT(&context->mt_lock);
680+
UCP_THREAD_CS_ASYNC_EXIT(&context->mt_lock);
681681
return status;
682682
}
683683

src/ucp/core/ucp_thread.h

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
typedef enum ucp_mt_type {
2323
UCP_MT_TYPE_NONE = 0,
2424
UCP_MT_TYPE_SPINLOCK,
25-
UCP_MT_TYPE_MUTEX
25+
UCP_MT_TYPE_MUTEX,
26+
UCP_MT_TYPE_WORKER_ASYNC
2627
} ucp_mt_type_t;
2728

2829

@@ -36,6 +37,13 @@ typedef struct ucp_mt_lock {
3637
at one time. Spinlock is the default option. */
3738
ucs_recursive_spinlock_t mt_spinlock;
3839
pthread_mutex_t mt_mutex;
40+
/* Lock for MULTI_THREAD_WORKER case, when mt-single context is used by
41+
* a single mt-shared worker. In this case the worker progress flow is
42+
* already protected by worker mutex, and we don't need to lock inside
43+
* that flow. This is to protect certain API calls that can be triggered
44+
* from the user thread without holding a worker mutex.
45+
* Essentially this mutex is a pointer to a worker mutex */
46+
ucs_async_context_t *mt_worker_async;
3947
} lock;
4048
} ucp_mt_lock_t;
4149

@@ -58,21 +66,44 @@ typedef struct ucp_mt_lock {
5866
pthread_mutex_destroy(&((_lock_ptr)->lock.mt_mutex)); \
5967
} \
6068
} while (0)
61-
#define UCP_THREAD_CS_ENTER(_lock_ptr) \
69+
70+
static UCS_F_ALWAYS_INLINE void ucp_mt_lock_lock(ucp_mt_lock_t *lock)
71+
{
72+
if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) {
73+
ucs_recursive_spin_lock(&lock->lock.mt_spinlock);
74+
} else if (lock->mt_type == UCP_MT_TYPE_MUTEX) {
75+
pthread_mutex_lock(&lock->lock.mt_mutex);
76+
}
77+
}
78+
79+
static UCS_F_ALWAYS_INLINE void ucp_mt_lock_unlock(ucp_mt_lock_t *lock)
80+
{
81+
if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) {
82+
ucs_recursive_spin_unlock(&lock->lock.mt_spinlock);
83+
} else if (lock->mt_type == UCP_MT_TYPE_MUTEX) {
84+
pthread_mutex_unlock(&lock->lock.mt_mutex);
85+
}
86+
}
87+
88+
#define UCP_THREAD_CS_ENTER(_lock_ptr) ucp_mt_lock_lock(_lock_ptr)
89+
#define UCP_THREAD_CS_EXIT(_lock_ptr) ucp_mt_lock_unlock(_lock_ptr)
90+
91+
#define UCP_THREAD_CS_ASYNC_ENTER(_lock_ptr) \
6292
do { \
63-
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \
64-
ucs_recursive_spin_lock(&((_lock_ptr)->lock.mt_spinlock)); \
65-
} else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \
66-
pthread_mutex_lock(&((_lock_ptr)->lock.mt_mutex)); \
93+
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \
94+
UCS_ASYNC_BLOCK((_lock_ptr)->lock.mt_worker_async); \
95+
} else { \
96+
ucp_mt_lock_lock(_lock_ptr); \
6797
} \
68-
} while (0)
69-
#define UCP_THREAD_CS_EXIT(_lock_ptr) \
98+
} while(0)
99+
100+
#define UCP_THREAD_CS_ASYNC_EXIT(_lock_ptr) \
70101
do { \
71-
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \
72-
ucs_recursive_spin_unlock(&((_lock_ptr)->lock.mt_spinlock)); \
73-
} else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \
74-
pthread_mutex_unlock(&((_lock_ptr)->lock.mt_mutex)); \
102+
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \
103+
UCS_ASYNC_UNBLOCK((_lock_ptr)->lock.mt_worker_async); \
104+
} else { \
105+
ucp_mt_lock_unlock(_lock_ptr); \
75106
} \
76-
} while (0)
107+
} while(0)
77108

78109
#endif

src/ucp/core/ucp_worker.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2569,6 +2569,10 @@ ucs_status_t ucp_worker_create(ucp_context_h context,
25692569
goto err_free_tm_offload_stats;
25702570
}
25712571

2572+
if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
2573+
ucp_context_set_worker_async(context, &worker->async);
2574+
}
2575+
25722576
/* Create the underlying UCT worker */
25732577
status = uct_worker_create(&worker->async, uct_thread_mode, &worker->uct);
25742578
if (status != UCS_OK) {
@@ -2668,6 +2672,9 @@ ucs_status_t ucp_worker_create(ucp_context_h context,
26682672
err_destroy_uct_worker:
26692673
uct_worker_destroy(worker->uct);
26702674
err_destroy_async:
2675+
if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
2676+
ucp_context_set_worker_async(context, NULL);
2677+
}
26712678
ucs_async_context_cleanup(&worker->async);
26722679
err_free_tm_offload_stats:
26732680
UCS_STATS_NODE_FREE(worker->tm_offload_stats);
@@ -2923,6 +2930,9 @@ void ucp_worker_destroy(ucp_worker_h worker)
29232930
ucs_conn_match_cleanup(&worker->conn_match_ctx);
29242931
ucp_worker_wakeup_cleanup(worker);
29252932
uct_worker_destroy(worker->uct);
2933+
if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
2934+
ucp_context_set_worker_async(worker->context, NULL);
2935+
}
29262936
ucs_async_context_cleanup(&worker->async);
29272937
UCS_STATS_NODE_FREE(worker->tm_offload_stats);
29282938
UCS_STATS_NODE_FREE(worker->stats);

test/gtest/ucp/test_ucp_rma_mt.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ UCS_TEST_P(test_ucp_rma_mt, rkey_pack) {
218218
#if _OPENMP && ENABLE_MT
219219
#pragma omp parallel for
220220
for (int i = 0; i < mt_num_threads(); i++) {
221+
int worker_index = 0;
222+
if (get_variant_thread_type() == MULTI_THREAD_CONTEXT) {
223+
worker_index = i;
224+
}
225+
221226
if (i % 2 == 0) {
222227
void *rkey;
223228
size_t rkey_size;
@@ -226,13 +231,16 @@ UCS_TEST_P(test_ucp_rma_mt, rkey_pack) {
226231
} else {
227232
ucs_sys_dev_distance_t sys_dev = {};
228233
ucp_request req = {};
229-
req.send.ep = sender().ep();
234+
req.send.ep = sender().ep(worker_index);
230235
req.send.state.dt_iter.type.contig.memh = memh;
231236
req.send.state.dt_iter.type.contig.buffer = data;
232237
req.send.state.dt_iter.length = sizeof(data);
233238

234239
uint8_t rkey[1024];
240+
ucp_worker_h worker = sender().worker(worker_index);
241+
UCP_WORKER_THREAD_CS_ENTER_CONDITIONAL(worker);
235242
ucp_proto_request_pack_rkey(&req, memh->md_map, 0, &sys_dev, rkey);
243+
UCP_WORKER_THREAD_CS_EXIT_CONDITIONAL(worker);
236244
}
237245
}
238246
#endif

0 commit comments

Comments
 (0)