Skip to content

Commit bb6cb06

Browse files
committed
fix
1 parent 5ca0eb7 commit bb6cb06

File tree

2 files changed

+21
-20
lines changed

2 files changed

+21
-20
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ jobs:
402402
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
403403
)
404404
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
405+
--ignore=./python/runtime --ignore=./python/transform \
405406
./python
406407
407408
# Apple Metal tests

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ def ref_program(A, B):
111111
@pytest.mark.parametrize(
112112
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
113113
[
114-
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
115-
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
116-
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
117-
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
118-
(128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
114+
(512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128),
115+
(512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128),
116+
(512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128),
117+
(512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128),
118+
(128, 8, 32, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128),
119119
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
120120
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
121121
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
@@ -244,11 +244,11 @@ def ref_program(A, B):
244244
@pytest.mark.parametrize(
245245
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
246246
[
247-
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
248-
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
249-
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
250-
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
251-
(128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
247+
(512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
248+
(512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
249+
(512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
250+
(512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
251+
(128, 8, 32, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128),
252252
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
253253
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
254254
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
@@ -376,11 +376,11 @@ def ref_program(A, B):
376376
@pytest.mark.parametrize(
377377
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
378378
[
379-
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
380-
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
381-
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
382-
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
383-
(128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
379+
(512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
380+
(512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
381+
(512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
382+
(512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
383+
(128, 8, 32, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128),
384384
(128, 128, 32, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
385385
(128, 128, 32, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
386386
(128, 128, 32, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
@@ -514,12 +514,12 @@ def ref_program(A, B):
514514
@pytest.mark.parametrize(
515515
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
516516
[
517-
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
518-
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
519-
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
520-
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
517+
(512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
518+
(512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
519+
(512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
520+
(512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
521521
(512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float, 128, 256, 32, 2, 128),
522-
(128, 8, 128, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 2, 128),
522+
(128, 8, 128, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 2, 128),
523523
(128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 32, 2, 128),
524524
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
525525
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),

0 commit comments

Comments
 (0)