Skip to content

Commit

Permalink
add gather cache kernel
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Feb 5, 2025
1 parent c375826 commit bf6a400
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 2 deletions.
4 changes: 4 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);

void gather_cache(torch::Tensor const& src_cache, torch::Tensor const& dst,
torch::Tensor const& block_table,
torch::Tensor const& cu_seq_lens, int64_t batch_size);
125 changes: 125 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_utils.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"

Expand Down Expand Up @@ -570,3 +571,127 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
}

namespace vllm {

// grid (batch, num_splits)
template <typename scalar_t>
__global__ void gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t __restrict__* block_table, // [BATCH, BLOCK_INDICES]
const int32_t __restrict__* cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = cu_seq_lens[bid];
const int32_t seq_end = cu_seq_lens[bid + 1];
const int32_t seq_len = seq_end - seq_start;
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);

const int32_t split_start = split * split_blocks;
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);

const bool is_active_split = (split_start < tot_blocks);
const bool is_last_split = (split_end == tot_blocks);

if (!is_active_split) return;

int32_t full_blocks_end = split_end;
int32_t partial_block_size = 0;

block_table += bid * block_table_stride;
dst += seq_start * dst_entry_stride;

if (is_last_split) {
partial_block_size = seq_len % block_size;
if (partial_block_size) full_blocks_end -= 1;
}

auto copy_entry = [&](const scalar_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
};

for (int pid = split_start; pid < full_blocks_end; ++pid) {
auto block_id = block_table[pid];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
for (int eid = 0; eid < block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}

if (partial_block_size) {
auto block_id = block_table[full_blocks_end];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
for (int eid = 0; eid < partial_block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}
}

} // namespace vllm

#define CALL_GATHER_CACHE(CPY_DTYPE) \
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride);

// Gather sequences from cache into dst tensor based on block_table
// cu_seq_lens is the cumulative sequence lengths of the batch and
// simultaneously used as the start locations in dst for each sequence
// in the batch.
void gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);

TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
"cu_seq_lens must be int32");

int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);

int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;

dim3 grid(batch_size, num_splits);
dim3 block(1024);

TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");

const int dtype_bits = src_cache.element_size() * 8;
if (dtype_bits == 32) {
CALL_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
10 changes: 10 additions & 0 deletions csrc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,13 @@
int64_t get_device_attribute(int64_t attribute, int64_t device_id);

int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);

namespace cuda_utils {

template <typename T>
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
ceil_div(T a, T b) {
return (a + b - 1) / b;
}

}; // namespace cuda_utils
6 changes: 6 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);

// Gather cache blocks from src_cache to dst.
cache_ops.def(
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size) -> ()");
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
Expand Down
72 changes: 70 additions & 2 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,8 +682,6 @@ def test_swap_blocks_mla(
torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(kv_lora_rank == KV_LORA_RANKS[0]
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
)

ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
Expand All @@ -694,3 +692,73 @@ def test_swap_blocks_mla(
dst_cache[dst].cpu(),
msg=f"Block {src} from src should have been swapped to block "
f"{dst} in dst_cache.")


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [2])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [4])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("align_cache", [False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, align_cache, device):
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

seq_len_tensor = torch.randint(1,
max_seq_len + 1, (batch_size, ),
device=device)

total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)

tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)

for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm

dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)

expected_batches = []
for b in range(batch_size):
s = seq_len_tensor[b]
tot = tot_blocks_tensor[b]
blocks = block_table[b, :tot].tolist()

gathered_rows = []
for i in range(tot - 1):
gathered_rows.append(src_cache[blocks[i]])
remaining = s - (tot - 1) * block_size
gathered_rows.append(src_cache[blocks[-1], :remaining, :])

batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)

opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)

ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected)
7 changes: 7 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,13 @@ def convert_fp8(output: torch.Tensor,
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def gather_cache(src_cache: torch.Tensor, dst: torch.Tensor,
block_table: torch.Tensor, cu_seq_lens: torch.Tensor,
batch_size: int) -> None:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size)


def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)

Expand Down

0 comments on commit bf6a400

Please sign in to comment.