Skip to content

Default enable_prefix_caching True #3407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
750aaa8
add log
lvhan028 Mar 17, 2025
8886124
Merge branch 'main' into improve-tm-prefix-cache
lvhan028 Mar 18, 2025
7b4304a
refactor tm prefix caching
lvhan028 Mar 24, 2025
8be44f8
refactor tm prefix cache
lvhan028 Mar 25, 2025
dfdde01
Merge branch 'dev' into improve-tm-prefix-cache
lvhan028 Mar 25, 2025
fda1e25
fix linting
lvhan028 Mar 25, 2025
a4ffe41
fix linting
lvhan028 Mar 25, 2025
acf4092
combine Get&Create
lvhan028 Mar 27, 2025
a2352d1
update
lvhan028 Mar 27, 2025
1e940df
clear blocks
lvhan028 Mar 27, 2025
533941d
INFO log to DEBUG log
lvhan028 Mar 28, 2025
91d1412
refactor chat.py
lvhan028 Mar 28, 2025
ce08974
unlock the unmatched blocks when id is reused
lvhan028 Mar 28, 2025
3891782
merge main
lvhan028 Mar 31, 2025
9c3ebc8
remove start_flag and end_flag from tm csrc
lvhan028 Mar 31, 2025
d41683a
update output_logits
lvhan028 Apr 1, 2025
70399b4
update
lvhan028 Apr 1, 2025
1b99728
update
lvhan028 Apr 2, 2025
c5a2962
fix api_client
lvhan028 Apr 2, 2025
499b709
remove interactive chat API
lvhan028 Apr 3, 2025
617d317
fix build error on windows platform
lvhan028 Apr 3, 2025
50e56e2
fix chat
lvhan028 Apr 3, 2025
38ea2ae
update generate.ps1
lvhan028 Apr 3, 2025
e1489a5
fix clang-format error
lvhan028 Apr 3, 2025
9d1df28
fix clang-format error
lvhan028 Apr 3, 2025
e2a0c7a
fix vlm chat error
lvhan028 Apr 4, 2025
604b101
merge main
lvhan028 Apr 4, 2025
5e34425
fix get_logits
lvhan028 Apr 4, 2025
1cbdf5a
remove killing from tm csrc
lvhan028 Apr 4, 2025
afd531d
fix clang-format
lvhan028 Apr 6, 2025
14eb22a
enable_prefix_caching defaults to True
lvhan028 Apr 7, 2025
7e13a18
merge pt chat.py and tm chat.py
lvhan028 Apr 8, 2025
22cf302
remove pt chat.py and tm chat.py
lvhan028 Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
470 changes: 0 additions & 470 deletions autotest/interface/restful/test_restful_chat_func.py

Large diffs are not rendered by default.

50 changes: 0 additions & 50 deletions autotest/utils/run_restful_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
import os
import random
import string
import subprocess
from time import sleep, time

Expand Down Expand Up @@ -145,13 +143,6 @@ def run_all_step(config, cases_info, worker_id: str = '', port: int = DEFAULT_PO
with assume:
assert restful_result, msg

with allure.step(case + ' step3 - restful_test - interactive chat'):
active_result, interactive_log, msg = interactive_test(config, case, case_info, model, http_url, worker_id)
allure.attach.file(interactive_log, attachment_type=allure.attachment_type.TEXT)

with assume:
assert active_result, msg


def open_chat_test(config, case, case_info, model, url, worker_id: str = ''):
log_path = config.get('log_path')
Expand Down Expand Up @@ -190,47 +181,6 @@ def open_chat_test(config, case, case_info, model, url, worker_id: str = ''):
return result, restful_log, msg


def interactive_test(config, case, case_info, model, url, worker_id: str = ''):
log_path = config.get('log_path')

interactive_log = os.path.join(log_path, 'interactive_' + model + worker_id + '_' + case + '.log')

file = open(interactive_log, 'w')

result = True

api_client = APIClient(url)
file.writelines('available_models:' + ','.join(api_client.available_models) + '\n')

# Randomly generate 6 characters and concatenate them into a string.
characters = string.digits
random_chars = ''.join(random.choice(characters) for i in range(6))

messages = []
msg = ''
for prompt_detail in case_info:
prompt = list(prompt_detail.keys())[0]
new_prompt = {'role': 'user', 'content': prompt}
messages.append(new_prompt)
file.writelines('prompt:' + prompt + '\n')

for output in api_client.chat_interactive_v1(prompt=prompt,
interactive_mode=True,
session_id=random_chars,
top_k=1,
request_output_len=256):
output_content = output.get('text')
file.writelines('output:' + output_content + '\n')

case_result, reason = assert_result(output_content, prompt_detail.values(), model)
file.writelines('result:' + str(case_result) + ',reason:' + reason + '\n')
if not case_result:
msg += reason
result = result & case_result
file.close()
return result, interactive_log, msg


def health_check(url):
try:
api_client = APIClient(url)
Expand Down
6 changes: 3 additions & 3 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def parse_args():
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group, default=2048)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
rope_scaling_factor_act = ArgumentHelper.rope_scaling_factor(pt_group)
dtype_act = ArgumentHelper.dtype(pt_group)

Expand Down Expand Up @@ -390,7 +390,7 @@ def main():
session_len=session_len,
rope_scaling_factor=args.rope_scaling_factor,
tp=args.tp,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
dtype=args.dtype,
)
elif args.backend == 'pytorch':
Expand All @@ -400,7 +400,7 @@ def main():
session_len=session_len,
tp=args.tp,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
dtype=args.dtype,
)
gen_config = GenerationConfig(top_k=args.top_k,
Expand Down
6 changes: 3 additions & 3 deletions benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def parse_args():
session_len_act = ArgumentHelper.session_len(pt_group, default=4096)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)

# turbomind engine args
tb_group = parser.add_argument_group('TurboMind engine argument')
Expand Down Expand Up @@ -188,7 +188,7 @@ def main():
quant_policy=args.quant_policy,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
communicator=args.communicator,
)
elif args.backend == 'pytorch':
Expand All @@ -200,7 +200,7 @@ def main():
tp=args.tp,
thread_safe=False,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
)

engine = Engine(args.model_path, engine_config, csv=args.csv)
Expand Down
6 changes: 3 additions & 3 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def parse_args():
session_len_act = ArgumentHelper.session_len(pt_group, default=4096)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
quant_policy_act = ArgumentHelper.quant_policy(pt_group, default=0)
dtype_act = ArgumentHelper.dtype(pt_group)

Expand Down Expand Up @@ -248,7 +248,7 @@ def main():
quant_policy=args.quant_policy,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
dtype=args.dtype,
communicator=args.communicator,
)
Expand All @@ -260,7 +260,7 @@ def main():
max_batch_size=args.concurrency,
tp=args.tp,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
enable_prefix_caching=not args.disable_prefix_caching,
quant_policy=args.quant_policy,
dtype=args.dtype,
distributed_executor_backend=args.distributed_executor_backend,
Expand Down
3 changes: 2 additions & 1 deletion builder/windows/generate.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ cmake .. -A x64 -T "v142,cuda=$env:CUDA_PATH" `
-DBUILD_MULTI_GPU=OFF `
-DCMAKE_CUDA_FLAGS="-lineinfo" `
-DUSE_NVTX=ON `
-DBUILD_TEST="$env:BUILD_TEST"
-DBUILD_TEST="$env:BUILD_TEST" `
-DCMAKE_POLICY_VERSION_MINIMUM="3.5"
22 changes: 0 additions & 22 deletions docs/en/llm/api_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,28 +151,6 @@ for item in api_client.completions_v1(model=model_name, prompt='hi'):
print(item)
```

As for `/v1/chat/interactive`,we disable the feature by default. Please open it by setting `interactive_mode = True`. If you don't, it falls back to openai compatible interfaces.

Keep in mind that `session_id` indicates an identical sequence and all requests belonging to the same sequence must share the same `session_id`.
For instance, in a sequence with 10 rounds of chatting requests, the `session_id` in each request should be the same.

```python
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient(f'http://{server_ip}:{server_port}')
messages = [
"hi, what's your name?",
"who developed you?",
"Tell me more about your developers",
"Summarize the information we've talked so far"
]
for message in messages:
for item in api_client.chat_interactive_v1(prompt=message,
session_id=1,
interactive_mode=True,
stream=False):
print(item)
```

### Tools

May refer to [api_server_tools](./api_server_tools.md).
Expand Down
22 changes: 0 additions & 22 deletions docs/zh_cn/llm/api_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,28 +169,6 @@ for item in api_client.completions_v1(model=model_name, prompt='hi'):
print(item)
```

关于 `/v1/chat/interactive` 接口,我们默认是关闭的。在使用时,请设置`interactive_mode = True`打开它。否则,它会退化为 openai 接口。

在交互式推理中,每个对话序列的 id 必须唯一,所有属于该独立的对话请求,必须使用相同的 id。这里的 id 对应与接口中的 `session_id`。
比如,一个对话序列中,有 10 轮对话请求,那么每轮对话请求中的 `session_id` 都要相同。

```python
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient(f'http://{server_ip}:{server_port}')
messages = [
"hi, what's your name?",
"who developed you?",
"Tell me more about your developers",
"Summarize the information we've talked so far"
]
for message in messages:
for item in api_client.chat_interactive_v1(prompt=message,
session_id=1,
interactive_mode=True,
stream=False):
print(item)
```

### 工具调用

参考 [api_server_tools](./api_server_tools.md)。
Expand Down
90 changes: 90 additions & 0 deletions lmdeploy/cli/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) OpenMMLab. All rights reserved.
import fire

from lmdeploy import ChatTemplateConfig, GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline
from lmdeploy.archs import autoget_backend


def input_prompt():
"""Input a prompt in the consolo interface."""
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))


def build_pipe(model_path, backend, **kwargs):
# set enable_prefix_cache
disable_prefix_cache = kwargs.pop('disable_prefix_cache', False)
kwargs.update(enable_prefix_caching=not disable_prefix_cache)
# set engine config
engine_config = None
if backend == 'turbomind':
engine_config = TurbomindEngineConfig()
for key, value in kwargs.items():
if hasattr(TurbomindEngineConfig, key):
setattr(engine_config, key, value)
else:
engine_config = PytorchEngineConfig()
for key, value in kwargs.items():
if hasattr(PytorchEngineConfig, key):
setattr(engine_config, key, value)
if kwargs.get('adapters', None):
from .utils import get_lora_adapters
adapters = get_lora_adapters(kwargs['adapters'])
engine_config.adapters = adapters
# set chat template config
chat_template = kwargs.get('chat_template', None)
chat_template_config = None
if chat_template:
chat_template_config = ChatTemplateConfig(model_name=chat_template)

pipe = pipeline(model_path, backend_config=engine_config, chat_template_config=chat_template_config, **kwargs)
return pipe


def build_gen_config(**kwargs):
gen_config = GenerationConfig(max_new_tokens=1024, top_k=40, top_p=0.8, temperature=0.8, repetition_penalty=1.0)
for key, value in kwargs.items():
if hasattr(GenerationConfig, key):
setattr(gen_config, key, value)
return gen_config


def main(model_path, backend, **kwargs):
if backend != 'pytorch':
# set auto backend mode
backend = autoget_backend(model_path)

pipe = build_pipe(model_path, backend, **kwargs)
gen_config = build_gen_config(**kwargs)

quit = False
while True:
with pipe.session(gen_config) as sess:
while True:
try:
prompt = input_prompt()
except KeyboardInterrupt:
quit = True
break
if prompt == 'end':
break
if prompt == 'exit':
quit = True
break
resps = sess(prompt)
try:
for resp in resps:
print(resp.text, end='', flush=True)
sess.messages.append(dict(role='assistant', content=resp.text))
except KeyboardInterrupt:
sess.stop()
pass
finally:
print()
if quit:
break


if __name__ == '__main__':
fire.Fire(main)
40 changes: 5 additions & 35 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

from ..version import __version__
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args


class CLI(object):
Expand Down Expand Up @@ -104,7 +104,7 @@ def add_parser_chat():
tp_act = ArgumentHelper.tp(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)

# turbomind args
Expand Down Expand Up @@ -218,39 +218,9 @@ def get_gpu_topo():
@staticmethod
def chat(args):
"""Chat with pytorch or turbomind engine."""
from lmdeploy.archs import autoget_backend

chat_template_config = get_chat_template(args.chat_template)

backend = args.backend
if backend != 'pytorch':
# set auto backend mode
backend = autoget_backend(args.model_path)

if backend == 'pytorch':
from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.chat import run_chat

adapters = get_lora_adapters(args.adapters)
engine_config = PytorchEngineConfig(dtype=args.dtype,
tp=args.tp,
session_len=args.session_len,
cache_max_entry_count=args.cache_max_entry_count,
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
eager_mode=args.eager_mode,
quant_policy=args.quant_policy)
run_chat(args.model_path, engine_config, chat_template_config=chat_template_config)
else:
from lmdeploy.turbomind.chat import main as run_chat
kwargs = convert_args(args)
kwargs.pop('chat_template')
kwargs.pop('backend')
kwargs.pop('device')
kwargs.pop('eager_mode')
kwargs['chat_template_config'] = chat_template_config
run_chat(**kwargs)
from .chat import main
kwargs = convert_args(args)
main(**kwargs)

@staticmethod
def add_parsers():
Expand Down
Loading
Loading