Skip to content

Commit a4848c6

Browse files
committed
add a python test for bmm
1 parent 7a08f68 commit a4848c6

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

kernels/bmm/python_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
import thunderkittens
3+
4+
b, N, K, M = 2, 4096, 4096, 64
5+
A = torch.rand([b, N, K], dtype=torch.bfloat16, device="cuda")
6+
B = torch.rand([b, M, K], dtype=torch.bfloat16, device="cuda")
7+
c = thunderkittens.batch_matmul(A, B)
8+
torch.cuda.synchronize()
9+
# d = thunderkittens.batch_matmul(c, B)
10+
# torch.cuda.synchronize()
11+
ref = A@(B.transpose(-2, -1))
12+
print(c[0, 0,0])
13+
assert ref.allclose(c)
14+

0 commit comments

Comments
 (0)