1+ #pragma once
2+
3+ #include < hip/hip_runtime.h>
4+
5+ // Add an extra unused input to accommodate the additional 'memory_order'
6+ // argument during lowering.
7+ template <typename T1, typename T2>
8+ __forceinline__ __device__ void AtomicAdd (
9+ T1 *address, T2 val, int memory_order = 0 ) {
10+ atomicAdd (reinterpret_cast <T1 *>(address), static_cast <T1>(val));
11+ }
12+
13+ // Add an extra unused input to accommodate the additional 'memory_order'
14+ // argument during lowering.
15+ // Overload for when the first argument is a value instead of a pointer
16+ template <typename T1, typename T2>
17+ __forceinline__ __device__ void AtomicAdd (
18+ T1 &address, T2 val, int memory_order = 0 ) {
19+ atomicAdd (reinterpret_cast <T1 *>(&address), static_cast <T1>(val));
20+ }
21+
22+ // Add an extra unused input to accommodate the additional 'memory_order'
23+ // argument during lowering.
24+ template <typename T1, typename T2>
25+ __forceinline__ __device__ T1 AtomicAddRet (
26+ T1 &ref, T2 val, int memory_order = 0 ) {
27+ return atomicAdd (&ref, static_cast <T1>(val));
28+ }
29+
30+ // Add an extra unused input to accommodate the additional 'memory_order'
31+ // argument during lowering.
32+ template <typename T1, typename T2>
33+ __forceinline__ __device__ void AtomicMax (
34+ T1 *address, T2 val, int memory_order = 0 ) {
35+ atomicMax (reinterpret_cast <T1 *>(address), static_cast <T1>(val));
36+ }
37+
38+ // Add an extra unused input to accommodate the additional 'memory_order'
39+ // argument during lowering.
40+ // Overload for when the first argument is a value instead of a pointer
41+ template <typename T1, typename T2>
42+ __forceinline__ __device__ void AtomicMax (
43+ T1 &address, T2 val, int memory_order = 0 ) {
44+ atomicMax (reinterpret_cast <T1 *>(&address), static_cast <T1>(val));
45+ }
46+
47+ // Add an extra unused input to accommodate the additional 'memory_order'
48+ // argument during lowering.
49+ template <typename T1, typename T2>
50+ __forceinline__ __device__ void AtomicMin (
51+ T1 *address, T2 val, int memory_order = 0 ) {
52+ atomicMin (reinterpret_cast <T1 *>(address), static_cast <T1>(val));
53+ }
54+
55+ // Add an extra unused input to accommodate the additional 'memory_order'
56+ // argument during lowering.
57+ // Overload for when the first argument is a value instead of a pointer
58+ template <typename T1, typename T2>
59+ __forceinline__ __device__ void AtomicMin (
60+ T1 &address, T2 val, int memory_order = 0 ) {
61+ atomicMin (reinterpret_cast <T1 *>(&address), static_cast <T1>(val));
62+ }
63+
64+ __forceinline__ __device__ void AtomicAddx2 (
65+ float *ref, float *val, int memory_order = 0 ) {
66+ float2 add_val = *reinterpret_cast <float2 *>(val);
67+ atomicAdd (ref + 0 , add_val.x );
68+ atomicAdd (ref + 1 , add_val.y );
69+ }
70+
71+ // Add an extra unused input to accommodate the additional 'memory_order'
72+ // argument during lowering.
73+ __forceinline__ __device__ float2 AtomicAddx2Ret (
74+ float *ref, float *val, int memory_order = 0 ) {
75+ float2 add_val = *reinterpret_cast <float2 *>(val);
76+ float2 ret;
77+ ret.x = atomicAdd (ref + 0 , add_val.x );
78+ ret.y = atomicAdd (ref + 1 , add_val.y );
79+ return ret;
80+ }
81+
82+ // Add an extra unused input to accommodate the additional 'memory_order'
83+ // argument during lowering.
84+ __forceinline__ __device__ void AtomicAddx4 (
85+ float *ref, float *val, int memory_order = 0 ) {
86+ float4 add_val = *reinterpret_cast <float4 *>(val);
87+ atomicAdd (ref + 0 , add_val.x );
88+ atomicAdd (ref + 1 , add_val.y );
89+ atomicAdd (ref + 2 , add_val.z );
90+ atomicAdd (ref + 3 , add_val.w );
91+ }
92+
93+ // Add an extra unused input to accommodate the additional 'memory_order'
94+ // argument during lowering.
95+ __forceinline__ __device__ float4 AtomicAddx4Ret (
96+ float *ref, float *val, int memory_order = 0 ) {
97+ float4 add_val = *reinterpret_cast <float4 *>(val);
98+ float4 ret;
99+ ret.x = atomicAdd (ref + 0 , add_val.x );
100+ ret.y = atomicAdd (ref + 1 , add_val.y );
101+ ret.z = atomicAdd (ref + 2 , add_val.z );
102+ ret.w = atomicAdd (ref + 3 , add_val.w );
103+ return ret;
104+ }
0 commit comments