Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions kernels/gemm/educational_b200/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Change this!
LEVEL := 01

###############################################
###############################################
###############################################

GPU := B200
SRC := level_$(LEVEL).cu
OUT := level_$(LEVEL).out
CMD := ./level_$(LEVEL).out
CONFIG := standalone
include ../../common.mk
15 changes: 15 additions & 0 deletions kernels/gemm/educational_b200/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# ThunderKittens Educational GEMM Kernels (Blackwell)

This folder builds up the B200 GEMM piece-by-piece. It is only for educational purposes.

Change the `LEVEL` field in the `Makefile` to `01` - `09`, then `make clean && make run`.

- Level 01: Simple for loop (float) -- this is faster than bf16 because bf16 gets implicitly converted to floats first on cuda cores
- Level 02: Simple for loop (bf16)
- Level 03: Use shared memory
- Level 04: Use tensor cores (WMMA)
- Level 05: Use TMA for global<->shared memory transfers (+ WMMA)
- Level 06: Use tensor cores (tcgen05 MMA) with TMA
- Level 07: Use pipelined warp specialization (TMA loader + MMA issuer)
- Level 08: Use epilogue pipelining
- Level 09: Use 2-CTA cluster and warpgroup-level parallelism
124 changes: 124 additions & 0 deletions kernels/gemm/educational_b200/launch.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <chrono>

#include "../common.cuh"

int run_benchmark(size_t N) {
cudaError_t cudaStatus;
std::cout << "-------------------- N=" << N << " --------------------\n";

float *h_A = new float[N * N];
float *h_B = new float[N * N];
float *h_C = new float[N * N];
std::cout << "Allocated host memory" << std::endl;

std::random_device rd;
std::mt19937 gen(42);
std::uniform_real_distribution<> dis(-0.5, 0.5);

for (size_t i = 0; i < N * N; ++i) h_A[i] = dis(gen);
for (size_t i = 0; i < N * N; ++i) h_B[i] = dis(gen);
std::cout << "Initialized matrices" << std::endl;

__nv_bfloat16 *d_A, *d_B, *d_C, *d_C_ref;
cudaMalloc(&d_A, N*N*sizeof(__nv_bfloat16));
cudaMalloc(&d_B, N*N*sizeof(__nv_bfloat16));
cudaMalloc(&d_C, N*N*sizeof(__nv_bfloat16));
cudaMalloc(&d_C_ref, N*N*sizeof(__nv_bfloat16));
cudaStatus = cudaGetLastError();
if (cudaStatus != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(cudaStatus) << std::endl;
return -1;
}
std::cout << "Allocated device memory" << std::endl;

__nv_bfloat16 *h_A_bf16 = new __nv_bfloat16[N * N];
__nv_bfloat16 *h_B_bf16 = new __nv_bfloat16[N * N];
for (size_t i = 0; i < N * N; ++i) h_A_bf16[i] = __float2bfloat16(h_A[i]);
for (size_t i = 0; i < N * N; ++i) h_B_bf16[i] = __float2bfloat16(h_B[i]);
cudaMemcpy(d_A, h_A_bf16, N*N*2, cudaMemcpyHostToDevice);
cudaMemcpy(d_B, h_B_bf16, N*N*2, cudaMemcpyHostToDevice);
std::cout << "Copied matrices to device" << std::endl;

reference_gemm<__nv_bfloat16, __nv_bfloat16, false>(d_C_ref, d_A, d_B, N, N, N);
cudaDeviceSynchronize();
std::cout << "Computed reference GEMM on device" << std::endl;
printf("\n");

for(int i = 0; i < 2; i++) {
matmul(d_A, d_B, d_C, N);
}
cudaDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
constexpr int ITERS = 1;
for(int i = 0; i < ITERS; i++) {
matmul(d_A, d_B, d_C, N);
}
cudaDeviceSynchronize();

auto end = std::chrono::high_resolution_clock::now();

std::chrono::duration<double> diff = end - start;
double useconds = diff.count() * 1e6 / ITERS;

double flops = double(2.0) * N * N * N;
double tflops = (flops / useconds) / 1e6;
std::cout << "Avg Kernel execution time: " << useconds << " us\n";
std::cout << "Achieved performance: " << tflops << " TFLOPs\n";

cudaStatus = cudaGetLastError();
if (cudaStatus != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(cudaStatus) << std::endl;
return -1;
}

__nv_bfloat16 *h_C_bf16 = new __nv_bfloat16[N * N];
__nv_bfloat16 *h_C_ref_bf16 = new __nv_bfloat16[N * N];
cudaMemcpy(h_C_bf16, d_C, N*N*2, cudaMemcpyDeviceToHost);
cudaMemcpy(h_C_ref_bf16, d_C_ref, N*N*2, cudaMemcpyDeviceToHost);
std::cout << "Copied result back to host" << std::endl;

float *h_C_ref = new float[N * N];
for (size_t i = 0; i < N * N; ++i) h_C[i] = __bfloat162float(h_C_bf16[i]);
for (size_t i = 0; i < N * N; ++i) h_C_ref[i] = __bfloat162float(h_C_ref_bf16[i]);
std::cout << "Converted result back to float" << std::endl;

float max_error = 0.0f;
int error_count = 0;
for (size_t i = 0; i < N * N; ++i) {
float error = std::abs(h_C[i] - h_C_ref[i]);
if( error > 0.2 ) {
if(error_count < 20) std::cout << "Error at row " << i / N << " col " << i % N << ": " << h_C[i] << " != " << h_C_ref[i] << " (ref)" << std::endl;
else if(error_count == 21) std::cout << "Too many errors to show them all.\n";
error_count++;
}
max_error = std::max(max_error, error);
}

std::cout << "Max error: " << max_error << std::endl;
std::cout << "Error count: " << error_count << std::endl;
std::cout << "Total count: " << int(N * N) << std::endl;

delete[] h_A;
delete[] h_B;
delete[] h_C;
delete[] h_C_ref;
delete[] h_A_bf16;
delete[] h_B_bf16;
delete[] h_C_bf16;
delete[] h_C_ref_bf16;
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
cudaFree(d_C_ref);

return 0;
}

int main() {
run_benchmark(4096);
return 0;
}
121 changes: 121 additions & 0 deletions kernels/gemm/educational_b200/level_01.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include <cuda_runtime.h>
#include <iostream>
#include <random>
#include <chrono>

using my_dtype = float;

__global__ void kernel(my_dtype* A, my_dtype* B, my_dtype* C, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;

if (row < N && col < N) {
my_dtype sum = 0.0f;
for (int k = 0; k < N; k++) {
sum += A[row * N + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
}

int BLOCK_SIZE = 32;
void matmul(my_dtype* A, my_dtype* B, my_dtype* C, int N) {
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
dim3 blocks((N + (BLOCK_SIZE-1)) / BLOCK_SIZE, (N + (BLOCK_SIZE-1)) / BLOCK_SIZE);
kernel<<<blocks, threads>>>(A, B, C, N);
}

void cpu_gemm(float* a, float* b, float* c, int N) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float sum = 0.0f;
for (int k = 0; k < N; k++) {
sum += a[i * N + k] * b[k * N + j];
}
c[i * N + j] = sum;
}
}
}

int run_benchmark(size_t N) {
cudaError_t cudaStatus;
std::cout << "-------------------- N=" << N << " --------------------\n";

float *h_A = new float[N * N];
float *h_B = new float[N * N];
float *h_C = new float[N * N];
float *h_C_ref = new float[N * N];

std::mt19937 gen(42);
std::uniform_real_distribution<> dis(-0.5, 0.5);
for (size_t i = 0; i < N * N; ++i) h_A[i] = dis(gen);
for (size_t i = 0; i < N * N; ++i) h_B[i] = dis(gen);

cpu_gemm(h_A, h_B, h_C_ref, N);

float *d_A, *d_B, *d_C;
cudaMalloc(&d_A, N*N*sizeof(float));
cudaMalloc(&d_B, N*N*sizeof(float));
cudaMalloc(&d_C, N*N*sizeof(float));
cudaStatus = cudaGetLastError();
if (cudaStatus != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(cudaStatus) << std::endl;
return -1;
}

cudaMemcpy(d_A, h_A, N*N*4, cudaMemcpyHostToDevice);
cudaMemcpy(d_B, h_B, N*N*4, cudaMemcpyHostToDevice);

for (int i = 0; i < 2; i++) matmul(d_A, d_B, d_C, N);

cudaDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
constexpr int ITERS = 10;
for (int i = 0; i < ITERS; i++) matmul(d_A, d_B, d_C, N);
cudaDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();

std::chrono::duration<double> diff = end - start;
double useconds = diff.count() * 1e6 / ITERS;
double flops = double(2.0) * N * N * N;
double tflops = (flops / useconds) / 1e6;
std::cout << "Avg Kernel execution time: " << useconds << " us\n";
std::cout << "Achieved performance: " << tflops << " TFLOPs\n";

cudaStatus = cudaGetLastError();
if (cudaStatus != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(cudaStatus) << std::endl;
return -1;
}

cudaMemcpy(h_C, d_C, N*N*4, cudaMemcpyDeviceToHost);

int error_count = 0;
float max_error = 0.0f;
for (size_t i = 0; i < N * N; ++i) {
float error = std::abs(h_C[i] - h_C_ref[i]);
if (error > .01f) {
if (error_count < 20) std::cout << "Error at row " << i / N << " col " << i % N << ": " << h_C[i] << " != " << h_C_ref[i] << " (ref)" << std::endl;
else if (error_count == 21) std::cout << "Too many errors to show them all.\n";
error_count++;
}
max_error = std::max(max_error, error);
}
std::cout << "Max error: " << max_error << std::endl;
std::cout << "Error count: " << error_count << std::endl;

delete[] h_A;
delete[] h_B;
delete[] h_C;
delete[] h_C_ref;
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
return 0;
}

int main() {
run_benchmark(1024);
return 0;
}
25 changes: 25 additions & 0 deletions kernels/gemm/educational_b200/level_02.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <cuda_runtime.h>
#include <cuda_bf16.h>

static constexpr int BLOCK_SIZE = 32;

__global__ void kernel(__nv_bfloat16* A, __nv_bfloat16* B, __nv_bfloat16* C, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;

if (row < N && col < N) {
float sum = 0.0f;
for (int k = 0; k < N; k++) {
sum += __bfloat162float(A[row * N + k] * B[k * N + col]);
}
C[row * N + col] = __float2bfloat16(sum);
}
}

void matmul(__nv_bfloat16* A, __nv_bfloat16* B, __nv_bfloat16* C, int N) {
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
dim3 blocks((N + (BLOCK_SIZE-1)) / BLOCK_SIZE, (N + (BLOCK_SIZE-1)) / BLOCK_SIZE);
kernel<<<blocks, threads>>>(A, B, C, N);
}

#include "launch.cu"
33 changes: 33 additions & 0 deletions kernels/gemm/educational_b200/level_03.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <cuda_runtime.h>
#include <cuda_bf16.h>

static constexpr int BLOCK_SIZE = 32;

__global__ void kernel(__nv_bfloat16* A, __nv_bfloat16* B, __nv_bfloat16* C, int N) {
__shared__ __nv_bfloat16 As[BLOCK_SIZE][BLOCK_SIZE], Bs[BLOCK_SIZE][BLOCK_SIZE];
int tx = threadIdx.x, bx = blockIdx.x;
int by = blockIdx.y, ty = threadIdx.y;
int row = by * BLOCK_SIZE + ty;
int col = bx * BLOCK_SIZE + tx;

float sum = 0.0f;
for (int tile = 0; tile < (N + BLOCK_SIZE - 1) / BLOCK_SIZE; ++tile) {
As[ty][tx] = A[row * N + tile * BLOCK_SIZE + tx];
Bs[ty][tx] = B[(tile * BLOCK_SIZE + ty) * N + col];
__syncthreads();
#pragma unroll
for (int k = 0; k < BLOCK_SIZE; ++k) {
sum += __bfloat162float(As[ty][k] * Bs[k][tx]);
}
__syncthreads();
}
C[row * N + col] = __float2bfloat16(sum);
}

void matmul(__nv_bfloat16* A, __nv_bfloat16* B, __nv_bfloat16* C, int N) {
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
dim3 blocks((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
kernel<<<blocks, threads>>>(A, B, C, N);
}

#include "launch.cu"
Loading