Skip to content

Commit 5386af0

Browse files
Support agent reference (#546)
1 parent 64ff9da commit 5386af0

File tree

9 files changed

+213
-28
lines changed

9 files changed

+213
-28
lines changed

docs/source/GetStarted/界面训练推理.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ swift web-ui
88

99
web-ui没有传入参数,所有可控部分都在界面中。但是有几个环境变量可以使用:
1010

11-
> WEBUI_SHARE=1 控制gradio是否是share状态
11+
> WEBUI_SHARE=1/0 默认为0 控制gradio是否是share状态
1212
> SWIFT_UI_LANG=en/zh 控制web-ui界面语言
13-
> WEBUI_SERVER server_name参数, web-ui host ip,0.0.0.0代表所有ip均可访问,127.0.0.1代表只允许本机访问
13+
> WEBUI_SERVER server_name参数,web-ui host ip,0.0.0.0代表所有ip均可访问,127.0.0.1代表只允许本机访问
1414
> WEBUI_PORT web-ui的端口号
15+
> USE_INFERENCE=1/0 默认0. 控制gradio的推理页面是直接加载模型推理或者部署(USE_INFERENCE=0)

docs/source/LLM/Agent微调最佳实践.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,150 @@ Begin!
267267

268268
![image-20240201133359457](resources/image-20240201133359457.png)
269269

270+
### 在命令行中使用Agent
271+
272+
目前命令行的Agent推理支持需要指定`--eval_human true`,因为该参数为false的时候会读取数据集内容,此时无法手动传入`Observation:`后面的API调用结果。
273+
274+
```shell
275+
swift infer --model_type chatglm3-6b-32k --eval_human true --stop_words Observation: --infer_backend pt
276+
```
277+
278+
运行命令后,改变system字段:
279+
280+
```shell
281+
# 单行system
282+
<<< reset-system
283+
<<< Answer the following questions as best you can. You have access to the following APIs:\n1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]\n\n2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []\n\n3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []\n\n4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\nBegin!
284+
```
285+
286+
如果需要以多行方式输入,可以用下面的命令(多行信息以#号结束):
287+
288+
```shell
289+
# 多行system
290+
<<< multi-line#
291+
<<<[M] reset-system#
292+
<<<[MS] Answer the following questions as best you can. You have access to the following APIs:
293+
1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]
294+
295+
2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []
296+
297+
3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []
298+
299+
4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []
300+
301+
Use the following format:
302+
303+
Thought: you should always think about what to do
304+
Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]
305+
Action Input: the input to the action
306+
Observation: the result of the action
307+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
308+
Thought: I now know the final answer
309+
Final Answer: the final answer to the original input question
310+
Begin!#
311+
```
312+
313+
下面就可以进行Agent问答:
314+
315+
```shell
316+
<<< 输入图片是/tmp/1.jpg,协助判断图片中是否存在着火点
317+
Thought: I need to use the fire\_recognition API to analyze the input image and determine if there are any signs of fire.
318+
319+
Action: Use the fire\_recognition API to analyze the input image.
320+
321+
Action Input: /tmp/1.jpg
322+
323+
Observation:
324+
<<< [{'coordinate': [101.1, 200.9], 'on_fire': True}]
325+
Thought: The fire\_recognition API has returned a result indicating that there is fire in the input image.
326+
327+
Final Answer: There is fire in the input image.
328+
```
329+
330+
可以看到,模型已经返回了API调用的结果分析。用户可以继续问问题进行多轮Agent场景。也可以指定`--infer_backend vllm`和`--stream true`来使用vllm和流式推理。
331+
332+
### 在部署中使用Agent
333+
334+
由于部署不支持history管理,因此agent的API调用结果拼接需要用户自行进行,下面给出一个OpenAI格式可运行的代码范例。
335+
336+
服务端:
337+
338+
```shell
339+
swift deploy --model_type chatglm3-6b-32k --stop_words Observation:
340+
```
341+
342+
客户端:
343+
344+
```python
345+
from openai import OpenAI
346+
client = OpenAI(
347+
api_key='EMPTY',
348+
base_url='http://localhost:8000/v1',
349+
)
350+
model_type = client.models.list().data[0].id
351+
print(f'model_type: {model_type}')
352+
353+
system = """Answer the following questions as best you can. You have access to the following APIs:
354+
1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{\"name\": \"image\", \"description\": \"The input image to recognize fire\", \"required\": \"True\"}]
355+
356+
2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []
357+
358+
3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []
359+
360+
4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []
361+
362+
Use the following format:
363+
364+
Thought: you should always think about what to do
365+
Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]
366+
Action Input: the input to the action
367+
Observation: the result of the action
368+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
369+
Thought: I now know the final answer
370+
Final Answer: the final answer to the original input question
371+
Begin!"""
372+
messages = [{
373+
'role': 'system',
374+
'content': system
375+
}, {
376+
'role': 'user',
377+
'content': '输入图片是/tmp/1.jpg,协助判断图片中是否存在着火点'
378+
}]
379+
resp = client.chat.completions.create(
380+
model=model_type,
381+
messages=messages,
382+
stop=['Observation:'],
383+
seed=42)
384+
response = resp.choices[0].message.content
385+
print(f'response: {response}')
386+
387+
# # 流式
388+
messages.append({'role': 'assistant', 'content': response + "\n[{'coordinate': [101.1, 200.9], 'on_fire': True}]"})
389+
print(messages)
390+
stream_resp = client.chat.completions.create(
391+
model=model_type,
392+
messages=messages,
393+
stop=['Observation:'],
394+
stream=True,
395+
seed=42)
396+
397+
print('response: ', end='')
398+
for chunk in stream_resp:
399+
print(chunk.choices[0].delta.content, end='', flush=True)
400+
print()
401+
## Output:
402+
# model_type: chatglm3-6b-32k
403+
# response: Thought: I need to check if there is fire in the image
404+
# Action: Use fire\_recognition API
405+
# Action Input: /tmp/2.jpg
406+
# Observation:
407+
# [{'role': 'system', 'content': 'Answer the following questions as best you can. You have access to the following APIs:\n1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]\n\n2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building\'s administraters. Parameters: []\n\n3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []\n\n4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\nBegin!'}, {'role': 'user', 'content': '输入图片是/tmp/2.jpg,协助判断图片中是否存在着火点'}, {'role': 'assistant', 'content': "Thought: I need to check if there is fire in the image\nAction: Use fire\\_recognition API\nAction Input: /tmp/2.jpg\nObservation:\n[{'coordinate': [101.1, 200.9], 'on_fire': True}]"}]
408+
# response:
409+
# Final Answer: There is fire in the image at coordinates [101.1, 200.9]
410+
```
411+
412+
413+
270414
## 总结
271415
272416
通过SWIFT支持的Agent训练能力,我们使用ms-agent和ms-bench对qwen-7b-chat模型进行了微调。可以看到微调后模型保留了通用知识问答能力,并在system字段增加了API的情况下可以正确调用并完成任务。需要注意的是:

swift/llm/deploy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ async def _generate_full():
178178
for output in result.outputs:
179179
choice = ChatCompletionResponseChoice(
180180
index=output.index,
181-
message=ChatMessage(role='assistant', content=output.text),
181+
message=ChatMessage(
182+
role='assistant',
183+
content=template.tokenizer.decode(
184+
output.token_ids, True)),
182185
finish_reason=output.finish_reason,
183186
)
184187
choices.append(choice)
@@ -193,7 +196,7 @@ async def _generate_full():
193196
for output in result.outputs:
194197
choice = CompletionResponseChoice(
195198
index=output.index,
196-
text=output.text,
199+
text=template.tokenizer.decode(output.token_ids, True),
197200
finish_reason=output.finish_reason,
198201
)
199202
choices.append(choice)
@@ -219,7 +222,8 @@ async def _generate_stream():
219222
if isinstance(request, ChatCompletionRequest):
220223
choices = []
221224
for output in result.outputs:
222-
delta_text = output.text[print_idx_list[output.index]:]
225+
text = template.tokenizer.decode(output.token_ids, True)
226+
delta_text = text[print_idx_list[output.index]:]
223227
print_idx_list[output.index] += len(delta_text)
224228
choice = ChatCompletionResponseStreamChoice(
225229
index=output.index,
@@ -236,7 +240,8 @@ async def _generate_stream():
236240
else:
237241
choices = []
238242
for output in result.outputs:
239-
delta_text = output.text[print_idx_list[output.index]:]
243+
text = template.tokenizer.decode(output.token_ids, True)
244+
delta_text = text[print_idx_list[output.index]:]
240245
print_idx_list[output.index] += len(delta_text)
241246
choice = CompletionResponseStreamChoice(
242247
index=output.index,

swift/llm/infer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def llm_infer(args: InferArguments) -> None:
315315
if not template.support_multi_round:
316316
history = []
317317
infer_kwargs = {}
318+
318319
read_media_file(infer_kwargs, args.infer_media_type)
319320
if args.infer_backend == 'vllm':
320321
request_list = [{
@@ -340,6 +341,8 @@ def llm_infer(args: InferArguments) -> None:
340341
new_history = resp_list[0]['history']
341342
print(response)
342343
else:
344+
if args.stop_words:
345+
infer_kwargs['stop_words'] = args.stop_words
343346
template_info = TEMPLATE_MAPPING[args.template_type]
344347
support_stream = template_info.get('support_stream', True)
345348
if args.stream and support_stream:

swift/llm/utils/argument.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ class InferArguments:
607607
top_p: float = 0.7
608608
repetition_penalty: float = 1.
609609
num_beams: int = 1
610+
stop_words: List[str] = None
610611

611612
# other
612613
use_flash_attn: Optional[bool] = None

swift/llm/utils/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def inference_stream(model: PreTrainedModel,
468468

469469
# agent support
470470
is_observation = history[-1][-1].endswith(
471-
'Observation:') if history else False
471+
'Observation:') if history and history[-1][-1] else False
472472
if is_observation:
473473
history[-1][-1] = history[-1][-1] + query
474474
act_length = len(history[-1][-1])
@@ -576,7 +576,7 @@ def inference_stream(model: PreTrainedModel,
576576
) and response[-len(template.suffix[-1]):] == template.suffix[-1]:
577577
response = response[:-len(template.suffix[-1])]
578578
if not is_observation:
579-
history[-1] = (query, response)
579+
history[-1] = [query, response]
580580
else:
581581
history[-1][-1] = history[-1][-1][:act_length] + response
582582
yield response, history
@@ -607,7 +607,7 @@ def inference(model: PreTrainedModel,
607607
history = deepcopy(history)
608608

609609
is_observation = history[-1][-1].endswith(
610-
'Observation:') if history else False
610+
'Observation:') if history and history[-1][-1] else False
611611
if is_observation:
612612
history[-1][-1] = history[-1][-1] + query
613613
query = None
@@ -701,7 +701,7 @@ def inference(model: PreTrainedModel,
701701
) and response[-len(template.suffix[-1]):] == template.suffix[-1]:
702702
response = response[:-len(template.suffix[-1])]
703703
if not is_observation:
704-
history.append((query, response))
704+
history.append([query, response])
705705
else:
706706
history[-1][-1] = history[-1][-1] + response
707707
return response, history

swift/llm/utils/vllm_utils.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
import vllm
9-
from modelscope import GenerationConfig, snapshot_download
9+
from modelscope import GenerationConfig
1010
from packaging import version
1111
from torch import dtype as Dtype
1212
from tqdm import tqdm
@@ -16,7 +16,7 @@
1616

1717
from swift.utils import get_logger, seed_everything
1818
from .argument import InferArguments
19-
from .model import MODEL_MAPPING, get_model_tokenizer
19+
from .model import get_model_tokenizer
2020
from .template import Template, get_template
2121
from .utils import _is_chinese_char
2222

@@ -210,10 +210,22 @@ def inference_stream_vllm(
210210
str) and template.suffix[-1] not in generation_config.stop:
211211
generation_config.stop.append(template.suffix[-1])
212212

213+
request_temp = []
213214
for i, request in enumerate(request_list):
214215
history = request.get('history', None)
215216
if history is None:
216217
history = []
218+
219+
# agent support
220+
is_observation = history[-1][-1].endswith(
221+
'Observation:') if history and history[-1][-1] else False
222+
act_length = None
223+
if is_observation:
224+
history[-1][-1] = history[-1][-1] + request['query']
225+
act_length = len(history[-1][-1])
226+
request['query'] = None
227+
request_temp.append((is_observation, act_length))
228+
217229
request['history'] = history
218230
inputs = template.encode(request)[0]
219231
input_ids = inputs['input_ids']
@@ -240,9 +252,13 @@ def inference_stream_vllm(
240252
safe_response = response[:print_idx_list[i]]
241253
query = request['query']
242254
history = request['history']
243-
if resp_list[i] is None:
255+
if resp_list[i] is None and not request_temp[i][0]:
244256
history.append(None)
245-
history[-1] = (query, safe_response)
257+
if not request_temp[i][0]:
258+
history[-1] = [query, safe_response]
259+
else:
260+
history[-1][
261+
-1] = history[-1][-1][:request_temp[i][1]] + safe_response
246262
resp_list[i] = {'response': safe_response, 'history': history}
247263
if output.finished:
248264
prog_bar.update()
@@ -284,6 +300,12 @@ def inference_vllm(llm_engine: LLMEngine,
284300
history = request.get('history', None)
285301
if history is None:
286302
history = []
303+
304+
is_observation = history[-1][-1].endswith(
305+
'Observation:') if history and history[-1][-1] else False
306+
if is_observation:
307+
history[-1][-1] = history[-1][-1] + request['query']
308+
request['query'] = None
287309
request['history'] = history
288310
inputs = template.encode(request)[0]
289311
input_ids = inputs['input_ids']
@@ -308,7 +330,10 @@ def inference_vllm(llm_engine: LLMEngine,
308330
response = tokenizer.decode(output.outputs[0].token_ids, True)
309331
query = request['query']
310332
history = request['history']
311-
history.append((query, response))
333+
if not is_observation:
334+
history.append([query, response])
335+
else:
336+
history[-1][-1] = history[-1][-1] + response
312337
resp_list[i] = {'response': response, 'history': history}
313338
if verbose:
314339
print(
@@ -355,6 +380,7 @@ def prepare_vllm_engine_template(
355380
temperature=args.temperature,
356381
top_k=args.top_k,
357382
top_p=args.top_p,
383+
stop=args.stop_words,
358384
repetition_penalty=args.repetition_penalty,
359385
num_beams=args.num_beams)
360386
logger.info(f'generation_config: {generation_config}')

0 commit comments

Comments
 (0)