Skip to content

Commit 80d101f

Browse files
authored
add kv trans v2 kernel for dp mode pd (#763)
1 parent f12ba29 commit 80d101f

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _kv_trans_kernel(
9+
input_mems_ptr,
10+
input_stride_0,
11+
input_stride_1,
12+
input_stride_2,
13+
input_token_idx_ptr,
14+
input_dp_idx_ptr,
15+
output_ptr,
16+
output_stride_0,
17+
output_stride_1,
18+
output_stride_2,
19+
output_token_idx_ptr,
20+
token_num: int,
21+
head_num: int,
22+
head_dim: int,
23+
grid_count: int,
24+
BLOCK_SIZE: tl.constexpr,
25+
NUM_STAGES: tl.constexpr,
26+
):
27+
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
28+
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
29+
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
30+
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)
31+
32+
head_num_dim = head_num * head_dim
33+
tid = tl.program_id(0)
34+
35+
offs = tl.arange(0, BLOCK_SIZE)
36+
while tid < token_num:
37+
dp_index = tl.load(input_dp_idx_ptr + tid)
38+
input_token_idx = tl.load(input_token_idx_ptr + tid)
39+
output_token_idx = tl.load(output_token_idx_ptr + tid)
40+
input_ptr = tl.load(input_mems_ptr + dp_index).to(tl.pointer_type(output_ptr.dtype.element_ty))
41+
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
42+
cur_offs = block_idx * BLOCK_SIZE + offs
43+
in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim)
44+
tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim)
45+
46+
tid += grid_count
47+
48+
return
49+
50+
51+
def kv_trans_v2(
52+
input_mems: torch.Tensor,
53+
input_idx: torch.Tensor,
54+
input_dp_idx: torch.Tensor,
55+
output: torch.Tensor,
56+
output_idx: torch.Tensor,
57+
dp_size_in_node: int,
58+
):
59+
"""
60+
input_memes 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
61+
"""
62+
assert input_mems.is_contiguous()
63+
assert output.is_contiguous()
64+
assert len(input_mems.shape) == 1
65+
assert len(input_mems) == dp_size_in_node
66+
assert len(output.shape) == 3
67+
assert len(input_idx) == len(output_idx)
68+
assert len(input_idx) == len(input_dp_idx)
69+
70+
_, head_num, head_dim = output.shape
71+
token_num = len(input_idx)
72+
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
73+
grid_count = 20
74+
BLOCK_SIZE = 256
75+
NUM_STAGES = 3
76+
grid = (grid_count,)
77+
78+
_kv_trans_kernel[grid](
79+
input_mems,
80+
*output.stride(),
81+
input_idx,
82+
input_dp_idx,
83+
output,
84+
*output.stride(),
85+
output_idx,
86+
token_num=token_num,
87+
head_num=head_num,
88+
head_dim=head_dim,
89+
grid_count=grid_count,
90+
BLOCK_SIZE=BLOCK_SIZE,
91+
NUM_STAGES=NUM_STAGES,
92+
num_warps=1,
93+
)
94+
return
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
import torch
3+
import random
4+
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2
5+
6+
7+
@pytest.mark.parametrize(
8+
"token_num",
9+
[token_num for token_num in range(5, 10)],
10+
)
11+
def test_kv_trans_v2(token_num):
12+
dp_size_in_node = 8
13+
head_num = 2
14+
head_dim = 512
15+
kv_buffer_token_num = 512
16+
mems = []
17+
for _ in range(dp_size_in_node):
18+
mems.append(torch.randn((kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda"))
19+
input_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda")
20+
input_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)]
21+
input_idx = torch.tensor(input_idx, dtype=torch.int32, device="cuda")
22+
input_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)]
23+
input_dp_idx = torch.tensor(input_dp_idx, dtype=torch.int32, device="cuda")
24+
25+
true_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
26+
test_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
27+
output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda")
28+
29+
kv_trans_v2(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node)
30+
31+
for dest_token_index, token_index, dp_index in zip(
32+
list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy()
33+
):
34+
true_output[dest_token_index, :, :] = mems[dp_index][token_index]
35+
36+
assert torch.equal(true_output, test_output)
37+
return
38+
39+
40+
if __name__ == "__main__":
41+
pytest.main()

0 commit comments

Comments
 (0)