Skip to content

Commit 8ee76c4

Browse files
author
yujinbiao
committed
refactor hicache
1 parent 702b5ac commit 8ee76c4

File tree

4 files changed

+105
-115
lines changed

4 files changed

+105
-115
lines changed
Lines changed: 104 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,128 @@
11
import torch
2+
import time
3+
import tempfile
4+
import numpy as np
25
import torch.distributed as dist
6+
from os.path import join
37
from .radix_cache import RadixCache, TreeNode, match
48
from typing import Tuple, Dict, Set, List
59
from lightllm.common.mem_manager import MemoryManager
610
from lightllm.utils.log_utils import init_logger
711
from threading import Lock
812
from enum import Enum
13+
from .shared_arr import SharedArray
914
from kvcache.python.jit import PyLocalCacheService
10-
import time
1115

1216
logger = init_logger(__name__)
1317

18+
def wait_until_ready(task, timeout=10.0, check_interval=0.01):
19+
start_time = time.time()
20+
while not task.ready():
21+
time.sleep(check_interval)
22+
if time.time() - start_time > timeout:
23+
logger.error("Current kv cache task not ready in time")
24+
return False
25+
return True
1426

15-
class HiRadixCache(RadixCache):
16-
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length):
17-
super().__init__(unique_name, total_token_num, rank_in_node, mem_manager)
18-
logger.info("Initializing HiRadixCache")
19-
self.rank_in_node = rank_in_node
20-
try:
21-
# TODO: determine by model type && dp, tp
22-
store_once = True # Deepseek -> True, Llama -> False
23-
self.do_store = store_once and self.rank_in_node == 0
24-
self.is_hi_radix_cache = True
25-
all_buffers = self.mem_manager.kv_buffer
26-
all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1)
27-
self.py_cache_service = (
28-
PyLocalCacheService(
29-
file="cache/cache_file",
30-
storage_size=128 * (1024 ** 3),
31-
num_shard=32,
32-
kvcache_tensor=all_buffers,
33-
num_worker=32,
34-
)
35-
if self.do_store
36-
else None
37-
)
38-
self.working_tasks = {}
39-
except Exception as e:
40-
logger.error(f"error alloc hi cache buffer {e}, fallback to normal radix cache")
41-
self.hi_cache_kv_buffer = None
42-
self.is_hi_radix_cache = False
27+
class LocalCacheManager:
4328

44-
def insert_disk(self, req_id, key, value):
45-
if not self.do_store:
46-
return
47-
if req_id in self.working_tasks:
48-
self.abort_req_store_task(req_id)
49-
self.working_tasks[req_id] = self.py_cache_service.create(tokens=key, kv_page_indexer=value, mode="w")
50-
logger.info(f"Created store task for req {req_id}.")
29+
def __init__(self, unique_name: str, rank_in_node: int, mem_manager):
30+
tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}")
31+
self.cache_file = join(tmp_dir, "cache_file")
32+
all_buffers = mem_manager.kv_buffer
33+
all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1)
5134

52-
def abort_req_store_task(self, req_id):
53-
if not self.do_store or req_id not in self.working_tasks:
54-
return
55-
if self.working_tasks[req_id].ready():
56-
logger.info(f"Calling abort for req {req_id}, but is finished.")
57-
return
58-
logger.info(f"Aborting req {req_id} unfinished.")
59-
self.py_cache_service.az5(self.working_tasks[req_id])
35+
self.py_cache_service = PyLocalCacheService(
36+
file=self.cache_file,
37+
storage_size=128 * (1024 ** 3), # 128GB
38+
num_shard=32,
39+
kvcache_tensor=all_buffers,
40+
num_worker=8
41+
)
6042

61-
def match_prefix(self, key, update_refs=False):
62-
assert len(key) != 0
63-
ans_value_list = []
64-
pull_hi_cache_tensor = torch.tensor([0], dtype=torch.int64).cuda(self.rank_in_node)
65-
if self.do_store:
66-
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
67-
max_len = self._query_hi_cache(key) # x64
68-
logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.")
69-
pull_hi_cache_tensor[0] = max_len if (max_len > sum(len(s) for s in ans_value_list)) else 0
70-
dist.broadcast(pull_hi_cache_tensor, src=0)
71-
pull_hi_cache = False
43+
def insert(self, tokens, kv_page_indexer, start_pos=0):
44+
t = self.py_cache_service.create(
45+
tokens=tokens,
46+
kv_page_indexer=kv_page_indexer,
47+
mode="w",
48+
start_pos=start_pos)
49+
res = wait_until_ready(t)
50+
if not res:
51+
self.py_cache_service.az5(t)
7252

73-
if pull_hi_cache_tensor[0] == 0:
74-
ans_value_list = []
75-
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
76-
elif pull_hi_cache_tensor[0] > 0:
77-
pull_hi_cache = True
78-
max_len = pull_hi_cache_tensor[0]
79-
try:
80-
self.free_radix_cache_to_get_enough_token(max_len)
81-
except:
82-
logger.info(f"Unable to free on rank {self.rank_in_node}")
83-
pull_hi_cache_tensor[0] = 0
84-
pull_hi_cache = False
85-
ans_value_list = []
86-
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
87-
if pull_hi_cache:
88-
buffers = self.mem_manager.alloc(max_len)
89-
if self.do_store:
90-
read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r")
91-
while not read_task.ready():
92-
time.sleep(0.05)
93-
dist.broadcast(self.mem_manager.get_index_kv_buffer(buffers)["kv_buffer"], src=0)
94-
logger.info(f"HiCache pulled one cache with len = {max_len}")
95-
self._insert_helper(self.root_node, key, buffers)
96-
ans_value_list = []
97-
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
98-
if tree_node != self.root_node:
99-
if len(ans_value_list) != 0:
100-
value = torch.concat(ans_value_list)
101-
else:
102-
assert False, "can not run to here"
103-
return tree_node, len(value), value
104-
else:
105-
self.dec_node_ref_counter(self.root_node)
106-
return None, 0, None
53+
def read(self, tokens, kv_page_indexer, start_pos=0):
54+
t = self.py_cache_service.create(
55+
tokens=tokens,
56+
kv_page_indexer=kv_page_indexer,
57+
mode="r",
58+
start_pos=start_pos)
59+
res = wait_until_ready(t)
60+
return res
10761

108-
def _query_hi_cache(self, key) -> bool:
109-
query_result = self.py_cache_service.query(key)
110-
# query_result is a list of bool, find out the max len true continuous from start
62+
def query(self, tokens):
63+
query_result = self.py_cache_service.query(tokens)
11164
max_len = 0
11265
for result in query_result:
11366
if result:
11467
max_len += 1
11568
else:
11669
break
117-
return max_len * self.py_cache_service.tokens_per_block
70+
return max_len * self.block_size
71+
72+
@property
73+
def block_size(self,):
74+
return self.py_cache_service.tokens_per_block
75+
76+
class HiRadixCache(RadixCache):
77+
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager):
78+
super().__init__(unique_name, total_token_num, rank_in_node, mem_manager)
79+
self.rank_in_node = rank_in_node
80+
self.local_cache_manager = LocalCacheManager(
81+
unique_name,
82+
rank_in_node,
83+
mem_manager,
84+
)
85+
self.is_hi_radix_cache = True
86+
self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64)
87+
self.disk_cache_match_count.arr[0] = 0
88+
self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64)
89+
self.total_match_count.arr[0] = 0
90+
self.disk_cache_match_ratio = SharedArray(f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32)
91+
self.disk_cache_match_ratio.arr[0] = 0.0
92+
logger.info(f"Initializing HiRadixCache {rank_in_node}")
93+
94+
def insert(self, key, value=None):
95+
share_len = super().insert(key, value)
96+
if share_len == 0:
97+
return 0
98+
self.local_cache_manager.insert(key, value)
99+
return share_len
100+
101+
def match_prefix(self, key, update_refs=False):
102+
assert len(key) != 0
103+
self.total_match_count.arr[0] += 1
104+
ans_value_list = []
105+
ans_value = None
106+
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
107+
if tree_node.node_prefix_total_len != 0:
108+
ans_value = torch.concat(ans_value_list)
109+
max_len = 0
110+
if tree_node.node_prefix_total_len < len(key):
111+
max_len = self.local_cache_manager.query(key)
112+
if max_len > tree_node.node_prefix_total_len:
113+
pull_len = max_len - tree_node.node_prefix_total_len
114+
self.disk_cache_match_count.arr[0] += 1
115+
self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0]
116+
self.free_radix_cache_to_get_enough_token(pull_len)
117+
buffers = self.mem_manager.alloc(pull_len)
118+
start_pos = 0
119+
if ans_value is not None:
120+
buffers = torch.concat([ans_value, buffers])
121+
start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size
122+
logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk")
123+
res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos)
124+
if res:
125+
super().insert(key[:max_len], buffers)
126+
else:
127+
self.mem_manager.free(buffers[tree_node.node_prefix_total_len:])
128+
return super().match_prefix(key, update_refs=update_refs)

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis
109109
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
110110
req.shared_kv_node = None
111111

112-
if self.radix_cache.is_hi_radix_cache:
113-
self.radix_cache.abort_req_store_task(req.req_id)
114112

115113
def _save_promptcache_kvbuffer(self):
116114
"""

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def init_model(self, kvargs):
121121
get_unique_server_name(),
122122
self.model.mem_manager.size,
123123
self.rank_in_node,
124-
mem_manager=self.model.mem_manager,
125-
max_seq_length=kvargs.get("max_seq_length", 1024 * 5),
124+
mem_manager=self.model.mem_manager
126125
)
127126
if self.use_dynamic_prompt_cache and self.use_hi_dynamic_prompt_cache
128127
else RadixCache(
@@ -347,23 +346,6 @@ def _overlap_req_init_and_filter(
347346

348347
return
349348

350-
def _overlap_store_prefill_reqs(self, run_reqs: List[InferReq]):
351-
if run_reqs:
352-
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
353-
if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None:
354-
for req in run_reqs:
355-
if req.cur_output_len > 1:
356-
continue
357-
key = torch.tensor(
358-
req.get_input_token_ids()[0 : req.cur_kv_len], dtype=torch.int64, device="cpu"
359-
)
360-
value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu()
361-
self.radix_cache.insert_disk(req.req_id, key, value)
362-
363-
torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
364-
365-
return
366-
367349
# 一些可以复用的通用功能函数
368350
def _post_init_reqs(self, uninit_reqs: List[InferReq]):
369351
"""

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def decode(self):
4343
self._overlap_req_init_and_filter(
4444
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
4545
)
46-
self._overlap_store_prefill_reqs(run_reqs=run_reqs)
4746
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
4847
next_token_ids = next_token_ids.detach().cpu().numpy()
4948
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()

0 commit comments

Comments
 (0)