From 2b35da2758723e0b223650fd1bf2a40db3ac370b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B1=E6=A2=93=E5=92=B8?= Date: Wed, 5 Feb 2025 17:38:17 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0openai=E9=A3=8E=E6=A0=BC?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E7=AB=AF=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lagent/llms/__init__.py | 2 + lagent/llms/openai_style.py | 831 ++++++++++++++++++++++++++++ tests/test_llms/__init__.py | 0 tests/test_llms/test_gptstyleapi.py | 182 ++++++ 4 files changed, 1015 insertions(+) create mode 100644 lagent/llms/openai_style.py create mode 100644 tests/test_llms/__init__.py create mode 100644 tests/test_llms/test_gptstyleapi.py diff --git a/lagent/llms/__init__.py b/lagent/llms/__init__.py index 95679b15..ab69d803 100644 --- a/lagent/llms/__init__.py +++ b/lagent/llms/__init__.py @@ -12,6 +12,7 @@ ) from .meta_template import INTERNLM2_META from .openai import GPTAPI, AsyncGPTAPI +from .openai_style import GPTStyleAPI, AsyncGPTStyleAPI from .sensenova import SensenovaAPI from .vllm_wrapper import AsyncVllmModel, VllmModel @@ -22,6 +23,7 @@ 'BaseAPILLM', 'AsyncGPTAPI', 'GPTAPI', +'GPTStyleAPI', 'AsyncGPTStyleAPI', 'LMDeployClient', 'AsyncLMDeployClient', 'LMDeployPipeline', diff --git a/lagent/llms/openai_style.py b/lagent/llms/openai_style.py new file mode 100644 index 00000000..82a81670 --- /dev/null +++ b/lagent/llms/openai_style.py @@ -0,0 +1,831 @@ +''' +支持openai风格: +已测试通过: + xinference 1.2.0 + ollama 0.5.5 + one-api v0.6.10-alpha.6 + baichuan 直连 + lmdeploy 0.7.0 + +使用方法: +设置api_base、 model_name、key(如果服务端不需要,可忽略) + api_base = 'http://192.168.26.213:13000/v1/chat/completions' # oneapi + model_name = "deepseek-r1:1.5b" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-IXgCTwuoEwxL1CiBE4744688D8094521B70f4aDeE6830c5e", + ) + res = gpttool.chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + +具体参见tests/test_llms/test_gptstyleapi.py +''' +import asyncio +import json +import os +import time +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from threading import Lock +from typing import AsyncGenerator, Dict, List, Optional, Union + +import aiohttp +import requests + +from ..schema import ModelStatusCode +from ..utils import filter_suffix +from .base_api import AsyncBaseAPILLM, BaseAPILLM + +warnings.simplefilter('default') + +OPENAI_STYLE_API_BASE = 'http://192.168.26.213:13000/v1/chat/completions' + +def process_model_params(model_type, messages,gen_params,json_mode): + # Model-specific processing + data = {} + if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'): + if 'top_k' in gen_params: + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) + gen_params.pop('top_k') + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} + if json_mode: + data['response_format'] = {'type': 'json_object'} + elif model_type.lower().startswith('internlm'): + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} + if json_mode: + data['response_format'] = {'type': 'json_object'} + elif model_type.lower().startswith('o1'): + data = {'model': model_type, 'messages': messages, 'n': 1} + else: + data = {'model': model_type, 'messages': messages, **gen_params} + return data + +class GPTStyleAPI(BaseAPILLM): + """Model wrapper around OpenAI's models. + + Args: + model_type (str): The name of OpenAI's model. + retry (int): Number of retires if the API call fails. Defaults to 2. + key (str or List[str]): OpenAI key(s). In particular, when it + is set to "ENV", the key will be fetched from the environment + variable $OPENAI_API_KEY, as how openai defaults to be. If it's a + list, the keys will be used in round-robin manner. Defaults to + 'ENV'. + org (str or List[str], optional): OpenAI organization(s). If not + specified, OpenAI uses the default organization bound to each API + key. If specified, the orgs will be posted with each request in + round-robin manner. Defaults to None. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + api_base (str): The base url of OpenAI's API. Defaults to + 'https://api.openai.com/v1/chat/completions'. + gen_params: Default generation configuration which could be overridden + on the fly of generation. + """ + + is_api: bool = True + + def __init__( + self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='system'), + ], + stream: bool = False, + api_base: str = OPENAI_STYLE_API_BASE, + proxies: Optional[Dict] = None, + **gen_params, + ): + if 'top_k' in gen_params: + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) + gen_params.pop('top_k') + super().__init__(model_type=model_type, meta_template=meta_template, retry=retry, **gen_params) + self.gen_params.pop('top_k') + self.logger = getLogger(__name__) + + if isinstance(key, str): + self.keys = [os.getenv('OPENAI_STYLE_API_KEY') if key == 'ENV' else key] + else: + self.keys = key + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + if isinstance(org, str): + self.orgs = [org] + else: + self.orgs = org + self.org_ctr = 0 + self.url = api_base + self.model_type = model_type + self.proxies = proxies + self.json_mode = json_mode + self.stream = stream + + def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + **gen_params, + ) -> Union[str, List[str]]: + """Generate responses given the contexts. + + Args: + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration + + Returns: + Union[str, List[str]]: generated string(s) + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = {**self.gen_params, **gen_params,"stream":self.stream} + with ThreadPoolExecutor(max_workers=20) as executor: + tasks = [ + executor.submit(self._chat, messages, **gen_params) + for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) + ] + ret = [task.result() for task in tasks] + return ret[0] if isinstance(inputs[0], dict) else ret + + def stream_chat( + self, + inputs: List[dict], + **gen_params, + ): + """Generate responses given the contexts. + + Args: + inputs (List[dict]): a list of messages + gen_params: additional generation configuration + + Returns: + str: generated string + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = self.update_gen_params(**gen_params) + gen_params['stream'] = True + + resp = '' + finished = False + stop_words = gen_params.get('stop_words') + if stop_words is None: + stop_words = [] + # mapping to role that openai supports + # messages = self.template_parser(inputs) + for text in self._stream_chat(inputs, **gen_params): + resp += text + if not resp: + continue + # remove stop_words + for sw in stop_words: + if sw in resp: + resp = filter_suffix(resp, stop_words) + finished = True + break + yield ModelStatusCode.STREAM_ING, resp, None + if finished: + break + yield ModelStatusCode.END, resp, None + + def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + # messages = self.template_parser(messages) + header, data = self.generate_request_data( + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + with Lock(): + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + with Lock(): + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies) + response = raw_response.json() + if "choices" in response: + return response['choices'][0]['message']['content'].strip() + return response['message']['content'].strip() + except requests.ConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except requests.JSONDecodeError: + errmsg = 'JsonDecode error, got ' + str(raw_response.content) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str(response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) + + def _stream_chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + + def streaming(raw_response): + for chunk in raw_response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\n'): + if chunk: + decoded = chunk.decode('utf-8') + if decoded.startswith('data: [DONE]'): + return + if decoded[:5] == 'data:': + decoded = decoded[5:] + if decoded[0] == ' ': + decoded = decoded[1:] + else: + print(decoded) + continue + try: + response = json.loads(decoded) + if 'code' in response and response['code'] == -20003: + # Context exceeds maximum length + yield '' + return + + choice = response['choices'][0] + if choice['finish_reason'] == 'stop': + return + yield choice['delta'].get('content', '') + except Exception as exc: + msg = f'response {decoded} lead to exception of {str(exc)}' + self.logger.error(msg) + raise Exception(msg) from exc + + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True) + return streaming(raw_response) + except requests.ConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except requests.JSONDecodeError: + errmsg = 'JsonDecode error, got ' + str(raw_response.content) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str(response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) + + def generate_request_data(self, model_type, messages, gen_params, json_mode=False): + """ + Generates the request data for different model types. + + Args: + model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). + messages (list): The list of messages to be sent to the model. + gen_params (dict): The generation parameters. + json_mode (bool): Flag to determine if the response format should be JSON. + + Returns: + tuple: A tuple containing the header and the request data. + """ + # Copy generation parameters to avoid modifying the original dictionary + gen_params = gen_params.copy() + + # Hold out 100 tokens due to potential errors in token calculation + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + + # Initialize the header + header = { + 'content-type': 'application/json', + } + + # Common parameters processing + gen_params['max_tokens'] = max_tokens + if 'stop_words' in gen_params: + gen_params['stop'] = gen_params.pop('stop_words') + if 'repetition_penalty' in gen_params: + gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') + + # Model-specific processing + data = process_model_params(model_type, messages, gen_params, json_mode) + + return header, data + + def tokenize(self, prompt: str) -> list: + """Tokenize the input prompt. + + Args: + prompt (str): Input string. + + Returns: + list: token ids + """ + import tiktoken + + self.tiktoken = tiktoken + enc = self.tiktoken.encoding_for_model(self.model_type) + return enc.encode(prompt) + + +class AsyncGPTStyleAPI(AsyncBaseAPILLM): + """Model wrapper around OpenAI's models. + + Args: + model_type (str): The name of OpenAI's model. + retry (int): Number of retires if the API call fails. Defaults to 2. + key (str or List[str]): OpenAI key(s). In particular, when it + is set to "ENV", the key will be fetched from the environment + variable $OPENAI_API_KEY, as how openai defaults to be. If it's a + list, the keys will be used in round-robin manner. Defaults to + 'ENV'. + org (str or List[str], optional): OpenAI organization(s). If not + specified, OpenAI uses the default organization bound to each API + key. If specified, the orgs will be posted with each request in + round-robin manner. Defaults to None. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + api_base (str): The base url of OpenAI's API. Defaults to + 'https://api.openai.com/v1/chat/completions'. + gen_params: Default generation configuration which could be overridden + on the fly of generation. + """ + + is_api: bool = True + + def __init__( + self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='system'), + ], + api_base: str = OPENAI_STYLE_API_BASE, + proxies: Optional[Dict] = None, + **gen_params, + ): + if 'top_k' in gen_params: + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) + gen_params.pop('top_k') + super().__init__(model_type=model_type, meta_template=meta_template, retry=retry, **gen_params) + self.gen_params.pop('top_k') + self.logger = getLogger(__name__) + + if isinstance(key, str): + self.keys = [os.getenv('OPENAI_STYLE_API_KEY') if key == 'ENV' else key] + else: + self.keys = key + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + if isinstance(org, str): + self.orgs = [org] + else: + self.orgs = org + self.org_ctr = 0 + self.url = api_base + self.model_type = model_type + self.proxies = proxies or {} + self.json_mode = json_mode + + async def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + session_ids: Union[int, List[int]] = None, + **gen_params, + ) -> Union[str, List[str]]: + """Generate responses given the contexts. + + Args: + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration + + Returns: + Union[str, List[str]]: generated string(s) + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = {**self.gen_params, **gen_params} + tasks = [ + self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) + ] + ret = await asyncio.gather(*tasks) + return ret[0] if isinstance(inputs[0], dict) else ret + + async def stream_chat( + self, + inputs: List[dict], + **gen_params, + ): + """Generate responses given the contexts. + + Args: + inputs (List[dict]): a list of messages + gen_params: additional generation configuration + + Returns: + str: generated string + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = self.update_gen_params(**gen_params) + gen_params['stream'] = True + + resp = '' + finished = False + stop_words = gen_params.get('stop_words') + if stop_words is None: + stop_words = [] + # mapping to role that openai supports + messages = self.template_parser(inputs) + async for text in self._stream_chat(messages, **gen_params): + resp += text + if not resp: + continue + # remove stop_words + for sw in stop_words: + if sw in resp: + resp = filter_suffix(resp, stop_words) + finished = True + break + yield ModelStatusCode.STREAM_ING, resp, None + if finished: + break + yield ModelStatusCode.END, resp, None + + async def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + messages = self.template_parser(messages) + header, data = self.generate_request_data( + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + async with aiohttp.ClientSession() as session: + async with session.post( + self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http')) + ) as resp: + response = await resp.json() + return response['choices'][0]['message']['content'].strip() + except aiohttp.ClientConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except aiohttp.ClientResponseError as e: + errmsg = 'Response error, got ' + str(e) + self.logger.error(errmsg) + continue + except json.JSONDecodeError: + errmsg = 'JsonDecode error, got ' + (await resp.text(errors='replace')) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str(response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) + + async def _stream_chat(self, messages: List[dict], **gen_params) -> AsyncGenerator[str, None]: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + + async def streaming(raw_response): + async for chunk in raw_response.content: + if chunk: + decoded = chunk.decode('utf-8') + if decoded.startswith('data: [DONE]'): + return + if decoded[:5] == 'data:': + decoded = decoded[5:] + if decoded[0] == ' ': + decoded = decoded[1:] + else: + print(decoded) + continue + try: + response = json.loads(decoded) + if 'code' in response and response['code'] == -20003: + # Context exceeds maximum length + yield '' + return + + choice = response['choices'][0] + if choice['finish_reason'] == 'stop': + return + yield choice['delta'].get('content', '') + except Exception as exc: + msg = f'response {decoded} lead to exception of {str(exc)}' + self.logger.error(msg) + raise Exception(msg) from exc + + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + async with aiohttp.ClientSession() as session: + async with session.post( + self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http')) + ) as raw_response: + async for msg in streaming(raw_response): + yield msg + return + except aiohttp.ClientConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except aiohttp.ClientResponseError as e: + errmsg = 'Response error, got ' + str(e) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str(response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) + + def generate_request_data(self, model_type, messages, gen_params, json_mode=False): + """ + Generates the request data for different model types. + + Args: + model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). + messages (list): The list of messages to be sent to the model. + gen_params (dict): The generation parameters. + json_mode (bool): Flag to determine if the response format should be JSON. + + Returns: + tuple: A tuple containing the header and the request data. + """ + # Copy generation parameters to avoid modifying the original dictionary + gen_params = gen_params.copy() + + # Hold out 100 tokens due to potential errors in token calculation + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + + # Initialize the header + header = { + 'content-type': 'application/json', + } + + # Common parameters processing + gen_params['max_tokens'] = max_tokens + if 'stop_words' in gen_params: + gen_params['stop'] = gen_params.pop('stop_words') + if 'repetition_penalty' in gen_params: + gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') + + data = process_model_params(model_type, messages, gen_params, json_mode) + + + return header, data + + def tokenize(self, prompt: str) -> list: + """Tokenize the input prompt. + + Args: + prompt (str): Input string. + + Returns: + list: token ids + """ + import tiktoken + + self.tiktoken = tiktoken + enc = self.tiktoken.encoding_for_model(self.model_type) + return enc.encode(prompt) + diff --git a/tests/test_llms/__init__.py b/tests/test_llms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_llms/test_gptstyleapi.py b/tests/test_llms/test_gptstyleapi.py new file mode 100644 index 00000000..05ad9630 --- /dev/null +++ b/tests/test_llms/test_gptstyleapi.py @@ -0,0 +1,182 @@ +from lagent.llms import GPTStyleAPI + +def chat_xinfrence(): + api_base = 'http://192.168.26.213:13000/v1/chat/completions' # oneapi + model_name = "deepseek-r1:1.5b" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-IXgCTwuoEwxL1CiBE4744688D8094521B70f4aDeE6830c5e", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + # res = gpttool.generate(inputs="世界第一高峰是") + print(res) +def chat_ollama(): + api_base = 'http://192.168.26.212:11434/api/chat' # ollama + model_name = "qwen:7b" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-IXgCTwuoEwxL1CiBE4744688D8094521B70f4aDeE6830c5e", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + print(res) + +def chat_direct(): + api_base = 'http://192.168.26.213/v1/chat/completions' # 直连 + model_name = "Baichuan2-Turbo" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-IXgCTwuoEwxL1CiBE4744688D8094521B70f4aDeE6830c5e", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + print(res) + +def chat_lmdeploy(): + api_base = 'http://192.168.26.212:24444/v1/chat/completions' # 直连 + model_name = "deepseek-r1:14b" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-IXgCTwuoEwxL1CiBE4744688D8094521B70f4aDeE6830c5e", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + print(res) + +def chat_oneapi(): + api_base = 'http://192.168.26.213:13000/v1/chat/completions' # oneapi + model_name = "deepseek-r1-14b" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-CZOUavQGNzkkQjZr626908A0011040F8B743C526F315D6Ee", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + print(res) + +def stream_chat_ollama(): + api_base = 'http://192.168.26.212:11434/api/chat' # ollama + model_name = "qwen:7b" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-IXgCTwuoEwxL1CiBE4744688D8094521B70f4aDeE6830c5e", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.stream_chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + for status, content, _ in res: + print(content, end='', flush=True) + +def stream_chat_oneapi(): + api_base = 'http://192.168.26.213:13000/v1/chat/completions' # oneapi + model_name = "deepseek-r1-14b" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-CZOUavQGNzkkQjZr626908A0011040F8B743C526F315D6Ee", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.stream_chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + for status, content, _ in res: + print(content, end='', flush=True) + +if __name__ == '__main__': + # chat_xinfrence() + # chat_direct() + # chat_ollama() + # chat_oneapi() + chat_lmdeploy() + + # #流式输出测试 + # stream_chat_ollama() + # stream_chat_oneapi() \ No newline at end of file From 74e5e681e868295d705c1c06ec2bc38abfd5f7b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B1=E6=A2=93=E5=92=B8?= Date: Thu, 6 Feb 2025 16:48:02 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=9B=B4=E6=96=B0stream=5Fchat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lagent/llms/openai_style.py | 5 ++++- lagent/version.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lagent/llms/openai_style.py b/lagent/llms/openai_style.py index 82a81670..6b18bb5e 100644 --- a/lagent/llms/openai_style.py +++ b/lagent/llms/openai_style.py @@ -323,7 +323,10 @@ def streaming(raw_response): choice = response['choices'][0] if choice['finish_reason'] == 'stop': return - yield choice['delta'].get('content', '') + res = choice['delta'].get('content', '') + if res is None or "null"==res: # 处理硅基流动的特例 + res = '' + yield res except Exception as exc: msg = f'response {decoded} lead to exception of {str(exc)}' self.logger.error(msg) diff --git a/lagent/version.py b/lagent/version.py index 01c3552e..925ae4ef 100644 --- a/lagent/version.py +++ b/lagent/version.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -__version__ = '0.5.0rc2' +__version__ = '0.5.0rc3' def parse_version_info(version_str): From e39f088e67891310cd6d58b8b8a67af003f9fd69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B1=E6=A2=93=E5=92=B8?= Date: Thu, 6 Feb 2025 16:48:41 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=A1=85=E5=9F=BA?= =?UTF-8?q?=E6=B5=81=E5=8A=A8api=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_llms/test_gptstyleapi.py | 66 ++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/tests/test_llms/test_gptstyleapi.py b/tests/test_llms/test_gptstyleapi.py index 05ad9630..b373808d 100644 --- a/tests/test_llms/test_gptstyleapi.py +++ b/tests/test_llms/test_gptstyleapi.py @@ -1,4 +1,4 @@ -from lagent.llms import GPTStyleAPI +from lagent.llms import GPTStyleAPI,GPTAPI def chat_xinfrence(): api_base = 'http://192.168.26.213:13000/v1/chat/completions' # oneapi @@ -98,7 +98,8 @@ def chat_lmdeploy(): def chat_oneapi(): api_base = 'http://192.168.26.213:13000/v1/chat/completions' # oneapi - model_name = "deepseek-r1-14b" + # model_name = "deepseek-r1-14b" + model_name = "Baichuan2-Turbo" gpttool = GPTStyleAPI( model_type=model_name, api_base=api_base, @@ -120,6 +121,30 @@ def chat_oneapi(): }]) print(res) +def chat_siliconflow(): + api_base = 'https://api.siliconflow.cn/v1/chat/completions' # oneapi + # model_name = "deepseek-r1-14b" + model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-srirwcmjqmbmyttandxidrtmlfqpxcigyacoabutufvdkkgl", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, + stream=False, + stop_words=None, + ) + res = gpttool.chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + print(res) def stream_chat_ollama(): api_base = 'http://192.168.26.212:11434/api/chat' # ollama model_name = "qwen:7b" @@ -134,7 +159,6 @@ def stream_chat_ollama(): top_k=40, temperature=0.8, repetition_penalty=1, - stream=False, stop_words=None, ) res = gpttool.stream_chat(inputs=[ @@ -147,7 +171,9 @@ def stream_chat_ollama(): def stream_chat_oneapi(): api_base = 'http://192.168.26.213:13000/v1/chat/completions' # oneapi - model_name = "deepseek-r1-14b" + # model_name = "deepseek-r1-14b" + model_name = "Baichuan2-Turbo" + # model_name = "qwen:7b" gpttool = GPTStyleAPI( model_type=model_name, api_base=api_base, @@ -159,7 +185,31 @@ def stream_chat_oneapi(): top_k=40, temperature=0.8, repetition_penalty=1, - stream=False, + stop_words=None, + ) + res = gpttool.stream_chat(inputs=[ + { + "role": "user", + "content": "世界第一高峰是" + }]) + for status, content, _ in res: + print(content, end='', flush=True) + +def stream_chat_siliconflow(): + api_base = 'https://api.siliconflow.cn/v1/chat/completions' # oneapi + # model_name = "deepseek-r1-14b" + model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" + gpttool = GPTStyleAPI( + model_type=model_name, + api_base=api_base, + key="sk-srirwcmjqmbmyttandxidrtmlfqpxcigyacoabutufvdkkgl", + retry=3, + meta_template=None, + max_new_tokens=512, + top_p=0.8, + top_k=40, + temperature=0.8, + repetition_penalty=1, stop_words=None, ) res = gpttool.stream_chat(inputs=[ @@ -175,8 +225,10 @@ def stream_chat_oneapi(): # chat_direct() # chat_ollama() # chat_oneapi() - chat_lmdeploy() + # chat_lmdeploy() + # chat_siliconflow() # #流式输出测试 # stream_chat_ollama() - # stream_chat_oneapi() \ No newline at end of file + # stream_chat_oneapi() + stream_chat_siliconflow() \ No newline at end of file