Skip to content

Commit 3781b54

Browse files
author
Sylvie Liberman
committed
fix
1 parent a08bd7e commit 3781b54

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

kernels/bmm/matmul.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ void cpu_gemm(float* a, float* b, float* c, int B, int M, int N, int K) {
143143
#include "pyutils/torch_helpers.cuh"
144144

145145
torch::Tensor batch_matmul(torch::Tensor A, torch::Tensor B) {
146+
CHECK_INPUT(A);
147+
CHECK_INPUT(B);
146148
TORCH_CHECK(A.size(0) == B.size(0), "Batch size mismatch");
147149
TORCH_CHECK(A.size(2) == B.size(2), "Inner dimensions mismatch");
148150
uint batch = A.size(0), M = A.size(1), K = A.size(2), N = B.size(1);
@@ -160,7 +162,7 @@ torch::Tensor batch_matmul(torch::Tensor A, torch::Tensor B) {
160162

161163
dim3 grid(mmt::grid(batch, M, N, K));
162164
dim3 block(kittens::prototype::detail::NUM_THREADS_v<mmt>);
163-
165+
cudaFuncSetAttribute(prototype::lcf::kernel<mmt>, cudaFuncAttributeMaxDynamicSharedMemorySize, MAX_SHARED_MEMORY-1024);
164166
prototype::lcf::kernel<mmt><<<grid, block, MAX_SHARED_MEMORY-1024>>>(G);
165167
return C;
166168
}

kernels/bmm/python_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import torch
22
import thunderkittens
33

4-
b, N, K, M = 2, 4096, 4096, 4096
4+
b, N, K, M = 2, 4096, 4096, 64
55
A = torch.rand([b, N, K], dtype=torch.bfloat16, device="cuda")
6-
A_ = A.clone()
76
B = torch.rand([b, M, K], dtype=torch.bfloat16, device="cuda")
8-
B_ = B.clone()
97
c = thunderkittens.batch_matmul(A, B)
108
torch.cuda.synchronize()
11-
d = thunderkittens.batch_matmul(c, B)
12-
torch.cuda.synchronize()
13-
assert A.equal(A_)
9+
# d = thunderkittens.batch_matmul(c, B)
10+
# torch.cuda.synchronize()
1411
ref = A@(B.transpose(-2, -1))
1512
print(c[0, 0,0])
1613
assert ref.allclose(c)

0 commit comments

Comments
 (0)