Skip to content

Commit 7f85fba

Browse files
nsarkaNicholas Sarkauskas
andauthored
Asymmetric memory (openucx#1000)
* CORE: Implement weak asymmetric mem with gtests * CORE: Fix asymmetric bug --------- Co-authored-by: Nicholas Sarkauskas <[email protected]>
1 parent 2ada313 commit 7f85fba

File tree

8 files changed

+761
-20
lines changed

8 files changed

+761
-20
lines changed

src/coll_score/ucc_coll_score_map.c

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,12 @@ static ucc_status_t ucc_coll_score_map_lookup(ucc_score_map_t *map,
8787
ucc_list_link_t *list;
8888
ucc_msg_range_t *r;
8989

90-
if (mt == UCC_MEMORY_TYPE_ASYMMETRIC) {
91-
/* TODO */
92-
ucc_debug("asymmetric memory type is not supported");
93-
return UCC_ERR_NOT_SUPPORTED;
94-
} else if (mt == UCC_MEMORY_TYPE_NOT_APPLY) {
90+
if (mt == UCC_MEMORY_TYPE_NOT_APPLY) {
9591
/* Temporary solution: for Barrier, Fanin, Fanout - use
9692
"host" range list */
9793
mt = UCC_MEMORY_TYPE_HOST;
9894
}
95+
ucc_assert(ucc_coll_args_is_mem_symmetric(&bargs->args, map->team_rank));
9996
if (msgsize == UCC_MSG_SIZE_INVALID || msgsize == UCC_MSG_SIZE_ASYMMETRIC) {
10097
/* These algorithms require global communication to get the same msgsize estimation.
10198
Can't use msg ranges. Use msize 0 (assuming the range list should only contain 1

src/components/base/ucc_base_iface.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,20 @@ enum {
159159
UCC_BASE_CARGS_MAX_FRAG_COUNT = UCC_BIT(0)
160160
};
161161

162+
typedef struct ucc_buffer_info_asymmetric_memtype {
163+
union {
164+
ucc_coll_buffer_info_t info;
165+
ucc_coll_buffer_info_v_t info_v;
166+
} old_asymmetric_buffer;
167+
ucc_mc_buffer_header_t *scratch;
168+
} ucc_buffer_info_asymmetric_memtype_t;
169+
162170
typedef struct ucc_base_coll_args {
163-
uint64_t mask;
164-
ucc_coll_args_t args;
165-
ucc_team_t *team;
166-
size_t max_frag_count;
171+
uint64_t mask;
172+
ucc_coll_args_t args;
173+
ucc_team_t *team;
174+
size_t max_frag_count;
175+
ucc_buffer_info_asymmetric_memtype_t asymmetric_save_info;
167176
} ucc_base_coll_args_t;
168177

169178
typedef ucc_status_t (*ucc_base_coll_init_fn_t)(ucc_base_coll_args_t *coll_args,

src/core/ucc_coll.c

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,19 +230,31 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
230230
UCC_COPY_PARAM_BY_FIELD(&op_args.args, coll_args, UCC_COLL_ARGS_FIELD_FLAGS,
231231
flags);
232232

233+
if (!ucc_coll_args_is_mem_symmetric(&op_args.args, team->rank) &&
234+
ucc_coll_args_is_rooted(op_args.args.coll_type)) {
235+
status = ucc_coll_args_init_asymmetric_buffer(&op_args.args, team,
236+
&op_args.asymmetric_save_info);
237+
if (ucc_unlikely(status != UCC_OK)) {
238+
ucc_error("handling asymmetric memory failed");
239+
return status;
240+
}
241+
} else {
242+
op_args.asymmetric_save_info.scratch = NULL;
243+
}
244+
233245
status = ucc_coll_init(team->score_map, &op_args, &task);
234246
if (UCC_ERR_NOT_SUPPORTED == status) {
235247
ucc_debug("failed to init collective: not supported");
236-
return status;
248+
goto free_scratch;
237249
} else if (ucc_unlikely(status < 0)) {
238250
ucc_error("failed to init collective: %s", ucc_status_string(status));
239-
return status;
251+
goto free_scratch;
240252
}
241253

242254
task->flags |= UCC_COLL_TASK_FLAG_TOP_LEVEL;
243255
if (task->flags & UCC_COLL_TASK_FLAG_EXECUTOR) {
244256
task->flags |= UCC_COLL_TASK_FLAG_EXECUTOR_STOP;
245-
coll_mem_type = ucc_coll_args_mem_type(coll_args, team->rank);
257+
coll_mem_type = ucc_coll_args_mem_type(&op_args.args, team->rank);
246258
switch(coll_mem_type) {
247259
case UCC_MEMORY_TYPE_CUDA:
248260
case UCC_MEMORY_TYPE_CUDA_MANAGED:
@@ -251,7 +263,7 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
251263
case UCC_MEMORY_TYPE_ROCM:
252264
coll_ee_type = UCC_EE_ROCM_STREAM;
253265
break;
254-
case UCC_MEMORY_TYPE_HOST:
266+
case UCC_MEMORY_TYPE_HOST:
255267
coll_ee_type = UCC_EE_CPU_THREAD;
256268
break;
257269
default:
@@ -299,6 +311,10 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
299311

300312
coll_finalize:
301313
task->finalize(task);
314+
free_scratch:
315+
if (op_args.asymmetric_save_info.scratch != NULL) {
316+
ucc_mc_free(op_args.asymmetric_save_info.scratch);
317+
}
302318
return status;
303319
}
304320

@@ -341,6 +357,17 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_post, (request),
341357
}
342358
}
343359

360+
if (task->bargs.asymmetric_save_info.scratch != NULL &&
361+
(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTER ||
362+
task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTERV)) {
363+
status = ucc_copy_asymmetric_buffer(task);
364+
if (status != UCC_OK) {
365+
ucc_error("failure copying in asymmetric buffer: %s",
366+
ucc_status_string(status));
367+
return status;
368+
}
369+
}
370+
344371
COLL_POST_STATUS_CHECK(task);
345372
if (UCC_COLL_TIMEOUT_REQUIRED(task)) {
346373
task->start_time = ucc_get_time();
@@ -402,6 +429,13 @@ ucc_status_t ucc_collective_finalize_internal(ucc_coll_task_t *task)
402429
return UCC_ERR_INVALID_PARAM;
403430
}
404431

432+
if (task->bargs.asymmetric_save_info.scratch) {
433+
st = ucc_coll_args_free_asymmetric_buffer(task);
434+
if (ucc_unlikely(st != UCC_OK)) {
435+
ucc_error("error freeing asymmetric buf: %s", ucc_status_string(st));
436+
}
437+
}
438+
405439
if (task->executor) {
406440
st = ucc_ee_executor_finalize(task->executor);
407441
if (ucc_unlikely(st != UCC_OK)) {

src/schedule/ucc_schedule.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "utils/ucc_coll_utils.h"
1515
#include "components/base/ucc_base_iface.h"
1616
#include "components/ec/ucc_ec.h"
17+
#include "components/mc/ucc_mc.h"
1718

1819
#define MAX_LISTENERS 4
1920

@@ -185,6 +186,16 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
185186
with schedules are not released during a callback (if set). */
186187

187188
if (ucc_likely(status == UCC_OK)) {
189+
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
190+
if (save->scratch &&
191+
task->bargs.args.coll_type != UCC_COLL_TYPE_SCATTERV &&
192+
task->bargs.args.coll_type != UCC_COLL_TYPE_SCATTER) {
193+
status = ucc_copy_asymmetric_buffer(task);
194+
if (status != UCC_OK) {
195+
ucc_error("failure copying out asymmetric buffer: %s",
196+
ucc_status_string(status));
197+
}
198+
}
188199
status = ucc_event_manager_notify(task, UCC_EVENT_COMPLETED);
189200
} else {
190201
/* error in task status */

src/utils/ucc_coll_utils.c

Lines changed: 176 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ ucc_memory_type_t ucc_mem_type_from_str(const char *str)
5252
return UCC_MEMORY_TYPE_LAST;
5353
}
5454

55-
static inline int
55+
int
5656
ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args,
5757
ucc_rank_t rank)
5858
{
@@ -94,6 +94,180 @@ ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args,
9494
return 0;
9595
}
9696

97+
98+
/* If this is the root and the src/dst buffers are asymmetric, one buffer needs
99+
to have a new allocation to make the mem types match. If that buffer was the
100+
dst buffer, copy the result back into the old dst on task completion */
101+
ucc_status_t
102+
ucc_coll_args_init_asymmetric_buffer(ucc_coll_args_t *args,
103+
ucc_team_h team,
104+
ucc_buffer_info_asymmetric_memtype_t *save_info)
105+
{
106+
ucc_status_t status = UCC_OK;
107+
108+
if (UCC_IS_INPLACE(*args)) {
109+
return UCC_ERR_INVALID_PARAM;
110+
}
111+
switch (args->coll_type) {
112+
case UCC_COLL_TYPE_REDUCE:
113+
case UCC_COLL_TYPE_GATHER:
114+
{
115+
ucc_memory_type_t mem_type = args->src.info.mem_type;
116+
if (args->coll_type == UCC_COLL_TYPE_SCATTERV) {
117+
mem_type = args->src.info_v.mem_type;
118+
}
119+
memcpy(&save_info->old_asymmetric_buffer.info,
120+
&args->dst.info, sizeof(ucc_coll_buffer_info_t));
121+
status = ucc_mc_alloc(&save_info->scratch,
122+
ucc_dt_size(args->dst.info.datatype) *
123+
args->dst.info.count,
124+
mem_type);
125+
if (ucc_unlikely(UCC_OK != status)) {
126+
ucc_error("failed to allocate replacement "
127+
"memory for asymmetric buffer");
128+
return status;
129+
}
130+
args->dst.info.buffer = save_info->scratch->addr;
131+
args->dst.info.mem_type = mem_type;
132+
return UCC_OK;
133+
}
134+
case UCC_COLL_TYPE_GATHERV:
135+
{
136+
memcpy(&save_info->old_asymmetric_buffer.info_v,
137+
&args->dst.info_v, sizeof(ucc_coll_buffer_info_v_t));
138+
status = ucc_mc_alloc(&save_info->scratch,
139+
ucc_coll_args_get_v_buffer_size(args,
140+
args->dst.info_v.counts,
141+
args->dst.info_v.displacements,
142+
team->size),
143+
args->src.info.mem_type);
144+
if (ucc_unlikely(UCC_OK != status)) {
145+
ucc_error("failed to allocate replacement "
146+
"memory for asymmetric buffer");
147+
return status;
148+
}
149+
args->dst.info_v.buffer = save_info->scratch->addr;
150+
args->dst.info_v.mem_type = args->src.info.mem_type;
151+
return UCC_OK;
152+
}
153+
case UCC_COLL_TYPE_SCATTER:
154+
{
155+
ucc_memory_type_t mem_type = args->dst.info.mem_type;
156+
memcpy(&save_info->old_asymmetric_buffer.info,
157+
&args->src.info, sizeof(ucc_coll_buffer_info_t));
158+
status = ucc_mc_alloc(&save_info->scratch,
159+
ucc_dt_size(args->src.info.datatype) * args->src.info.count,
160+
mem_type);
161+
if (ucc_unlikely(UCC_OK != status)) {
162+
ucc_error("failed to allocate replacement "
163+
"memory for asymmetric buffer");
164+
return status;
165+
}
166+
args->src.info.buffer = save_info->scratch->addr;
167+
args->src.info.mem_type = mem_type;
168+
return UCC_OK;
169+
}
170+
case UCC_COLL_TYPE_SCATTERV:
171+
{
172+
ucc_memory_type_t mem_type = args->dst.info.mem_type;
173+
memcpy(&save_info->old_asymmetric_buffer.info_v,
174+
&args->src.info_v, sizeof(ucc_coll_buffer_info_v_t));
175+
status = ucc_mc_alloc(&save_info->scratch,
176+
ucc_coll_args_get_v_buffer_size(args,
177+
args->src.info_v.counts,
178+
args->src.info_v.displacements,
179+
team->size),
180+
mem_type);
181+
if (ucc_unlikely(UCC_OK != status)) {
182+
ucc_error("failed to allocate replacement "
183+
"memory for asymmetric buffer");
184+
return status;
185+
}
186+
args->src.info_v.buffer = save_info->scratch->addr;
187+
args->src.info_v.mem_type = mem_type;
188+
return UCC_OK;
189+
}
190+
default:
191+
break;
192+
}
193+
return UCC_ERR_INVALID_PARAM;
194+
}
195+
196+
ucc_status_t
197+
ucc_coll_args_free_asymmetric_buffer(ucc_coll_task_t *task)
198+
{
199+
ucc_status_t status = UCC_OK;
200+
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
201+
202+
if (UCC_IS_INPLACE(task->bargs.args)) {
203+
return UCC_ERR_INVALID_PARAM;
204+
}
205+
206+
if (save->scratch == NULL) {
207+
ucc_error("failure trying to free NULL asymmetric buffer");
208+
}
209+
210+
status = ucc_mc_free(save->scratch);
211+
if (ucc_unlikely(status != UCC_OK)) {
212+
ucc_error("error freeing scratch asymmetric buffer: %s",
213+
ucc_status_string(status));
214+
}
215+
save->scratch = NULL;
216+
217+
return status;
218+
}
219+
220+
ucc_status_t ucc_copy_asymmetric_buffer(ucc_coll_task_t *task)
221+
{
222+
ucc_status_t status = UCC_OK;
223+
ucc_coll_args_t *coll_args = &task->bargs.args;
224+
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
225+
ucc_rank_t size = task->team->params.size;
226+
227+
if(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTERV) {
228+
// copy in
229+
status = ucc_mc_memcpy(save->scratch->addr,
230+
save->old_asymmetric_buffer.info_v.buffer,
231+
ucc_coll_args_get_v_buffer_size(coll_args,
232+
save->old_asymmetric_buffer.info_v.counts,
233+
save->old_asymmetric_buffer.info_v.displacements,
234+
size),
235+
save->scratch->mt,
236+
save->old_asymmetric_buffer.info_v.mem_type);
237+
} else if(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTER) {
238+
// copy in
239+
status = ucc_mc_memcpy(save->scratch->addr,
240+
save->old_asymmetric_buffer.info.buffer,
241+
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
242+
save->old_asymmetric_buffer.info.count,
243+
save->scratch->mt,
244+
save->old_asymmetric_buffer.info.mem_type);
245+
} else if(task->bargs.args.coll_type == UCC_COLL_TYPE_GATHERV) {
246+
// copy out
247+
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info_v.buffer,
248+
save->scratch->addr,
249+
ucc_coll_args_get_v_buffer_size(coll_args,
250+
save->old_asymmetric_buffer.info_v.counts,
251+
save->old_asymmetric_buffer.info_v.displacements,
252+
size),
253+
save->old_asymmetric_buffer.info_v.mem_type,
254+
save->scratch->mt);
255+
} else {
256+
// copy out
257+
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info.buffer,
258+
save->scratch->addr,
259+
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
260+
save->old_asymmetric_buffer.info.count,
261+
save->old_asymmetric_buffer.info.mem_type,
262+
save->scratch->mt);
263+
}
264+
if (ucc_unlikely(status != UCC_OK)) {
265+
ucc_error("error copying back to old asymmetric buffer: %s",
266+
ucc_status_string(status));
267+
}
268+
return status;
269+
}
270+
97271
int ucc_coll_args_is_predefined_dt(const ucc_coll_args_t *args, ucc_rank_t rank)
98272
{
99273
switch (args->coll_type) {
@@ -163,9 +337,6 @@ ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
163337
{
164338
ucc_rank_t root = args->root;
165339

166-
if (!ucc_coll_args_is_mem_symmetric(args, rank)) {
167-
return UCC_MEMORY_TYPE_ASYMMETRIC;
168-
}
169340
switch (args->coll_type) {
170341
case UCC_COLL_TYPE_BARRIER:
171342
case UCC_COLL_TYPE_FANIN:
@@ -180,7 +351,6 @@ ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
180351
return args->dst.info.mem_type;
181352
case UCC_COLL_TYPE_ALLGATHERV:
182353
case UCC_COLL_TYPE_REDUCE_SCATTERV:
183-
return args->dst.info_v.mem_type;
184354
case UCC_COLL_TYPE_ALLTOALLV:
185355
return args->dst.info_v.mem_type;
186356
case UCC_COLL_TYPE_REDUCE:
@@ -323,7 +493,7 @@ ucc_ep_map_t ucc_ep_map_from_array_64(uint64_t **array, ucc_rank_t size,
323493
need_free, 1);
324494
}
325495

326-
static inline int ucc_coll_args_is_rooted(ucc_coll_type_t ct)
496+
int ucc_coll_args_is_rooted(ucc_coll_type_t ct)
327497
{
328498
if (ct == UCC_COLL_TYPE_REDUCE || ct == UCC_COLL_TYPE_BCAST ||
329499
ct == UCC_COLL_TYPE_GATHER || ct == UCC_COLL_TYPE_SCATTER ||

0 commit comments

Comments
 (0)