|
13 | 13 | #include <string.h>
|
14 | 14 | #include "ceed-cuda-ref.h"
|
15 | 15 |
|
| 16 | + |
| 17 | +//------------------------------------------------------------------------------ |
| 18 | +// Check if host/device sync is needed |
| 19 | +//------------------------------------------------------------------------------ |
| 20 | +static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, |
| 21 | + CeedMemType mem_type, bool *need_sync) { |
| 22 | + int ierr; |
| 23 | + CeedVector_Cuda *impl; |
| 24 | + ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); |
| 25 | + |
| 26 | + bool has_valid_array = false; |
| 27 | + ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); |
| 28 | + switch (mem_type) { |
| 29 | + case CEED_MEM_HOST: |
| 30 | + *need_sync = has_valid_array && !impl->h_array; |
| 31 | + break; |
| 32 | + case CEED_MEM_DEVICE: |
| 33 | + *need_sync = has_valid_array && !impl->d_array; |
| 34 | + break; |
| 35 | + } |
| 36 | + |
| 37 | + return CEED_ERROR_SUCCESS; |
| 38 | +} |
| 39 | + |
16 | 40 | //------------------------------------------------------------------------------
|
17 | 41 | // Sync host to device
|
18 | 42 | //------------------------------------------------------------------------------
|
@@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
|
88 | 112 | //------------------------------------------------------------------------------
|
89 | 113 | // Sync arrays
|
90 | 114 | //------------------------------------------------------------------------------
|
91 |
| -static inline int CeedVectorSync_Cuda(const CeedVector vec, |
92 |
| - CeedMemType mem_type) { |
| 115 | +static int CeedVectorSyncArray_Cuda(const CeedVector vec, |
| 116 | + CeedMemType mem_type) { |
| 117 | + int ierr; |
| 118 | + // Check whether device/host sync is needed |
| 119 | + bool need_sync = false; |
| 120 | + ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); |
| 121 | + CeedChkBackend(ierr); |
| 122 | + if (!need_sync) |
| 123 | + return CEED_ERROR_SUCCESS; |
| 124 | + |
93 | 125 | switch (mem_type) {
|
94 | 126 | case CEED_MEM_HOST: return CeedVectorSyncD2H_Cuda(vec);
|
95 | 127 | case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Cuda(vec);
|
@@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec,
|
167 | 199 | return CEED_ERROR_SUCCESS;
|
168 | 200 | }
|
169 | 201 |
|
170 |
| -//------------------------------------------------------------------------------ |
171 |
| -// Check if is any array of given type |
172 |
| -//------------------------------------------------------------------------------ |
173 |
| -static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, |
174 |
| - CeedMemType mem_type, bool *need_sync) { |
175 |
| - int ierr; |
176 |
| - CeedVector_Cuda *impl; |
177 |
| - ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); |
178 |
| - |
179 |
| - bool has_valid_array = false; |
180 |
| - ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); |
181 |
| - switch (mem_type) { |
182 |
| - case CEED_MEM_HOST: |
183 |
| - *need_sync = has_valid_array && !impl->h_array; |
184 |
| - break; |
185 |
| - case CEED_MEM_DEVICE: |
186 |
| - *need_sync = has_valid_array && !impl->d_array; |
187 |
| - break; |
188 |
| - } |
189 |
| - |
190 |
| - return CEED_ERROR_SUCCESS; |
191 |
| -} |
192 |
| - |
193 | 202 | //------------------------------------------------------------------------------
|
194 | 203 | // Set array from host
|
195 | 204 | //------------------------------------------------------------------------------
|
@@ -368,11 +377,7 @@ static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type,
|
368 | 377 | ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
|
369 | 378 |
|
370 | 379 | // Sync array to requested mem_type
|
371 |
| - bool need_sync = false; |
372 |
| - ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr); |
373 |
| - if (need_sync) { |
374 |
| - ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr); |
375 |
| - } |
| 380 | + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); |
376 | 381 |
|
377 | 382 | // Update pointer
|
378 | 383 | switch (mem_type) {
|
@@ -403,14 +408,8 @@ static int CeedVectorGetArrayCore_Cuda(const CeedVector vec,
|
403 | 408 | CeedVector_Cuda *impl;
|
404 | 409 | ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
|
405 | 410 |
|
406 |
| - bool need_sync = false, has_array_of_type = true; |
407 |
| - ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr); |
408 |
| - ierr = CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type); |
409 |
| - CeedChkBackend(ierr); |
410 |
| - if (need_sync) { |
411 |
| - // Sync array to requested mem_type |
412 |
| - ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr); |
413 |
| - } |
| 411 | + // Sync array to requested mem_type |
| 412 | + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); |
414 | 413 |
|
415 | 414 | // Update pointer
|
416 | 415 | switch (mem_type) {
|
@@ -763,6 +762,8 @@ int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
|
763 | 762 | ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
|
764 | 763 | (int (*)())(CeedVectorSetValue_Cuda));
|
765 | 764 | CeedChkBackend(ierr);
|
| 765 | + ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", |
| 766 | + CeedVectorSyncArray_Cuda); CeedChkBackend(ierr); |
766 | 767 | ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
|
767 | 768 | CeedVectorGetArray_Cuda); CeedChkBackend(ierr);
|
768 | 769 | ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
|
|
0 commit comments