Skip to content

Commit fe66e41

Browse files
[FE] Error when trying to use atomics with any type smaller than 16b (triton-lang#6717)
So far we have been checking for pre-defined list of unsupported types, missing that fp8 is generally also not supported in atomic operations.
1 parent 12f8ae7 commit fe66e41

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,6 +1889,20 @@ def kernel(inp, out_max, out_min):
18891889
torch.testing.assert_close(out_max, inp, atol=0, rtol=0)
18901890

18911891

1892+
@pytest.mark.parametrize("dtype_str", ["float8_e4m3fn", "bfloat16", "int8", "int16", "uint8", "uint16"])
1893+
def test_atomic_unsupported_type(dtype_str, device):
1894+
1895+
@triton.jit
1896+
def kernel(I, O):
1897+
x = tl.load(I)
1898+
tl.atomic_add(O, x)
1899+
1900+
I = torch.zeros((1, ), device=device, dtype=getattr(torch, dtype_str))
1901+
O = torch.zeros((1, ), device=device, dtype=getattr(torch, dtype_str))
1902+
with pytest.raises(triton.TritonError):
1903+
kernel[(1, )](I, O)
1904+
1905+
18921906
# ---------------
18931907
# test cast
18941908
# ---------------

python/triton/language/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
13811381
element_ty = ptr.type.scalar.element_ty
13821382
if element_ty is tl.float16 and op != 'add':
13831383
raise ValueError("atomic_" + op + " does not support fp16")
1384-
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
1384+
if element_ty in [tl.int16, tl.uint16, tl.bfloat16] or element_ty.primitive_bitwidth < 16:
13851385
raise ValueError("atomic_" + op + " does not support " + str(element_ty))
13861386
if ptr.type.is_block():
13871387
if mask is not None:

0 commit comments

Comments
 (0)