|
1 | 1 | import torch
|
| 2 | +import time |
| 3 | +import tempfile |
| 4 | +import numpy as np |
2 | 5 | import torch.distributed as dist
|
| 6 | +from os.path import join |
3 | 7 | from .radix_cache import RadixCache, TreeNode, match
|
4 | 8 | from typing import Tuple, Dict, Set, List
|
5 | 9 | from lightllm.common.mem_manager import MemoryManager
|
6 | 10 | from lightllm.utils.log_utils import init_logger
|
7 | 11 | from threading import Lock
|
8 | 12 | from enum import Enum
|
| 13 | +from .shared_arr import SharedArray |
9 | 14 | from kvcache.python.jit import PyLocalCacheService
|
10 |
| -import time |
11 | 15 |
|
12 | 16 | logger = init_logger(__name__)
|
13 | 17 |
|
| 18 | +def wait_until_ready(task, timeout=10.0, check_interval=0.01): |
| 19 | + start_time = time.time() |
| 20 | + while not task.ready(): |
| 21 | + time.sleep(check_interval) |
| 22 | + if time.time() - start_time > timeout: |
| 23 | + logger.error("Current kv cache task not ready in time") |
| 24 | + return False |
| 25 | + return True |
14 | 26 |
|
15 |
| -class HiRadixCache(RadixCache): |
16 |
| - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length): |
17 |
| - super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) |
18 |
| - logger.info("Initializing HiRadixCache") |
19 |
| - self.rank_in_node = rank_in_node |
20 |
| - try: |
21 |
| - # TODO: determine by model type && dp, tp |
22 |
| - store_once = True # Deepseek -> True, Llama -> False |
23 |
| - self.do_store = store_once and self.rank_in_node == 0 |
24 |
| - self.is_hi_radix_cache = True |
25 |
| - all_buffers = self.mem_manager.kv_buffer |
26 |
| - all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) |
27 |
| - self.py_cache_service = ( |
28 |
| - PyLocalCacheService( |
29 |
| - file="cache/cache_file", |
30 |
| - storage_size=128 * (1024 ** 3), |
31 |
| - num_shard=32, |
32 |
| - kvcache_tensor=all_buffers, |
33 |
| - num_worker=32, |
34 |
| - ) |
35 |
| - if self.do_store |
36 |
| - else None |
37 |
| - ) |
38 |
| - self.working_tasks = {} |
39 |
| - except Exception as e: |
40 |
| - logger.error(f"error alloc hi cache buffer {e}, fallback to normal radix cache") |
41 |
| - self.hi_cache_kv_buffer = None |
42 |
| - self.is_hi_radix_cache = False |
| 27 | +class LocalCacheManager: |
43 | 28 |
|
44 |
| - def insert_disk(self, req_id, key, value): |
45 |
| - if not self.do_store: |
46 |
| - return |
47 |
| - if req_id in self.working_tasks: |
48 |
| - self.abort_req_store_task(req_id) |
49 |
| - self.working_tasks[req_id] = self.py_cache_service.create(tokens=key, kv_page_indexer=value, mode="w") |
50 |
| - logger.info(f"Created store task for req {req_id}.") |
| 29 | + def __init__(self, unique_name: str, rank_in_node: int, mem_manager): |
| 30 | + tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") |
| 31 | + self.cache_file = join(tmp_dir, "cache_file") |
| 32 | + all_buffers = mem_manager.kv_buffer |
| 33 | + all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) |
51 | 34 |
|
52 |
| - def abort_req_store_task(self, req_id): |
53 |
| - if not self.do_store or req_id not in self.working_tasks: |
54 |
| - return |
55 |
| - if self.working_tasks[req_id].ready(): |
56 |
| - logger.info(f"Calling abort for req {req_id}, but is finished.") |
57 |
| - return |
58 |
| - logger.info(f"Aborting req {req_id} unfinished.") |
59 |
| - self.py_cache_service.az5(self.working_tasks[req_id]) |
| 35 | + self.py_cache_service = PyLocalCacheService( |
| 36 | + file=self.cache_file, |
| 37 | + storage_size=128 * (1024 ** 3), # 128GB |
| 38 | + num_shard=32, |
| 39 | + kvcache_tensor=all_buffers, |
| 40 | + num_worker=8 |
| 41 | + ) |
60 | 42 |
|
61 |
| - def match_prefix(self, key, update_refs=False): |
62 |
| - assert len(key) != 0 |
63 |
| - ans_value_list = [] |
64 |
| - pull_hi_cache_tensor = torch.tensor([0], dtype=torch.int64).cuda(self.rank_in_node) |
65 |
| - if self.do_store: |
66 |
| - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) |
67 |
| - max_len = self._query_hi_cache(key) # x64 |
68 |
| - logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.") |
69 |
| - pull_hi_cache_tensor[0] = max_len if (max_len > sum(len(s) for s in ans_value_list)) else 0 |
70 |
| - dist.broadcast(pull_hi_cache_tensor, src=0) |
71 |
| - pull_hi_cache = False |
| 43 | + def insert(self, tokens, kv_page_indexer, start_pos=0): |
| 44 | + t = self.py_cache_service.create( |
| 45 | + tokens=tokens, |
| 46 | + kv_page_indexer=kv_page_indexer, |
| 47 | + mode="w", |
| 48 | + start_pos=start_pos) |
| 49 | + res = wait_until_ready(t) |
| 50 | + if not res: |
| 51 | + self.py_cache_service.az5(t) |
72 | 52 |
|
73 |
| - if pull_hi_cache_tensor[0] == 0: |
74 |
| - ans_value_list = [] |
75 |
| - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) |
76 |
| - elif pull_hi_cache_tensor[0] > 0: |
77 |
| - pull_hi_cache = True |
78 |
| - max_len = pull_hi_cache_tensor[0] |
79 |
| - try: |
80 |
| - self.free_radix_cache_to_get_enough_token(max_len) |
81 |
| - except: |
82 |
| - logger.info(f"Unable to free on rank {self.rank_in_node}") |
83 |
| - pull_hi_cache_tensor[0] = 0 |
84 |
| - pull_hi_cache = False |
85 |
| - ans_value_list = [] |
86 |
| - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) |
87 |
| - if pull_hi_cache: |
88 |
| - buffers = self.mem_manager.alloc(max_len) |
89 |
| - if self.do_store: |
90 |
| - read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") |
91 |
| - while not read_task.ready(): |
92 |
| - time.sleep(0.05) |
93 |
| - dist.broadcast(self.mem_manager.get_index_kv_buffer(buffers)["kv_buffer"], src=0) |
94 |
| - logger.info(f"HiCache pulled one cache with len = {max_len}") |
95 |
| - self._insert_helper(self.root_node, key, buffers) |
96 |
| - ans_value_list = [] |
97 |
| - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) |
98 |
| - if tree_node != self.root_node: |
99 |
| - if len(ans_value_list) != 0: |
100 |
| - value = torch.concat(ans_value_list) |
101 |
| - else: |
102 |
| - assert False, "can not run to here" |
103 |
| - return tree_node, len(value), value |
104 |
| - else: |
105 |
| - self.dec_node_ref_counter(self.root_node) |
106 |
| - return None, 0, None |
| 53 | + def read(self, tokens, kv_page_indexer, start_pos=0): |
| 54 | + t = self.py_cache_service.create( |
| 55 | + tokens=tokens, |
| 56 | + kv_page_indexer=kv_page_indexer, |
| 57 | + mode="r", |
| 58 | + start_pos=start_pos) |
| 59 | + res = wait_until_ready(t) |
| 60 | + return res |
107 | 61 |
|
108 |
| - def _query_hi_cache(self, key) -> bool: |
109 |
| - query_result = self.py_cache_service.query(key) |
110 |
| - # query_result is a list of bool, find out the max len true continuous from start |
| 62 | + def query(self, tokens): |
| 63 | + query_result = self.py_cache_service.query(tokens) |
111 | 64 | max_len = 0
|
112 | 65 | for result in query_result:
|
113 | 66 | if result:
|
114 | 67 | max_len += 1
|
115 | 68 | else:
|
116 | 69 | break
|
117 |
| - return max_len * self.py_cache_service.tokens_per_block |
| 70 | + return max_len * self.block_size |
| 71 | + |
| 72 | + @property |
| 73 | + def block_size(self,): |
| 74 | + return self.py_cache_service.tokens_per_block |
| 75 | + |
| 76 | +class HiRadixCache(RadixCache): |
| 77 | + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager): |
| 78 | + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) |
| 79 | + self.rank_in_node = rank_in_node |
| 80 | + self.local_cache_manager = LocalCacheManager( |
| 81 | + unique_name, |
| 82 | + rank_in_node, |
| 83 | + mem_manager, |
| 84 | + ) |
| 85 | + self.is_hi_radix_cache = True |
| 86 | + self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64) |
| 87 | + self.disk_cache_match_count.arr[0] = 0 |
| 88 | + self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64) |
| 89 | + self.total_match_count.arr[0] = 0 |
| 90 | + self.disk_cache_match_ratio = SharedArray(f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32) |
| 91 | + self.disk_cache_match_ratio.arr[0] = 0.0 |
| 92 | + logger.info(f"Initializing HiRadixCache {rank_in_node}") |
| 93 | + |
| 94 | + def insert(self, key, value=None): |
| 95 | + share_len = super().insert(key, value) |
| 96 | + if share_len == 0: |
| 97 | + return 0 |
| 98 | + self.local_cache_manager.insert(key, value) |
| 99 | + return share_len |
| 100 | + |
| 101 | + def match_prefix(self, key, update_refs=False): |
| 102 | + assert len(key) != 0 |
| 103 | + self.total_match_count.arr[0] += 1 |
| 104 | + ans_value_list = [] |
| 105 | + ans_value = None |
| 106 | + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) |
| 107 | + if tree_node.node_prefix_total_len != 0: |
| 108 | + ans_value = torch.concat(ans_value_list) |
| 109 | + max_len = 0 |
| 110 | + if tree_node.node_prefix_total_len < len(key): |
| 111 | + max_len = self.local_cache_manager.query(key) |
| 112 | + if max_len > tree_node.node_prefix_total_len: |
| 113 | + pull_len = max_len - tree_node.node_prefix_total_len |
| 114 | + self.disk_cache_match_count.arr[0] += 1 |
| 115 | + self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0] |
| 116 | + self.free_radix_cache_to_get_enough_token(pull_len) |
| 117 | + buffers = self.mem_manager.alloc(pull_len) |
| 118 | + start_pos = 0 |
| 119 | + if ans_value is not None: |
| 120 | + buffers = torch.concat([ans_value, buffers]) |
| 121 | + start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size |
| 122 | + logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk") |
| 123 | + res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos) |
| 124 | + if res: |
| 125 | + super().insert(key[:max_len], buffers) |
| 126 | + else: |
| 127 | + self.mem_manager.free(buffers[tree_node.node_prefix_total_len:]) |
| 128 | + return super().match_prefix(key, update_refs=update_refs) |
0 commit comments