|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | + |
| 7 | +@triton.jit |
| 8 | +def _kv_trans_kernel( |
| 9 | + input_mems_ptr, |
| 10 | + input_stride_0, |
| 11 | + input_stride_1, |
| 12 | + input_stride_2, |
| 13 | + input_token_idx_ptr, |
| 14 | + input_dp_idx_ptr, |
| 15 | + output_ptr, |
| 16 | + output_stride_0, |
| 17 | + output_stride_1, |
| 18 | + output_stride_2, |
| 19 | + output_token_idx_ptr, |
| 20 | + token_num: int, |
| 21 | + head_num: int, |
| 22 | + head_dim: int, |
| 23 | + grid_count: int, |
| 24 | + BLOCK_SIZE: tl.constexpr, |
| 25 | + NUM_STAGES: tl.constexpr, |
| 26 | +): |
| 27 | + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) |
| 28 | + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) |
| 29 | + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) |
| 30 | + output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64) |
| 31 | + |
| 32 | + head_num_dim = head_num * head_dim |
| 33 | + tid = tl.program_id(0) |
| 34 | + |
| 35 | + offs = tl.arange(0, BLOCK_SIZE) |
| 36 | + while tid < token_num: |
| 37 | + dp_index = tl.load(input_dp_idx_ptr + tid) |
| 38 | + input_token_idx = tl.load(input_token_idx_ptr + tid) |
| 39 | + output_token_idx = tl.load(output_token_idx_ptr + tid) |
| 40 | + input_ptr = tl.load(input_mems_ptr + dp_index).to(tl.pointer_type(output_ptr.dtype.element_ty)) |
| 41 | + for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES): |
| 42 | + cur_offs = block_idx * BLOCK_SIZE + offs |
| 43 | + in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim) |
| 44 | + tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim) |
| 45 | + |
| 46 | + tid += grid_count |
| 47 | + |
| 48 | + return |
| 49 | + |
| 50 | + |
| 51 | +def kv_trans_v2( |
| 52 | + input_mems: torch.Tensor, |
| 53 | + input_idx: torch.Tensor, |
| 54 | + input_dp_idx: torch.Tensor, |
| 55 | + output: torch.Tensor, |
| 56 | + output_idx: torch.Tensor, |
| 57 | + dp_size_in_node: int, |
| 58 | +): |
| 59 | + """ |
| 60 | + input_memes 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。 |
| 61 | + """ |
| 62 | + assert input_mems.is_contiguous() |
| 63 | + assert output.is_contiguous() |
| 64 | + assert len(input_mems.shape) == 1 |
| 65 | + assert len(input_mems) == dp_size_in_node |
| 66 | + assert len(output.shape) == 3 |
| 67 | + assert len(input_idx) == len(output_idx) |
| 68 | + assert len(input_idx) == len(input_dp_idx) |
| 69 | + |
| 70 | + _, head_num, head_dim = output.shape |
| 71 | + token_num = len(input_idx) |
| 72 | + # 用较少的资源来做数据传输,防止占用过多的 sm 计算单元 |
| 73 | + grid_count = 20 |
| 74 | + BLOCK_SIZE = 256 |
| 75 | + NUM_STAGES = 3 |
| 76 | + grid = (grid_count,) |
| 77 | + |
| 78 | + _kv_trans_kernel[grid]( |
| 79 | + input_mems, |
| 80 | + *output.stride(), |
| 81 | + input_idx, |
| 82 | + input_dp_idx, |
| 83 | + output, |
| 84 | + *output.stride(), |
| 85 | + output_idx, |
| 86 | + token_num=token_num, |
| 87 | + head_num=head_num, |
| 88 | + head_dim=head_dim, |
| 89 | + grid_count=grid_count, |
| 90 | + BLOCK_SIZE=BLOCK_SIZE, |
| 91 | + NUM_STAGES=NUM_STAGES, |
| 92 | + num_warps=1, |
| 93 | + ) |
| 94 | + return |
0 commit comments