Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Qwen 模型无法正确适配 & 修复同步模式下 openai 流式输出问题 #287

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __pycache__/

# Distribution / packaging
.Python
.idea
build/
develop-eggs/
dist/
Expand Down
193 changes: 103 additions & 90 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ class GPTAPI(BaseAPILLM):
is_api: bool = True

def __init__(
self,
model_type: str = 'gpt-3.5-turbo',
retry: int = 2,
json_mode: bool = False,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system'),
],
api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
**gen_params,
self,
model_type: str = 'gpt-3.5-turbo',
retry: int = 2,
json_mode: bool = False,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system'),
],
api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
**gen_params,
):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning)
Expand Down Expand Up @@ -92,9 +92,9 @@ def __init__(
self.json_mode = json_mode

def chat(
self,
inputs: Union[List[dict], List[List[dict]]],
**gen_params,
self,
inputs: Union[List[dict], List[List[dict]]],
**gen_params,
) -> Union[str, List[str]]:
"""Generate responses given the contexts.

Expand All @@ -119,9 +119,9 @@ def chat(
return ret[0] if isinstance(inputs[0], dict) else ret

def stream_chat(
self,
inputs: List[dict],
**gen_params,
self,
inputs: List[dict],
**gen_params,
):
"""Generate responses given the contexts.

Expand All @@ -146,10 +146,11 @@ def stream_chat(
# mapping to role that openai supports
messages = self.template_parser(inputs)
for text in self._stream_chat(messages, **gen_params):
if self.model_type.lower().startswith('qwen'):
resp = text
else:
resp += text
# if self.model_type.lower().startswith('qwen'):
# resp = text
# else:
# resp += text
resp += text
if not resp:
continue
# remove stop_words
Expand Down Expand Up @@ -262,19 +263,19 @@ def streaming(raw_response):
if decoded[0] == ' ':
decoded = decoded[1:]
else:
print(decoded)
# print(decoded) For debugging using
continue
try:
response = json.loads(decoded)
if 'code' in response and response['code'] == -20003:
# Context exceeds maximum length
yield ''
return
if self.model_type.lower().startswith('qwen'):
choice = response['output']['choices'][0]
yield choice['message']['content']
if choice['finish_reason'] == 'stop':
return
# if self.model_type.lower().startswith('qwen'):
# choice = response['output']['choices'][0]
# yield choice['message']['content']
# if choice['finish_reason'] == 'stop':
# return
else:
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
Expand Down Expand Up @@ -316,7 +317,8 @@ def streaming(raw_response):

response = dict()
try:
raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies)
# To solve the problem of streaming chat with “stream = True” .
raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True)
return streaming(raw_response)
except requests.ConnectionError:
errmsg = 'Got connection error ' + str(traceback.format_exc())
Expand Down Expand Up @@ -383,8 +385,8 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals
gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty')

# Model-specific processing
data = {}
if model_type.lower().startswith('gpt'):
# data = {}
if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning)
gen_params.pop('top_k')
Expand All @@ -397,14 +399,14 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals
data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params}
if json_mode:
data['response_format'] = {'type': 'json_object'}
elif model_type.lower().startswith('qwen'):
header['X-DashScope-SSE'] = 'enable'
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
if 'frequency_penalty' in gen_params:
gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty')
gen_params['result_format'] = 'message'
data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}}
# elif model_type.lower().startswith('qwen'):
# # header['X-DashScope-SSE'] = 'enable'
# gen_params.pop('skip_special_tokens', None)
# gen_params.pop('session_id', None)
# # if 'frequency_penalty' in gen_params:
# # gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty')
# # gen_params['result_format'] = 'message'
# data = {'model': model_type, 'messages': messages, 'parameters': {**gen_params}}
else:
raise NotImplementedError(f'Model type {model_type} is not supported')

Expand Down Expand Up @@ -453,21 +455,21 @@ class AsyncGPTAPI(AsyncBaseAPILLM):
is_api: bool = True

def __init__(
self,
model_type: str = 'gpt-3.5-turbo',
retry: int = 2,
json_mode: bool = False,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system'),
],
api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
**gen_params,
self,
model_type: str = 'gpt-3.5-turbo',
retry: int = 2,
json_mode: bool = False,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system'),
],
api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
**gen_params,
):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning)
Expand Down Expand Up @@ -497,10 +499,10 @@ def __init__(
self.json_mode = json_mode

async def chat(
self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params,
self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params,
) -> Union[str, List[str]]:
"""Generate responses given the contexts.

Expand All @@ -523,9 +525,9 @@ async def chat(
return ret[0] if isinstance(inputs[0], dict) else ret

async def stream_chat(
self,
inputs: List[dict],
**gen_params,
self,
inputs: List[dict],
**gen_params,
):
"""Generate responses given the contexts.

Expand All @@ -550,10 +552,11 @@ async def stream_chat(
# mapping to role that openai supports
messages = self.template_parser(inputs)
async for text in self._stream_chat(messages, **gen_params):
if self.model_type.lower().startswith('qwen'):
resp = text
else:
resp += text
# if self.model_type.lower().startswith('qwen'):
# resp = text
# else:
# resp += text
resp += text
if not resp:
continue
# remove stop_words
Expand Down Expand Up @@ -610,7 +613,8 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:
try:
async with aiohttp.ClientSession() as session:
async with session.post(
self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http'))
self.url, headers=header, json=data,
proxy=self.proxies.get('https', self.proxies.get('http'))
) as resp:
response = await resp.json()
return response['choices'][0]['message']['content'].strip()
Expand Down Expand Up @@ -671,24 +675,24 @@ async def streaming(raw_response):
if decoded[0] == ' ':
decoded = decoded[1:]
else:
print(decoded)
# print(decoded) For debugging using
continue
try:
response = json.loads(decoded)
if 'code' in response and response['code'] == -20003:
# Context exceeds maximum length
yield ''
return
if self.model_type.lower().startswith('qwen'):
choice = response['output']['choices'][0]
yield choice['message']['content']
if choice['finish_reason'] == 'stop':
return
else:
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta'].get('content', '')
# if self.model_type.lower().startswith('qwen'):
# choice = response['output']['choices'][0]
# yield choice['message']['content']
# if choice['finish_reason'] == 'stop':
# return
# else:
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta'].get('content', '')
except Exception as exc:
msg = f'response {decoded} lead to exception of {str(exc)}'
self.logger.error(msg)
Expand Down Expand Up @@ -727,7 +731,8 @@ async def streaming(raw_response):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http'))
self.url, headers=header, json=data,
proxy=self.proxies.get('https', self.proxies.get('http'))
) as raw_response:
async for msg in streaming(raw_response):
yield msg
Expand Down Expand Up @@ -798,7 +803,15 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals

# Model-specific processing
data = {}
if model_type.lower().startswith('gpt'):

# For developer: The Qwen large language model provides an API interface compatible with
# the OpenAI data output format (except for Tongyi Qianwen Audio).
# The original Qwen model adaptation code mistakenly mixed up the DashScope format provided by
# the official source. This has been corrected, and the interface calls for GPT and Qwen have been placed in parallel.
# 通义千问大语言模型有兼容 OpenAI 数据输出格式的 API 的接口(除 通义千问Audio 外),原 Qwen 模型适配代码误将官方提供的 DashScope 格式混淆了。
# 此处进行修改,将 gpt 与 qwen 的接口调用部分并列。

if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning)
gen_params.pop('top_k')
Expand All @@ -812,14 +825,14 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals
data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params}
if json_mode:
data['response_format'] = {'type': 'json_object'}
elif model_type.lower().startswith('qwen'):
header['X-DashScope-SSE'] = 'enable'
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
if 'frequency_penalty' in gen_params:
gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty')
gen_params['result_format'] = 'message'
data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}}
# elif model_type.lower().startswith('qwen'):
# # header['X-DashScope-SSE'] = 'enable'
# gen_params.pop('skip_special_tokens', None)
# gen_params.pop('session_id', None)
# if 'frequency_penalty' in gen_params:
# gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty')
# gen_params['result_format'] = 'message'
# data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}}
elif model_type.lower().startswith('o1'):
data = {'model': model_type, 'messages': messages, 'n': 1}
else:
Expand Down