Skip to content

Commit c02e336

Browse files
authored
Merge pull request #950 from CEED/natalie/sync-backend
Use backend functions for SyncArray in CUDA and HIP
2 parents c97ca96 + f48ed27 commit c02e336

File tree

3 files changed

+79
-75
lines changed

3 files changed

+79
-75
lines changed

backends/cuda-ref/ceed-cuda-vector.c

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,30 @@
1313
#include <string.h>
1414
#include "ceed-cuda-ref.h"
1515

16+
17+
//------------------------------------------------------------------------------
18+
// Check if host/device sync is needed
19+
//------------------------------------------------------------------------------
20+
static inline int CeedVectorNeedSync_Cuda(const CeedVector vec,
21+
CeedMemType mem_type, bool *need_sync) {
22+
int ierr;
23+
CeedVector_Cuda *impl;
24+
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
25+
26+
bool has_valid_array = false;
27+
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
28+
switch (mem_type) {
29+
case CEED_MEM_HOST:
30+
*need_sync = has_valid_array && !impl->h_array;
31+
break;
32+
case CEED_MEM_DEVICE:
33+
*need_sync = has_valid_array && !impl->d_array;
34+
break;
35+
}
36+
37+
return CEED_ERROR_SUCCESS;
38+
}
39+
1640
//------------------------------------------------------------------------------
1741
// Sync host to device
1842
//------------------------------------------------------------------------------
@@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
88112
//------------------------------------------------------------------------------
89113
// Sync arrays
90114
//------------------------------------------------------------------------------
91-
static inline int CeedVectorSync_Cuda(const CeedVector vec,
92-
CeedMemType mem_type) {
115+
static int CeedVectorSyncArray_Cuda(const CeedVector vec,
116+
CeedMemType mem_type) {
117+
int ierr;
118+
// Check whether device/host sync is needed
119+
bool need_sync = false;
120+
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync);
121+
CeedChkBackend(ierr);
122+
if (!need_sync)
123+
return CEED_ERROR_SUCCESS;
124+
93125
switch (mem_type) {
94126
case CEED_MEM_HOST: return CeedVectorSyncD2H_Cuda(vec);
95127
case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Cuda(vec);
@@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec,
167199
return CEED_ERROR_SUCCESS;
168200
}
169201

170-
//------------------------------------------------------------------------------
171-
// Check if is any array of given type
172-
//------------------------------------------------------------------------------
173-
static inline int CeedVectorNeedSync_Cuda(const CeedVector vec,
174-
CeedMemType mem_type, bool *need_sync) {
175-
int ierr;
176-
CeedVector_Cuda *impl;
177-
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
178-
179-
bool has_valid_array = false;
180-
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
181-
switch (mem_type) {
182-
case CEED_MEM_HOST:
183-
*need_sync = has_valid_array && !impl->h_array;
184-
break;
185-
case CEED_MEM_DEVICE:
186-
*need_sync = has_valid_array && !impl->d_array;
187-
break;
188-
}
189-
190-
return CEED_ERROR_SUCCESS;
191-
}
192-
193202
//------------------------------------------------------------------------------
194203
// Set array from host
195204
//------------------------------------------------------------------------------
@@ -368,11 +377,7 @@ static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type,
368377
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
369378

370379
// Sync array to requested mem_type
371-
bool need_sync = false;
372-
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr);
373-
if (need_sync) {
374-
ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr);
375-
}
380+
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
376381

377382
// Update pointer
378383
switch (mem_type) {
@@ -403,14 +408,8 @@ static int CeedVectorGetArrayCore_Cuda(const CeedVector vec,
403408
CeedVector_Cuda *impl;
404409
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
405410

406-
bool need_sync = false, has_array_of_type = true;
407-
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr);
408-
ierr = CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type);
409-
CeedChkBackend(ierr);
410-
if (need_sync) {
411-
// Sync array to requested mem_type
412-
ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr);
413-
}
411+
// Sync array to requested mem_type
412+
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
414413

415414
// Update pointer
416415
switch (mem_type) {
@@ -763,6 +762,8 @@ int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
763762
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
764763
(int (*)())(CeedVectorSetValue_Cuda));
765764
CeedChkBackend(ierr);
765+
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
766+
CeedVectorSyncArray_Cuda); CeedChkBackend(ierr);
766767
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
767768
CeedVectorGetArray_Cuda); CeedChkBackend(ierr);
768769
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",

backends/hip-ref/ceed-hip-ref-vector.c

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,30 @@
1313
#include <string.h>
1414
#include "ceed-hip-ref.h"
1515

16+
17+
//------------------------------------------------------------------------------
18+
// Check if host/device sync is needed
19+
//------------------------------------------------------------------------------
20+
static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
21+
CeedMemType mem_type, bool *need_sync) {
22+
int ierr;
23+
CeedVector_Hip *impl;
24+
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
25+
26+
bool has_valid_array = false;
27+
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
28+
switch (mem_type) {
29+
case CEED_MEM_HOST:
30+
*need_sync = has_valid_array && !impl->h_array;
31+
break;
32+
case CEED_MEM_DEVICE:
33+
*need_sync = has_valid_array && !impl->d_array;
34+
break;
35+
}
36+
37+
return CEED_ERROR_SUCCESS;
38+
}
39+
1640
//------------------------------------------------------------------------------
1741
// Sync host to device
1842
//------------------------------------------------------------------------------
@@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
88112
//------------------------------------------------------------------------------
89113
// Sync arrays
90114
//------------------------------------------------------------------------------
91-
static inline int CeedVectorSync_Hip(const CeedVector vec,
92-
CeedMemType mem_type) {
115+
static int CeedVectorSyncArray_Hip(const CeedVector vec,
116+
CeedMemType mem_type) {
117+
int ierr;
118+
// Check whether device/host sync is needed
119+
bool need_sync = false;
120+
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync);
121+
CeedChkBackend(ierr);
122+
if (!need_sync)
123+
return CEED_ERROR_SUCCESS;
124+
93125
switch (mem_type) {
94126
case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec);
95127
case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec);
@@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec,
167199
return CEED_ERROR_SUCCESS;
168200
}
169201

170-
//------------------------------------------------------------------------------
171-
// Sync array of given type
172-
//------------------------------------------------------------------------------
173-
static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
174-
CeedMemType mem_type, bool *need_sync) {
175-
int ierr;
176-
CeedVector_Hip *impl;
177-
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
178-
179-
bool has_valid_array = false;
180-
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
181-
switch (mem_type) {
182-
case CEED_MEM_HOST:
183-
*need_sync = has_valid_array && !impl->h_array;
184-
break;
185-
case CEED_MEM_DEVICE:
186-
*need_sync = has_valid_array && !impl->d_array;
187-
break;
188-
}
189-
190-
return CEED_ERROR_SUCCESS;
191-
}
192-
193202
//------------------------------------------------------------------------------
194203
// Set array from host
195204
//------------------------------------------------------------------------------
@@ -363,11 +372,7 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type,
363372
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
364373

365374
// Sync array to requested mem_type
366-
bool need_sync = false;
367-
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
368-
if (need_sync) {
369-
ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
370-
}
375+
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
371376

372377
// Update pointer
373378
switch (mem_type) {
@@ -398,13 +403,8 @@ static int CeedVectorGetArrayCore_Hip(const CeedVector vec,
398403
CeedVector_Hip *impl;
399404
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
400405

401-
bool need_sync = false;
402-
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
403-
CeedChkBackend(ierr);
404-
if (need_sync) {
405-
// Sync array to requested mem_type
406-
ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
407-
}
406+
// Sync array to requested mem_type
407+
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
408408

409409
// Update pointer
410410
switch (mem_type) {
@@ -758,6 +758,8 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
758758
CeedVectorTakeArray_Hip); CeedChkBackend(ierr);
759759
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
760760
(int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr);
761+
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
762+
CeedVectorSyncArray_Hip); CeedChkBackend(ierr);
761763
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
762764
CeedVectorGetArray_Hip); CeedChkBackend(ierr);
763765
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",

interface/ceed.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,7 @@ int CeedInit(const char *resource, Ceed *ceed) {
843843
CEED_FTABLE_ENTRY(CeedVector, SetArray),
844844
CEED_FTABLE_ENTRY(CeedVector, TakeArray),
845845
CEED_FTABLE_ENTRY(CeedVector, SetValue),
846+
CEED_FTABLE_ENTRY(CeedVector, SyncArray),
846847
CEED_FTABLE_ENTRY(CeedVector, GetArray),
847848
CEED_FTABLE_ENTRY(CeedVector, GetArrayRead),
848849
CEED_FTABLE_ENTRY(CeedVector, GetArrayWrite),

0 commit comments

Comments
 (0)