Skip to content

Commit bbd2c4e

Browse files
committed
Fix some bugs on ci and open rocm ci test
1 parent 869f021 commit bbd2c4e

31 files changed

+365
-89
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ jobs:
383383
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
384384
)
385385
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
386-
./python/amd
386+
./python
387387
388388
# Apple Metal tests
389389
- name: Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})

src/op/logical.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,16 @@ TVM_REGISTER_OP("tl.any_of")
4242
.set_attr<TCallEffectKind>("TCallEffectKind",
4343
Integer(CallEffectKind::kPure))
4444
.set_attr<TScriptPrinterName>("TScriptPrinterName", "any_of")
45-
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", any_of_op);
45+
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", any_of_op)
46+
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", any_of_op);
4647

4748
TVM_REGISTER_OP("tl.all_of")
4849
.set_num_inputs(1)
4950
.set_attr<TCallEffectKind>("TCallEffectKind",
5051
Integer(CallEffectKind::kPure))
5152
.set_attr<TScriptPrinterName>("TScriptPrinterName", "all_of")
52-
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", all_of_op);
53+
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", all_of_op)
54+
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", all_of_op);
5355

5456
} // namespace tl
5557
} // namespace tvm

src/tl_templates/hip/atomic.h

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#pragma once
2+
3+
#include <hip/hip_runtime.h>
4+
5+
// Add an extra unused input to accommodate the additional 'memory_order'
6+
// argument during lowering.
7+
template <typename T1, typename T2>
8+
__forceinline__ __device__ void AtomicAdd(
9+
T1 *address, T2 val, int memory_order = 0) {
10+
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
11+
}
12+
13+
// Add an extra unused input to accommodate the additional 'memory_order'
14+
// argument during lowering.
15+
// Overload for when the first argument is a value instead of a pointer
16+
template <typename T1, typename T2>
17+
__forceinline__ __device__ void AtomicAdd(
18+
T1 &address, T2 val, int memory_order = 0) {
19+
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
20+
}
21+
22+
// Add an extra unused input to accommodate the additional 'memory_order'
23+
// argument during lowering.
24+
template <typename T1, typename T2>
25+
__forceinline__ __device__ T1 AtomicAddRet(
26+
T1 &ref, T2 val, int memory_order = 0) {
27+
return atomicAdd(&ref, static_cast<T1>(val));
28+
}
29+
30+
// Add an extra unused input to accommodate the additional 'memory_order'
31+
// argument during lowering.
32+
template <typename T1, typename T2>
33+
__forceinline__ __device__ void AtomicMax(
34+
T1 *address, T2 val, int memory_order = 0) {
35+
atomicMax(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
36+
}
37+
38+
// Add an extra unused input to accommodate the additional 'memory_order'
39+
// argument during lowering.
40+
// Overload for when the first argument is a value instead of a pointer
41+
template <typename T1, typename T2>
42+
__forceinline__ __device__ void AtomicMax(
43+
T1 &address, T2 val, int memory_order = 0) {
44+
atomicMax(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
45+
}
46+
47+
// Add an extra unused input to accommodate the additional 'memory_order'
48+
// argument during lowering.
49+
template <typename T1, typename T2>
50+
__forceinline__ __device__ void AtomicMin(
51+
T1 *address, T2 val, int memory_order = 0) {
52+
atomicMin(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
53+
}
54+
55+
// Add an extra unused input to accommodate the additional 'memory_order'
56+
// argument during lowering.
57+
// Overload for when the first argument is a value instead of a pointer
58+
template <typename T1, typename T2>
59+
__forceinline__ __device__ void AtomicMin(
60+
T1 &address, T2 val, int memory_order = 0) {
61+
atomicMin(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
62+
}
63+
64+
__forceinline__ __device__ void AtomicAddx2(
65+
float *ref, float *val, int memory_order = 0) {
66+
float2 add_val = *reinterpret_cast<float2 *>(val);
67+
atomicAdd(ref + 0, add_val.x);
68+
atomicAdd(ref + 1, add_val.y);
69+
}
70+
71+
// Add an extra unused input to accommodate the additional 'memory_order'
72+
// argument during lowering.
73+
__forceinline__ __device__ float2 AtomicAddx2Ret(
74+
float *ref, float *val, int memory_order = 0) {
75+
float2 add_val = *reinterpret_cast<float2 *>(val);
76+
float2 ret;
77+
ret.x = atomicAdd(ref + 0, add_val.x);
78+
ret.y = atomicAdd(ref + 1, add_val.y);
79+
return ret;
80+
}
81+
82+
// Add an extra unused input to accommodate the additional 'memory_order'
83+
// argument during lowering.
84+
__forceinline__ __device__ void AtomicAddx4(
85+
float *ref, float *val, int memory_order = 0) {
86+
float4 add_val = *reinterpret_cast<float4 *>(val);
87+
atomicAdd(ref + 0, add_val.x);
88+
atomicAdd(ref + 1, add_val.y);
89+
atomicAdd(ref + 2, add_val.z);
90+
atomicAdd(ref + 3, add_val.w);
91+
}
92+
93+
// Add an extra unused input to accommodate the additional 'memory_order'
94+
// argument during lowering.
95+
__forceinline__ __device__ float4 AtomicAddx4Ret(
96+
float *ref, float *val, int memory_order = 0) {
97+
float4 add_val = *reinterpret_cast<float4 *>(val);
98+
float4 ret;
99+
ret.x = atomicAdd(ref + 0, add_val.x);
100+
ret.y = atomicAdd(ref + 1, add_val.y);
101+
ret.z = atomicAdd(ref + 2, add_val.z);
102+
ret.w = atomicAdd(ref + 3, add_val.w);
103+
return ret;
104+
}

src/tl_templates/hip/common.h

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <hip/hip_fp16.h>
66
#include <hip/hip_runtime.h>
77
#include <rocwmma/rocwmma.hpp>
8+
#include "atomic.h"
89

910
#define HIPRT_INF_F __int_as_float(0x7f800000)
1011
#define HIPRT_NEGINF_F __int_as_float(0xff800000)
@@ -105,18 +106,106 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
105106
return (v1 << 16) | v0;
106107
}
107108

108-
template <typename T1, typename T2>
109-
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
110-
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
109+
namespace tl {
110+
111+
// Any
112+
template <typename T> TL_DEVICE bool Any(T *a, int size) {
113+
for (int i = 0; i < size; i++) {
114+
if (a[i]) {
115+
return true;
116+
}
117+
}
118+
return false;
119+
}
120+
121+
// All
122+
template <typename T> TL_DEVICE bool All(T *a, int size) {
123+
for (int i = 0; i < size; i++) {
124+
if (!a[i]) {
125+
return false;
126+
}
127+
}
128+
return true;
129+
}
130+
131+
// Shuffle functions for HIP
132+
template <typename T>
133+
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) {
134+
return __shfl_xor_sync(mask, val, laneMask);
135+
}
136+
137+
template <typename T>
138+
TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) {
139+
return __shfl_down_sync(mask, val, delta);
140+
}
141+
142+
template <typename T>
143+
TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) {
144+
return __shfl_up_sync(mask, val, delta);
145+
}
146+
147+
template <typename T>
148+
TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) {
149+
return __shfl_sync(mask, val, srcLane);
150+
}
151+
152+
// Specializations for half_t (float16_t)
153+
template <>
154+
TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) {
155+
float f = static_cast<float>(val);
156+
float r = __shfl_xor_sync(mask, f, laneMask);
157+
return half_t(r);
158+
}
159+
160+
template <>
161+
TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) {
162+
float f = static_cast<float>(val);
163+
float r = __shfl_down_sync(mask, f, delta);
164+
return half_t(r);
165+
}
166+
167+
template <>
168+
TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) {
169+
float f = static_cast<float>(val);
170+
float r = __shfl_up_sync(mask, f, delta);
171+
return half_t(r);
172+
}
173+
174+
template <>
175+
TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) {
176+
float f = static_cast<float>(val);
177+
float r = __shfl_sync(mask, f, srcLane);
178+
return half_t(r);
179+
}
180+
181+
// Specializations for bfloat16_t
182+
template <>
183+
TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val,
184+
int laneMask) {
185+
float f = static_cast<float>(val);
186+
float r = __shfl_xor_sync(mask, f, laneMask);
187+
return bfloat16_t(r);
111188
}
112189

113-
// Overload for when the first argument is a value instead of a pointer
114-
template <typename T1, typename T2>
115-
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
116-
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
190+
template <>
191+
TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) {
192+
float f = static_cast<float>(val);
193+
float r = __shfl_down_sync(mask, f, delta);
194+
return bfloat16_t(r);
117195
}
118196

119-
template <typename T1, typename T2>
120-
TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val) {
121-
return atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
197+
template <>
198+
TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) {
199+
float f = static_cast<float>(val);
200+
float r = __shfl_up_sync(mask, f, delta);
201+
return bfloat16_t(r);
122202
}
203+
204+
template <>
205+
TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
206+
float f = static_cast<float>(val);
207+
float r = __shfl_sync(mask, f, srcLane);
208+
return bfloat16_t(r);
209+
}
210+
211+
} // namespace tl

src/tl_templates/hip/debug.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ __device__ void debug_print_var<unsigned int>(const char *msg,
4747
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
4848
}
4949

50+
// Specialization for unsigned short type
51+
template <>
52+
__device__ void debug_print_var<half_t>(const char *msg, half_t var) {
53+
const char *safe_msg = msg;
54+
float value = static_cast<float>(var);
55+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
56+
"dtype=half_t value=%f\n",
57+
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
58+
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
59+
}
60+
5061
// Specialization for float type
5162
template <> __device__ void debug_print_var<float>(const char *msg, float var) {
5263
const char *safe_msg = msg;
@@ -133,6 +144,20 @@ debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
133144
index, value);
134145
}
135146

147+
// Specialization for bool type
148+
template <>
149+
__device__ void debug_print_buffer_value<bool>(const char *msg,
150+
const char *buf_name, int index,
151+
bool var) {
152+
const char *safe_msg = msg;
153+
const char *safe_buf_name = buf_name;
154+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
155+
"index=%d, dtype=bool value=%s\n",
156+
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
157+
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
158+
index, var ? "true" : "false");
159+
}
160+
136161
// Specialization for integer type
137162
template <>
138163
__device__ void debug_print_buffer_value<int>(const char *msg,

src/tl_templates/hip/reduce.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ struct SharedReduceWarp {
7373
}
7474

7575
for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
76-
T other = __shfl_down(partial, offset, kWarpSize);
76+
constexpr uint32_t mask = 0xffffffff;
77+
T other = tl::shfl_down_sync(mask, partial, offset, kWarpSize);
7778
partial = Reducer()(partial, other);
7879
}
7980

@@ -104,7 +105,8 @@ struct AllReduce {
104105
__syncthreads();
105106
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
106107
} else {
107-
x = Reducer()(x, __shfl_xor(x, offset));
108+
constexpr uint32_t mask = 0xffffffff;
109+
x = Reducer()(x, tl::shfl_xor_sync(mask, x, offset));
108110
}
109111
if constexpr (offset == scale) {
110112
return x;

testing/python/autotune/test_tilelang_autotune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,13 @@ def main(
260260
return autotuner.run(warmup=3, rep=20)
261261

262262

263+
@tilelang.testing.requires_cuda
263264
def test_autotune_get_configs():
264265
get_configs(1024, 1024, 1024, with_roller=True)
265266
get_configs(1024, 1024, 1024, with_roller=False)
266267

267268

269+
@tilelang.testing.requires_cuda
268270
def test_autotune_matmul():
269271
matmul(1024, 1024, 1024, with_roller=True)
270272
matmul(1024, 1024, 1024, with_roller=False)

testing/python/carver/test_tilelang_carver_recommend_hints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def run_fmha_recommend_hints(
132132
assert len(hints) > 0, "Hints length should be greater than 0"
133133

134134

135+
@tilelang.testing.requires_cuda
135136
def test_fmha_recommend_hints():
136137
run_fmha_recommend_hints(4, 32, 512, 512, 128, "float16", "float16", "float16")
137138
run_fmha_recommend_hints(4, 32, 512, 512, 128, "int8", "int32", "int32")

testing/python/components/test_storage_rewrite_detect_inplace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import tilelang
22
import tilelang.testing
33
from tilelang import language as T
4+
from tilelang.utils.target import check_hip_availability
5+
6+
_IS_HIP_AVAILABLE = check_hip_availability()
47

58

69
@tilelang.jit
@@ -54,8 +57,9 @@ def test_storage_rewrite_detect_inplace_toggle():
5457
script_off = _get_device_kernel_script(detect_inplace=False)
5558
script_on = _get_device_kernel_script(detect_inplace=True)
5659

57-
assert script_off.count("read = (read * 2);") == 0
58-
assert script_on.count("read = (read * 2);") > 0
60+
pattern = "read[0] = (read[0] * 2);" if _IS_HIP_AVAILABLE else "read = (read * 2);"
61+
assert script_off.count(pattern) == 0
62+
assert script_on.count(pattern) > 0
5963

6064

6165
if __name__ == "__main__":

testing/python/debug/test_device_assert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def program():
1313
tid = T.get_thread_binding()
1414
T.device_assert(tid > 0, "Assertion Trigger !")
1515

16-
jit_kernel = tilelang.compile(program, target="cuda")
16+
jit_kernel = tilelang.compile(program, target="auto")
1717
profiler = jit_kernel.get_profiler()
1818
profiler.run_once()
1919

@@ -25,7 +25,7 @@ def program():
2525
tid = T.get_thread_binding()
2626
T.device_assert(tid == tid)
2727

28-
jit_kernel = tilelang.compile(program, target="cuda")
28+
jit_kernel = tilelang.compile(program, target="auto")
2929
profiler = jit_kernel.get_profiler()
3030
profiler.run_once()
3131

0 commit comments

Comments
 (0)