@@ -52,7 +52,7 @@ ucc_memory_type_t ucc_mem_type_from_str(const char *str)
52
52
return UCC_MEMORY_TYPE_LAST ;
53
53
}
54
54
55
- static inline int
55
+ int
56
56
ucc_coll_args_is_mem_symmetric (const ucc_coll_args_t * args ,
57
57
ucc_rank_t rank )
58
58
{
@@ -94,6 +94,180 @@ ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args,
94
94
return 0 ;
95
95
}
96
96
97
+
98
+ /* If this is the root and the src/dst buffers are asymmetric, one buffer needs
99
+ to have a new allocation to make the mem types match. If that buffer was the
100
+ dst buffer, copy the result back into the old dst on task completion */
101
+ ucc_status_t
102
+ ucc_coll_args_init_asymmetric_buffer (ucc_coll_args_t * args ,
103
+ ucc_team_h team ,
104
+ ucc_buffer_info_asymmetric_memtype_t * save_info )
105
+ {
106
+ ucc_status_t status = UCC_OK ;
107
+
108
+ if (UCC_IS_INPLACE (* args )) {
109
+ return UCC_ERR_INVALID_PARAM ;
110
+ }
111
+ switch (args -> coll_type ) {
112
+ case UCC_COLL_TYPE_REDUCE :
113
+ case UCC_COLL_TYPE_GATHER :
114
+ {
115
+ ucc_memory_type_t mem_type = args -> src .info .mem_type ;
116
+ if (args -> coll_type == UCC_COLL_TYPE_SCATTERV ) {
117
+ mem_type = args -> src .info_v .mem_type ;
118
+ }
119
+ memcpy (& save_info -> old_asymmetric_buffer .info ,
120
+ & args -> dst .info , sizeof (ucc_coll_buffer_info_t ));
121
+ status = ucc_mc_alloc (& save_info -> scratch ,
122
+ ucc_dt_size (args -> dst .info .datatype ) *
123
+ args -> dst .info .count ,
124
+ mem_type );
125
+ if (ucc_unlikely (UCC_OK != status )) {
126
+ ucc_error ("failed to allocate replacement "
127
+ "memory for asymmetric buffer" );
128
+ return status ;
129
+ }
130
+ args -> dst .info .buffer = save_info -> scratch -> addr ;
131
+ args -> dst .info .mem_type = mem_type ;
132
+ return UCC_OK ;
133
+ }
134
+ case UCC_COLL_TYPE_GATHERV :
135
+ {
136
+ memcpy (& save_info -> old_asymmetric_buffer .info_v ,
137
+ & args -> dst .info_v , sizeof (ucc_coll_buffer_info_v_t ));
138
+ status = ucc_mc_alloc (& save_info -> scratch ,
139
+ ucc_coll_args_get_v_buffer_size (args ,
140
+ args -> dst .info_v .counts ,
141
+ args -> dst .info_v .displacements ,
142
+ team -> size ),
143
+ args -> src .info .mem_type );
144
+ if (ucc_unlikely (UCC_OK != status )) {
145
+ ucc_error ("failed to allocate replacement "
146
+ "memory for asymmetric buffer" );
147
+ return status ;
148
+ }
149
+ args -> dst .info_v .buffer = save_info -> scratch -> addr ;
150
+ args -> dst .info_v .mem_type = args -> src .info .mem_type ;
151
+ return UCC_OK ;
152
+ }
153
+ case UCC_COLL_TYPE_SCATTER :
154
+ {
155
+ ucc_memory_type_t mem_type = args -> dst .info .mem_type ;
156
+ memcpy (& save_info -> old_asymmetric_buffer .info ,
157
+ & args -> src .info , sizeof (ucc_coll_buffer_info_t ));
158
+ status = ucc_mc_alloc (& save_info -> scratch ,
159
+ ucc_dt_size (args -> src .info .datatype ) * args -> src .info .count ,
160
+ mem_type );
161
+ if (ucc_unlikely (UCC_OK != status )) {
162
+ ucc_error ("failed to allocate replacement "
163
+ "memory for asymmetric buffer" );
164
+ return status ;
165
+ }
166
+ args -> src .info .buffer = save_info -> scratch -> addr ;
167
+ args -> src .info .mem_type = mem_type ;
168
+ return UCC_OK ;
169
+ }
170
+ case UCC_COLL_TYPE_SCATTERV :
171
+ {
172
+ ucc_memory_type_t mem_type = args -> dst .info .mem_type ;
173
+ memcpy (& save_info -> old_asymmetric_buffer .info_v ,
174
+ & args -> src .info_v , sizeof (ucc_coll_buffer_info_v_t ));
175
+ status = ucc_mc_alloc (& save_info -> scratch ,
176
+ ucc_coll_args_get_v_buffer_size (args ,
177
+ args -> src .info_v .counts ,
178
+ args -> src .info_v .displacements ,
179
+ team -> size ),
180
+ mem_type );
181
+ if (ucc_unlikely (UCC_OK != status )) {
182
+ ucc_error ("failed to allocate replacement "
183
+ "memory for asymmetric buffer" );
184
+ return status ;
185
+ }
186
+ args -> src .info_v .buffer = save_info -> scratch -> addr ;
187
+ args -> src .info_v .mem_type = mem_type ;
188
+ return UCC_OK ;
189
+ }
190
+ default :
191
+ break ;
192
+ }
193
+ return UCC_ERR_INVALID_PARAM ;
194
+ }
195
+
196
+ ucc_status_t
197
+ ucc_coll_args_free_asymmetric_buffer (ucc_coll_task_t * task )
198
+ {
199
+ ucc_status_t status = UCC_OK ;
200
+ ucc_buffer_info_asymmetric_memtype_t * save = & task -> bargs .asymmetric_save_info ;
201
+
202
+ if (UCC_IS_INPLACE (task -> bargs .args )) {
203
+ return UCC_ERR_INVALID_PARAM ;
204
+ }
205
+
206
+ if (save -> scratch == NULL ) {
207
+ ucc_error ("failure trying to free NULL asymmetric buffer" );
208
+ }
209
+
210
+ status = ucc_mc_free (save -> scratch );
211
+ if (ucc_unlikely (status != UCC_OK )) {
212
+ ucc_error ("error freeing scratch asymmetric buffer: %s" ,
213
+ ucc_status_string (status ));
214
+ }
215
+ save -> scratch = NULL ;
216
+
217
+ return status ;
218
+ }
219
+
220
+ ucc_status_t ucc_copy_asymmetric_buffer (ucc_coll_task_t * task )
221
+ {
222
+ ucc_status_t status = UCC_OK ;
223
+ ucc_coll_args_t * coll_args = & task -> bargs .args ;
224
+ ucc_buffer_info_asymmetric_memtype_t * save = & task -> bargs .asymmetric_save_info ;
225
+ ucc_rank_t size = task -> team -> params .size ;
226
+
227
+ if (task -> bargs .args .coll_type == UCC_COLL_TYPE_SCATTERV ) {
228
+ // copy in
229
+ status = ucc_mc_memcpy (save -> scratch -> addr ,
230
+ save -> old_asymmetric_buffer .info_v .buffer ,
231
+ ucc_coll_args_get_v_buffer_size (coll_args ,
232
+ save -> old_asymmetric_buffer .info_v .counts ,
233
+ save -> old_asymmetric_buffer .info_v .displacements ,
234
+ size ),
235
+ save -> scratch -> mt ,
236
+ save -> old_asymmetric_buffer .info_v .mem_type );
237
+ } else if (task -> bargs .args .coll_type == UCC_COLL_TYPE_SCATTER ) {
238
+ // copy in
239
+ status = ucc_mc_memcpy (save -> scratch -> addr ,
240
+ save -> old_asymmetric_buffer .info .buffer ,
241
+ ucc_dt_size (save -> old_asymmetric_buffer .info .datatype ) *
242
+ save -> old_asymmetric_buffer .info .count ,
243
+ save -> scratch -> mt ,
244
+ save -> old_asymmetric_buffer .info .mem_type );
245
+ } else if (task -> bargs .args .coll_type == UCC_COLL_TYPE_GATHERV ) {
246
+ // copy out
247
+ status = ucc_mc_memcpy (save -> old_asymmetric_buffer .info_v .buffer ,
248
+ save -> scratch -> addr ,
249
+ ucc_coll_args_get_v_buffer_size (coll_args ,
250
+ save -> old_asymmetric_buffer .info_v .counts ,
251
+ save -> old_asymmetric_buffer .info_v .displacements ,
252
+ size ),
253
+ save -> old_asymmetric_buffer .info_v .mem_type ,
254
+ save -> scratch -> mt );
255
+ } else {
256
+ // copy out
257
+ status = ucc_mc_memcpy (save -> old_asymmetric_buffer .info .buffer ,
258
+ save -> scratch -> addr ,
259
+ ucc_dt_size (save -> old_asymmetric_buffer .info .datatype ) *
260
+ save -> old_asymmetric_buffer .info .count ,
261
+ save -> old_asymmetric_buffer .info .mem_type ,
262
+ save -> scratch -> mt );
263
+ }
264
+ if (ucc_unlikely (status != UCC_OK )) {
265
+ ucc_error ("error copying back to old asymmetric buffer: %s" ,
266
+ ucc_status_string (status ));
267
+ }
268
+ return status ;
269
+ }
270
+
97
271
int ucc_coll_args_is_predefined_dt (const ucc_coll_args_t * args , ucc_rank_t rank )
98
272
{
99
273
switch (args -> coll_type ) {
@@ -163,9 +337,6 @@ ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
163
337
{
164
338
ucc_rank_t root = args -> root ;
165
339
166
- if (!ucc_coll_args_is_mem_symmetric (args , rank )) {
167
- return UCC_MEMORY_TYPE_ASYMMETRIC ;
168
- }
169
340
switch (args -> coll_type ) {
170
341
case UCC_COLL_TYPE_BARRIER :
171
342
case UCC_COLL_TYPE_FANIN :
@@ -180,7 +351,6 @@ ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
180
351
return args -> dst .info .mem_type ;
181
352
case UCC_COLL_TYPE_ALLGATHERV :
182
353
case UCC_COLL_TYPE_REDUCE_SCATTERV :
183
- return args -> dst .info_v .mem_type ;
184
354
case UCC_COLL_TYPE_ALLTOALLV :
185
355
return args -> dst .info_v .mem_type ;
186
356
case UCC_COLL_TYPE_REDUCE :
@@ -323,7 +493,7 @@ ucc_ep_map_t ucc_ep_map_from_array_64(uint64_t **array, ucc_rank_t size,
323
493
need_free , 1 );
324
494
}
325
495
326
- static inline int ucc_coll_args_is_rooted (ucc_coll_type_t ct )
496
+ int ucc_coll_args_is_rooted (ucc_coll_type_t ct )
327
497
{
328
498
if (ct == UCC_COLL_TYPE_REDUCE || ct == UCC_COLL_TYPE_BCAST ||
329
499
ct == UCC_COLL_TYPE_GATHER || ct == UCC_COLL_TYPE_SCATTER ||
0 commit comments