Skip to content

Commit

Permalink
增加对文心一言的支持,支持文心一言的三个主要模型。 (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
XudongLiu authored Nov 17, 2023
1 parent cd2e998 commit e17e77b
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 1 deletion.
2 changes: 2 additions & 0 deletions config_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
"spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
"claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
"ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型


//== Azure ==
Expand Down
5 changes: 5 additions & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def load_config_to_environ(key_list):
claude_api_secret = config.get("claude_api_secret", "")
os.environ["CLAUDE_API_SECRET"] = claude_api_secret

ernie_api_key = config.get("ernie_api_key", "")
os.environ["ERNIE_APIKEY"] = ernie_api_key
ernie_secret_key = config.get("ernie_secret_key", "")
os.environ["ERNIE_SECRETKEY"] = ernie_secret_key

load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])

Expand Down
96 changes: 96 additions & 0 deletions modules/models/ERNIE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from ..presets import *
from ..utils import *

from .base_model import BaseLLMModel


class ERNIE_Client(BaseLLMModel):
def __init__(self, model_name, api_key, secret_key) -> None:
super().__init__(model_name=model_name)
self.api_key = api_key
self.api_secret = secret_key
if None in [self.api_secret, self.api_key]:
raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key")

if self.model_name == "ERNIE-Bot-turbo":
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="
elif self.model_name == "ERNIE-Bot":
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
elif self.model_name == "ERNIE-Bot-4":
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="

def get_access_token(self):
"""
使用 AK,SK 生成鉴权签名(Access Token)
:return: access_token,或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials"

payload = json.dumps("")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}

response = requests.request("POST", url, headers=headers, data=payload)

return response.json()["access_token"]
def get_answer_stream_iter(self):
url = self.ERNIE_url + self.get_access_token()
system_prompt = self.system_prompt
history = self.history
if system_prompt is not None:
history = [construct_system(system_prompt), *history]

# 去除history中 history的role为system的
history = [i for i in history if i["role"] != "system"]

payload = json.dumps({
"messages":history,
"stream": True
})
headers = {
'Content-Type': 'application/json'
}

response = requests.request("POST", url, headers=headers, data=payload, stream=True)

if response.status_code == 200:
partial_text = ""
for line in response.iter_lines():
if len(line) == 0:
continue
line = json.loads(line[5:])
partial_text += line['result']
yield partial_text
else:
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG


def get_answer_at_once(self):
url = self.ERNIE_url + self.get_access_token()
system_prompt = self.system_prompt
history = self.history
if system_prompt is not None:
history = [construct_system(system_prompt), *history]

# 去除history中 history的role为system的
history = [i for i in history if i["role"] != "system"]

payload = json.dumps({
"messages": history,
"stream": True
})
headers = {
'Content-Type': 'application/json'
}

response = requests.request("POST", url, headers=headers, data=payload, stream=True)

if response.status_code == 200:

return str(response.json()["result"]),len(response.json()["result"])
else:
return "获取资源错误", 0


3 changes: 3 additions & 0 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class ModelType(Enum):
Claude = 14
Qwen = 15
OpenAIVision = 16
ERNIE = 17

@classmethod
def get_type(cls, model_name: str):
Expand Down Expand Up @@ -188,6 +189,8 @@ def get_type(cls, model_name: str):
model_type = ModelType.Claude
elif "qwen" in model_name_lower:
model_type = ModelType.Qwen
elif "ernie" in model_name_lower:
model_type = ModelType.ERNIE
else:
model_type = ModelType.LLaMA
return model_type
Expand Down
3 changes: 3 additions & 0 deletions modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def get_model(
elif model_type == ModelType.Qwen:
from .Qwen import Qwen_Client
model = Qwen_Client(model_name, user_name=user_name)
elif model_type == ModelType.ERNIE:
from .ERNIE import ERNIE_Client
model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
elif model_type == ModelType.Unknown:
raise ValueError(f"未知模型: {model_name}")
logging.info(msg)
Expand Down
17 changes: 16 additions & 1 deletion modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@
"讯飞星火大模型V3.0",
"讯飞星火大模型V2.0",
"讯飞星火大模型V1.5",
"Claude"
"Claude",
"ERNIE-Bot-turbo",
"ERNIE-Bot",
"ERNIE-Bot-4",
]

LOCAL_MODELS = [
Expand Down Expand Up @@ -146,6 +149,18 @@
"model_name": "Claude",
"token_limit": 4096,
},
"ERNIE-Bot-turbo": {
"model_name": "ERNIE-Bot-turbo",
"token_limit": 1024,
},
"ERNIE-Bot": {
"model_name": "ERNIE-Bot",
"token_limit": 1024,
},
"ERNIE-Bot-4": {
"model_name": "ERNIE-Bot-4",
"token_limit": 1024,
},
}

if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
Expand Down

0 comments on commit e17e77b

Please sign in to comment.