Skip to content

Commit

Permalink
Support agent reference (#546)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Mar 13, 2024
1 parent 64ff9da commit 5386af0
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 28 deletions.
5 changes: 3 additions & 2 deletions docs/source/GetStarted/界面训练推理.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ swift web-ui

web-ui没有传入参数,所有可控部分都在界面中。但是有几个环境变量可以使用:

> WEBUI_SHARE=1 控制gradio是否是share状态
> WEBUI_SHARE=1/0 默认为0 控制gradio是否是share状态
> SWIFT_UI_LANG=en/zh 控制web-ui界面语言
> WEBUI_SERVER server_name参数, web-ui host ip,0.0.0.0代表所有ip均可访问,127.0.0.1代表只允许本机访问
> WEBUI_SERVER server_name参数,web-ui host ip,0.0.0.0代表所有ip均可访问,127.0.0.1代表只允许本机访问
> WEBUI_PORT web-ui的端口号
> USE_INFERENCE=1/0 默认0. 控制gradio的推理页面是直接加载模型推理或者部署(USE_INFERENCE=0)
144 changes: 144 additions & 0 deletions docs/source/LLM/Agent微调最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,150 @@ Begin!

![image-20240201133359457](resources/image-20240201133359457.png)

### 在命令行中使用Agent

目前命令行的Agent推理支持需要指定`--eval_human true`,因为该参数为false的时候会读取数据集内容,此时无法手动传入`Observation:`后面的API调用结果。

```shell
swift infer --model_type chatglm3-6b-32k --eval_human true --stop_words Observation: --infer_backend pt
```

运行命令后,改变system字段:

```shell
# 单行system
<<< reset-system
<<< Answer the following questions as best you can. You have access to the following APIs:\n1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]\n\n2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []\n\n3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []\n\n4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\nBegin!
```

如果需要以多行方式输入,可以用下面的命令(多行信息以#号结束):

```shell
# 多行system
<<< multi-line#
<<<[M] reset-system#
<<<[MS] Answer the following questions as best you can. You have access to the following APIs:
1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]

2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []
3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []
4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!#
```
下面就可以进行Agent问答:
```shell
<<< 输入图片是/tmp/1.jpg,协助判断图片中是否存在着火点
Thought: I need to use the fire\_recognition API to analyze the input image and determine if there are any signs of fire.
Action: Use the fire\_recognition API to analyze the input image.
Action Input: /tmp/1.jpg
Observation:
<<< [{'coordinate': [101.1, 200.9], 'on_fire': True}]
Thought: The fire\_recognition API has returned a result indicating that there is fire in the input image.
Final Answer: There is fire in the input image.
```
可以看到,模型已经返回了API调用的结果分析。用户可以继续问问题进行多轮Agent场景。也可以指定`--infer_backend vllm`和`--stream true`来使用vllm和流式推理。
### 在部署中使用Agent
由于部署不支持history管理,因此agent的API调用结果拼接需要用户自行进行,下面给出一个OpenAI格式可运行的代码范例。
服务端:
```shell
swift deploy --model_type chatglm3-6b-32k --stop_words Observation:
```
客户端:
```python
from openai import OpenAI
client = OpenAI(
api_key='EMPTY',
base_url='http://localhost:8000/v1',
)
model_type = client.models.list().data[0].id
print(f'model_type: {model_type}')
system = """Answer the following questions as best you can. You have access to the following APIs:
1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{\"name\": \"image\", \"description\": \"The input image to recognize fire\", \"required\": \"True\"}]
2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []

3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []

4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []

Use the following format:

Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!"""
messages = [{
'role': 'system',
'content': system
}, {
'role': 'user',
'content': '输入图片是/tmp/1.jpg,协助判断图片中是否存在着火点'
}]
resp = client.chat.completions.create(
model=model_type,
messages=messages,
stop=['Observation:'],
seed=42)
response = resp.choices[0].message.content
print(f'response: {response}')
# # 流式
messages.append({'role': 'assistant', 'content': response + "\n[{'coordinate': [101.1, 200.9], 'on_fire': True}]"})
print(messages)
stream_resp = client.chat.completions.create(
model=model_type,
messages=messages,
stop=['Observation:'],
stream=True,
seed=42)
print('response: ', end='')
for chunk in stream_resp:
print(chunk.choices[0].delta.content, end='', flush=True)
print()
## Output:
# model_type: chatglm3-6b-32k
# response: Thought: I need to check if there is fire in the image
# Action: Use fire\_recognition API
# Action Input: /tmp/2.jpg
# Observation:
# [{'role': 'system', 'content': 'Answer the following questions as best you can. You have access to the following APIs:\n1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]\n\n2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building\'s administraters. Parameters: []\n\n3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []\n\n4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\nBegin!'}, {'role': 'user', 'content': '输入图片是/tmp/2.jpg,协助判断图片中是否存在着火点'}, {'role': 'assistant', 'content': "Thought: I need to check if there is fire in the image\nAction: Use fire\\_recognition API\nAction Input: /tmp/2.jpg\nObservation:\n[{'coordinate': [101.1, 200.9], 'on_fire': True}]"}]
# response:
# Final Answer: There is fire in the image at coordinates [101.1, 200.9]
```
## 总结
通过SWIFT支持的Agent训练能力,我们使用ms-agent和ms-bench对qwen-7b-chat模型进行了微调。可以看到微调后模型保留了通用知识问答能力,并在system字段增加了API的情况下可以正确调用并完成任务。需要注意的是:
Expand Down
13 changes: 9 additions & 4 deletions swift/llm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ async def _generate_full():
for output in result.outputs:
choice = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role='assistant', content=output.text),
message=ChatMessage(
role='assistant',
content=template.tokenizer.decode(
output.token_ids, True)),
finish_reason=output.finish_reason,
)
choices.append(choice)
Expand All @@ -193,7 +196,7 @@ async def _generate_full():
for output in result.outputs:
choice = CompletionResponseChoice(
index=output.index,
text=output.text,
text=template.tokenizer.decode(output.token_ids, True),
finish_reason=output.finish_reason,
)
choices.append(choice)
Expand All @@ -219,7 +222,8 @@ async def _generate_stream():
if isinstance(request, ChatCompletionRequest):
choices = []
for output in result.outputs:
delta_text = output.text[print_idx_list[output.index]:]
text = template.tokenizer.decode(output.token_ids, True)
delta_text = text[print_idx_list[output.index]:]
print_idx_list[output.index] += len(delta_text)
choice = ChatCompletionResponseStreamChoice(
index=output.index,
Expand All @@ -236,7 +240,8 @@ async def _generate_stream():
else:
choices = []
for output in result.outputs:
delta_text = output.text[print_idx_list[output.index]:]
text = template.tokenizer.decode(output.token_ids, True)
delta_text = text[print_idx_list[output.index]:]
print_idx_list[output.index] += len(delta_text)
choice = CompletionResponseStreamChoice(
index=output.index,
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def llm_infer(args: InferArguments) -> None:
if not template.support_multi_round:
history = []
infer_kwargs = {}

read_media_file(infer_kwargs, args.infer_media_type)
if args.infer_backend == 'vllm':
request_list = [{
Expand All @@ -340,6 +341,8 @@ def llm_infer(args: InferArguments) -> None:
new_history = resp_list[0]['history']
print(response)
else:
if args.stop_words:
infer_kwargs['stop_words'] = args.stop_words
template_info = TEMPLATE_MAPPING[args.template_type]
support_stream = template_info.get('support_stream', True)
if args.stream and support_stream:
Expand Down
1 change: 1 addition & 0 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ class InferArguments:
top_p: float = 0.7
repetition_penalty: float = 1.
num_beams: int = 1
stop_words: List[str] = None

# other
use_flash_attn: Optional[bool] = None
Expand Down
8 changes: 4 additions & 4 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def inference_stream(model: PreTrainedModel,

# agent support
is_observation = history[-1][-1].endswith(
'Observation:') if history else False
'Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
act_length = len(history[-1][-1])
Expand Down Expand Up @@ -576,7 +576,7 @@ def inference_stream(model: PreTrainedModel,
) and response[-len(template.suffix[-1]):] == template.suffix[-1]:
response = response[:-len(template.suffix[-1])]
if not is_observation:
history[-1] = (query, response)
history[-1] = [query, response]
else:
history[-1][-1] = history[-1][-1][:act_length] + response
yield response, history
Expand Down Expand Up @@ -607,7 +607,7 @@ def inference(model: PreTrainedModel,
history = deepcopy(history)

is_observation = history[-1][-1].endswith(
'Observation:') if history else False
'Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
query = None
Expand Down Expand Up @@ -701,7 +701,7 @@ def inference(model: PreTrainedModel,
) and response[-len(template.suffix[-1]):] == template.suffix[-1]:
response = response[:-len(template.suffix[-1])]
if not is_observation:
history.append((query, response))
history.append([query, response])
else:
history[-1][-1] = history[-1][-1] + response
return response, history
Expand Down
36 changes: 31 additions & 5 deletions swift/llm/utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import vllm
from modelscope import GenerationConfig, snapshot_download
from modelscope import GenerationConfig
from packaging import version
from torch import dtype as Dtype
from tqdm import tqdm
Expand All @@ -16,7 +16,7 @@

from swift.utils import get_logger, seed_everything
from .argument import InferArguments
from .model import MODEL_MAPPING, get_model_tokenizer
from .model import get_model_tokenizer
from .template import Template, get_template
from .utils import _is_chinese_char

Expand Down Expand Up @@ -210,10 +210,22 @@ def inference_stream_vllm(
str) and template.suffix[-1] not in generation_config.stop:
generation_config.stop.append(template.suffix[-1])

request_temp = []
for i, request in enumerate(request_list):
history = request.get('history', None)
if history is None:
history = []

# agent support
is_observation = history[-1][-1].endswith(
'Observation:') if history and history[-1][-1] else False
act_length = None
if is_observation:
history[-1][-1] = history[-1][-1] + request['query']
act_length = len(history[-1][-1])
request['query'] = None
request_temp.append((is_observation, act_length))

request['history'] = history
inputs = template.encode(request)[0]
input_ids = inputs['input_ids']
Expand All @@ -240,9 +252,13 @@ def inference_stream_vllm(
safe_response = response[:print_idx_list[i]]
query = request['query']
history = request['history']
if resp_list[i] is None:
if resp_list[i] is None and not request_temp[i][0]:
history.append(None)
history[-1] = (query, safe_response)
if not request_temp[i][0]:
history[-1] = [query, safe_response]
else:
history[-1][
-1] = history[-1][-1][:request_temp[i][1]] + safe_response
resp_list[i] = {'response': safe_response, 'history': history}
if output.finished:
prog_bar.update()
Expand Down Expand Up @@ -284,6 +300,12 @@ def inference_vllm(llm_engine: LLMEngine,
history = request.get('history', None)
if history is None:
history = []

is_observation = history[-1][-1].endswith(
'Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + request['query']
request['query'] = None
request['history'] = history
inputs = template.encode(request)[0]
input_ids = inputs['input_ids']
Expand All @@ -308,7 +330,10 @@ def inference_vllm(llm_engine: LLMEngine,
response = tokenizer.decode(output.outputs[0].token_ids, True)
query = request['query']
history = request['history']
history.append((query, response))
if not is_observation:
history.append([query, response])
else:
history[-1][-1] = history[-1][-1] + response
resp_list[i] = {'response': response, 'history': history}
if verbose:
print(
Expand Down Expand Up @@ -355,6 +380,7 @@ def prepare_vllm_engine_template(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
stop=args.stop_words,
repetition_penalty=args.repetition_penalty,
num_beams=args.num_beams)
logger.info(f'generation_config: {generation_config}')
Expand Down
Loading

0 comments on commit 5386af0

Please sign in to comment.