Skip to content

Commit fa4876f

Browse files
committed
wip
1 parent 1e137bb commit fa4876f

File tree

9 files changed

+12
-8
lines changed

9 files changed

+12
-8
lines changed

unified-runtime/source/adapters/cuda/command_buffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
6262
bool IsInOrder)
6363
: handle_base(), Context(Context), Device(Device), IsUpdatable(IsUpdatable),
6464
IsInOrder(IsInOrder), CudaGraph{nullptr}, CudaGraphExec{nullptr},
65-
RefCount{1}, NextSyncPoint{0} {
65+
NextSyncPoint{0} {
6666
urContextRetain(Context);
6767
}
6868

unified-runtime/source/adapters/cuda/command_buffer.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
186186
CUgraph CudaGraph;
187187
// Cuda Graph Exec handle
188188
CUgraphExec CudaGraphExec = nullptr;
189-
189+
// Atomic variable counting the number of reference to this command_buffer
190+
// using std::atomic prevents data race when incrementing/decrementing.
190191
ur::RefCount RefCount;
191192

192193
// Ordered map of sync_points to ur_events, so that we can find the last

unified-runtime/source/adapters/cuda/context.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
8585
UR_APIEXPORT ur_result_t UR_APICALL
8686
urContextRelease(ur_context_handle_t hContext) {
8787
if (hContext->RefCount.release()) {
88-
return UR_RESULT_SUCCESS;
88+
hContext->invokeExtendedDeleters();
89+
delete hContext;
8990
}
90-
hContext->invokeExtendedDeleters();
91-
delete hContext;
9291

9392
return UR_RESULT_SUCCESS;
9493
}

unified-runtime/source/adapters/cuda/context.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
9797
umf_memory_pool_handle_t MemoryPoolHost = nullptr;
9898

9999
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
100-
: handle_base(), Devices{Devs, Devs + NumDevices}, RefCount(0) {
100+
: handle_base(), Devices{Devs, Devs + NumDevices} {
101101
// Create UMF CUDA memory provider for the host memory
102102
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
103103
// it is guaranteed to exist).

unified-runtime/source/adapters/cuda/event.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
265265
}
266266

267267
// decrement ref count. If it is 0, delete the event.
268-
if (hEvent->release()) {
268+
if (hEvent->RefCount.release()) {
269269
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
270270
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
271271
try {

unified-runtime/source/adapters/cuda/event.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "common.hpp"
1616
#include "common/ur_ref_count.hpp"
1717
#include "queue.hpp"
18+
#include "common/ur_ref_count.hpp"
1819

1920
/// UR Event mapping to CUevent
2021
struct ur_event_handle_t_ : ur::cuda::handle_base {

unified-runtime/source/adapters/cuda/memory.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base {
315315
// Context where the memory object is accessible
316316
ur_context_handle_t Context;
317317

318+
/// Reference counting of the handler
318319
ur::RefCount RefCount;
319320

320321
// Original mem flags passed

unified-runtime/source/adapters/cuda/queue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
107107
}
108108

109109
UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) {
110-
assert(hQueue->RefCount.getCount() > 0);
110+
assert(hQueue->RefCount.getCount());
111111

112112
hQueue->RefCount.retain();
113113
return UR_RESULT_SUCCESS;

unified-runtime/source/common/cuda-hip/stream_queue.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <mutex>
1717
#include <vector>
1818

19+
#include "common/ur_ref_count.hpp"
20+
1921
using ur_stream_guard = std::unique_lock<std::mutex>;
2022

2123
/// Generic implementation of an out-of-order UR queue based on in-order

0 commit comments

Comments
 (0)