Skip to content

Commit e847a3f

Browse files
authored
Features: multi llm model support. (#78)
Features: multi llms support. - model_adapter for load multi models - chat_adapter for chat with models.
2 parents a2f3ccc + a537ce0 commit e847a3f

File tree

15 files changed

+463
-72
lines changed

15 files changed

+463
-72
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ lib/
2323
lib64/
2424
parts/
2525
sdist/
26+
models
2627
var/
2728
wheels/
2829
models/

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ Currently, we have released multiple key features, which are listed below to dem
2929
- Unified vector storage/indexing of knowledge base
3030
- Support for unstructured data such as PDF, Markdown, CSV, and WebURL
3131

32+
- Milti LLMs Support
33+
- Supports multiple large language models, currently supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8)
34+
- TODO: codegen2, codet5p
35+
3236

3337
## Demo
3438

@@ -175,6 +179,10 @@ Notice: the webserver need to connect llmserver, so you need change the .env f
175179
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.
176180
- [LLM Practical In Action Series (1) — Combined Langchain-Vicuna Application Practical](https://medium.com/@cfqcsunny/llm-practical-in-action-series-1-combined-langchain-vicuna-application-practical-701cd0413c9f)
177181

182+
### Multi LLMs Usage
183+
184+
To use multiple models, modify the LLM_MODEL parameter in the .env configuration file to switch between the models.
185+
178186
## Acknowledgement
179187

180188
The achievements of this project are thanks to the technical community, especially the following projects:

README.zh.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
2626
- 知识库统一向量存储/索引
2727
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
2828

29+
- 多模型支持
30+
- 支持多种大语言模型, 当前已支持Vicuna(7b,13b), ChatGLM-6b(int4, int8)
31+
- TODO: codet5p, codegen2
32+
2933
## 效果演示
3034

3135
示例通过 RTX 4090 GPU 演示,[YouTube 地址](https://www.youtube.com/watch?v=1PWI6F89LPo)
@@ -178,6 +182,10 @@ $ python webserver.py
178182
2. [大模型实战系列(2) —— DB-GPT 阿里云部署指南](https://zhuanlan.zhihu.com/p/629467580)
179183
3. [大模型实战系列(3) —— DB-GPT插件模型原理与使用](https://zhuanlan.zhihu.com/p/629623125)
180184

185+
186+
### 多模型使用
187+
在.env 配置文件当中, 修改LLM_MODEL参数来切换使用的模型。
188+
181189
## 感谢
182190

183191
项目取得的成果,需要感谢技术社区,尤其以下项目。

examples/embdserver.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
import json
66
import time
77
import uuid
8+
import os
9+
import sys
810
from urllib.parse import urljoin
911
import gradio as gr
12+
13+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14+
sys.path.append(ROOT_PATH)
15+
16+
1017
from pilot.configs.config import Config
1118
from pilot.conversation import conv_qa_prompt_template, conv_templates
1219
from langchain.prompts import PromptTemplate
1320

1421

15-
vicuna_stream_path = "generate_stream"
22+
llmstream_stream_path = "generate_stream"
1623

1724
CFG = Config()
1825

@@ -21,38 +28,44 @@ def generate(query):
2128
template_name = "conv_one_shot"
2229
state = conv_templates[template_name].copy()
2330

24-
pt = PromptTemplate(
25-
template=conv_qa_prompt_template,
26-
input_variables=["context", "question"]
27-
)
31+
# pt = PromptTemplate(
32+
# template=conv_qa_prompt_template,
33+
# input_variables=["context", "question"]
34+
# )
2835

29-
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
30-
question=query)
36+
# result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
37+
# question=query)
3138

32-
print(result)
39+
# print(result)
3340

34-
state.append_message(state.roles[0], result)
41+
state.append_message(state.roles[0], query)
3542
state.append_message(state.roles[1], None)
3643

3744
prompt = state.get_prompt()
3845
params = {
39-
"model": "vicuna-13b",
46+
"model": "chatglm-6b",
4047
"prompt": prompt,
41-
"temperature": 0.7,
48+
"temperature": 1.0,
4249
"max_new_tokens": 1024,
4350
"stop": "###"
4451
}
4552

4653
response = requests.post(
47-
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
54+
url=urljoin(CFG.MODEL_SERVER, llmstream_stream_path), data=json.dumps(params)
4855
)
4956

5057
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
58+
5159
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
60+
5261
if chunk:
5362
data = json.loads(chunk.decode())
5463
if data["error_code"] == 0:
55-
output = data["text"][skip_echo_len:].strip()
64+
65+
if "vicuna" in CFG.LLM_MODEL:
66+
output = data["text"][skip_echo_len:].strip()
67+
else:
68+
output = data["text"].strip()
5669
state.messages[-1][-1] = output + "▌"
5770
yield(output)
5871

pilot/configs/model_config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@
1616

1717
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
1818

19-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19+
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
2020
LLM_MODEL_CONFIG = {
2121
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
2222
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
23+
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
2324
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
24-
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
25+
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
26+
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
27+
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
28+
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
29+
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
2530
}
2631

2732
# Load model config

pilot/conversation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"port": CFG.LOCAL_DB_PORT
1616
}
1717

18+
ROLE_USER = "USER"
19+
ROLE_ASSISTANT = "Assistant"
20+
1821
class SeparatorStyle(Enum):
1922
SINGLE = auto()
2023
TWO = auto()

pilot/model/adapter.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
AutoModel
1010
)
1111

12+
from pilot.configs.model_config import DEVICE
13+
1214
class BaseLLMAdaper:
1315
"""The Base class for multi model, in our project.
1416
We will support those model, which performance resemble ChatGPT """
@@ -61,13 +63,29 @@ class ChatGLMAdapater(BaseLLMAdaper):
6163
"""LLM Adatpter for THUDM/chatglm-6b"""
6264
def match(self, model_path: str):
6365
return "chatglm" in model_path
64-
66+
6567
def loader(self, model_path: str, from_pretrained_kwargs: dict):
6668
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
67-
model = AutoModel.from_pretrained(
68-
model_path, trust_remote_code=True, **from_pretrained_kwargs
69-
).half().cuda()
70-
return model, tokenizer
69+
70+
if DEVICE != "cuda":
71+
model = AutoModel.from_pretrained(
72+
model_path, trust_remote_code=True, **from_pretrained_kwargs
73+
).float()
74+
return model, tokenizer
75+
else:
76+
model = AutoModel.from_pretrained(
77+
model_path, trust_remote_code=True, **from_pretrained_kwargs
78+
).half().cuda()
79+
return model, tokenizer
80+
81+
class CodeGenAdapter(BaseLLMAdaper):
82+
pass
83+
84+
class StarCoderAdapter(BaseLLMAdaper):
85+
pass
86+
87+
class T5CodeAdapter(BaseLLMAdaper):
88+
pass
7189

7290
class KoalaLLMAdapter(BaseLLMAdaper):
7391
"""Koala LLM Adapter which Based LLaMA """
@@ -91,6 +109,7 @@ def match(self, model_path: str):
91109

92110

93111
register_llm_model_adapters(VicunaLLMAdapater)
112+
register_llm_model_adapters(ChatGLMAdapater)
94113
# TODO Default support vicuna, other model need to tests and Evaluate
95114

96115
register_llm_model_adapters(BaseLLMAdaper)

pilot/model/chat.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

pilot/model/chatglm_llm.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python3
2+
# -*- coding:utf-8 -*-
3+
4+
import torch
5+
6+
from pilot.conversation import ROLE_USER, ROLE_ASSISTANT
7+
8+
@torch.inference_mode()
9+
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
10+
11+
"""Generate text using chatglm model's chat api """
12+
prompt = params["prompt"]
13+
temperature = float(params.get("temperature", 1.0))
14+
top_p = float(params.get("top_p", 1.0))
15+
stop = params.get("stop", "###")
16+
echo = params.get("echo", False)
17+
18+
generate_kwargs = {
19+
"do_sample": True if temperature > 1e-5 else False,
20+
"top_p": top_p,
21+
"repetition_penalty": 1.0,
22+
"logits_processor": None
23+
}
24+
25+
if temperature > 1e-5:
26+
generate_kwargs["temperature"] = temperature
27+
28+
# TODO, Fix this
29+
hist = []
30+
31+
messages = prompt.split(stop)
32+
33+
# Add history chat to hist for model.
34+
for i in range(1, len(messages) - 2, 2):
35+
hist.append((messages[i].split(ROLE_USER + ":")[1], messages[i+1].split(ROLE_ASSISTANT + ":")[1]))
36+
37+
query = messages[-2].split(ROLE_USER + ":")[1]
38+
print("Query Message: ", query)
39+
output = ""
40+
i = 0
41+
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):
42+
if echo:
43+
output = query + " " + response
44+
else:
45+
output = response
46+
47+
yield output
48+
49+
yield output

pilot/model/llm/monkey_patch.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#!/usr/bin/env python3
2+
# -*- coding:utf-8 -*-
3+
4+
import math
5+
from typing import Optional, Tuple
6+
7+
import torch
8+
from torch import nn
9+
import transformers
10+
11+
12+
def rotate_half(x):
13+
"""Rotates half the hidden dims of the input."""
14+
x1 = x[..., : x.shape[-1] // 2].clone()
15+
x2 = x[..., x.shape[-1] // 2 :].clone()
16+
return torch.cat((-x2, x1), dim=-1)
17+
18+
19+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
20+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
21+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
22+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
23+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
24+
q_embed = (q * cos) + (rotate_half(q) * sin)
25+
k_embed = (k * cos) + (rotate_half(k) * sin)
26+
return q_embed, k_embed
27+
28+
29+
def forward(
30+
self,
31+
hidden_states: torch.Tensor,
32+
attention_mask: Optional[torch.Tensor] = None,
33+
position_ids: Optional[torch.LongTensor] = None,
34+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
35+
output_attentions: bool = False,
36+
use_cache: bool = False,
37+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
38+
bsz, q_len, _ = hidden_states.size()
39+
40+
query_states = (
41+
self.q_proj(hidden_states)
42+
.view(bsz, q_len, self.num_heads, self.head_dim)
43+
.transpose(1, 2)
44+
)
45+
key_states = (
46+
self.k_proj(hidden_states)
47+
.view(bsz, q_len, self.num_heads, self.head_dim)
48+
.transpose(1, 2)
49+
)
50+
value_states = (
51+
self.v_proj(hidden_states)
52+
.view(bsz, q_len, self.num_heads, self.head_dim)
53+
.transpose(1, 2)
54+
)
55+
56+
kv_seq_len = key_states.shape[-2]
57+
if past_key_value is not None:
58+
kv_seq_len += past_key_value[0].shape[-2]
59+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
60+
query_states, key_states = apply_rotary_pos_emb(
61+
query_states, key_states, cos, sin, position_ids
62+
)
63+
# [bsz, nh, t, hd]
64+
65+
if past_key_value is not None:
66+
# reuse k, v, self_attention
67+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
68+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
69+
70+
past_key_value = (key_states, value_states) if use_cache else None
71+
72+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
73+
self.head_dim
74+
)
75+
76+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
77+
raise ValueError(
78+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
79+
f" {attn_weights.size()}"
80+
)
81+
82+
if attention_mask is not None:
83+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
84+
raise ValueError(
85+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
86+
)
87+
attn_weights = attn_weights + attention_mask
88+
attn_weights = torch.max(
89+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
90+
)
91+
92+
# upcast attention to fp32
93+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
94+
query_states.dtype
95+
)
96+
attn_output = torch.matmul(attn_weights, value_states)
97+
98+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
99+
raise ValueError(
100+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
101+
f" {attn_output.size()}"
102+
)
103+
104+
attn_output = attn_output.transpose(1, 2)
105+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
106+
107+
attn_output = self.o_proj(attn_output)
108+
109+
if not output_attentions:
110+
attn_weights = None
111+
112+
return attn_output, attn_weights, past_key_value
113+
114+
115+
def replace_llama_attn_with_non_inplace_operations():
116+
"""Avoid bugs in mps backend by not using in-place operations."""
117+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
118+
119+
import transformers
120+
121+
122+
123+
def replace_llama_attn_with_non_inplace_operations():
124+
"""Avoid bugs in mps backend by not using in-place operations."""
125+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

0 commit comments

Comments
 (0)