Skip to content

Commit 44bb513

Browse files
authored
gorllia support (#173) (#177)
we support gorllia for tools use in the latest code. #173
2 parents b24d2fe + 7164604 commit 44bb513

File tree

10 files changed

+212
-18
lines changed

10 files changed

+212
-18
lines changed

.env.template

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b
2121
MODEL_SERVER=http://127.0.0.1:8000
2222
LIMIT_MODEL_CONCURRENCY=5
2323
MAX_POSITION_EMBEDDINGS=4096
24-
24+
QUANTIZE_QLORA=True
2525
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
2626
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
2727
# SMART_LLM_MODEL=vicuna-13b
@@ -112,4 +112,4 @@ PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
112112
#*******************************************************************#
113113
# ** SUMMARY_CONFIG
114114
#*******************************************************************#
115-
SUMMARY_CONFIG=FAST
115+
SUMMARY_CONFIG=FAST

pilot/configs/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def __init__(self) -> None:
146146
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
147147
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
148148

149+
# QLoRA
150+
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
151+
149152
### EMBEDDING Configuration
150153
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
151154
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))

pilot/configs/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@
3535
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
3636
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
3737
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
38+
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
39+
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
3840
"proxyllm": "proxyllm",
3941
}
4042

4143
# Load model config
4244
ISLOAD_8BIT = True
4345
ISDEBUG = False
4446

45-
4647
VECTOR_SEARCH_TOP_K = 10
4748
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
4849
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(

pilot/model/adapter.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from functools import cache
4-
from typing import List
5-
6-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
73

4+
import torch
5+
from typing import List
6+
from functools import cache
7+
from transformers import (
8+
AutoModel,
9+
AutoModelForCausalLM,
10+
AutoTokenizer,
11+
LlamaTokenizer,
12+
BitsAndBytesConfig,
13+
)
814
from pilot.configs.model_config import DEVICE
15+
from pilot.configs.config import Config
16+
17+
bnb_config = BitsAndBytesConfig(
18+
load_in_4bit=True,
19+
bnb_4bit_quant_type="nf4",
20+
bnb_4bit_compute_dtype="bfloat16",
21+
bnb_4bit_use_double_quant=False,
22+
)
23+
CFG = Config()
924

1025

1126
class BaseLLMAdaper:
@@ -97,16 +112,44 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
97112
return model, tokenizer
98113

99114

100-
class GuanacoAdapter(BaseLLMAdaper):
115+
class FalconAdapater(BaseLLMAdaper):
116+
"""falcon Adapter"""
117+
118+
def match(self, model_path: str):
119+
return "falcon" in model_path
120+
121+
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
122+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
123+
124+
if CFG.QLoRA:
125+
model = AutoModelForCausalLM.from_pretrained(
126+
model_path,
127+
load_in_4bit=True, # quantize
128+
quantization_config=bnb_config,
129+
device_map={"": 0},
130+
trust_remote_code=True,
131+
**from_pretrained_kwagrs,
132+
)
133+
else:
134+
model = AutoModelForCausalLM.from_pretrained(
135+
model_path,
136+
trust_remote_code=True,
137+
device_map={"": 0},
138+
**from_pretrained_kwagrs,
139+
)
140+
return model, tokenizer
141+
142+
143+
class GorillaAdapter(BaseLLMAdaper):
101144
"""TODO Support guanaco"""
102145

103146
def match(self, model_path: str):
104-
return "guanaco" in model_path
147+
return "gorilla" in model_path
105148

106149
def loader(self, model_path: str, from_pretrained_kwargs: dict):
107-
tokenizer = LlamaTokenizer.from_pretrained(model_path)
150+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
108151
model = AutoModelForCausalLM.from_pretrained(
109-
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
152+
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
110153
)
111154
return model, tokenizer
112155

@@ -166,6 +209,8 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
166209
register_llm_model_adapters(VicunaLLMAdapater)
167210
register_llm_model_adapters(ChatGLMAdapater)
168211
register_llm_model_adapters(GuanacoAdapter)
212+
register_llm_model_adapters(FalconAdapater)
213+
register_llm_model_adapters(GorillaAdapter)
169214
# TODO Default support vicuna, other model need to tests and Evaluate
170215

171216
# just for test, remove this later

pilot/model/llm_out/chatglm_llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def chatglm_generate_stream(
5151
# else:
5252
# once_conversation.append(f"""###system:{message} """)
5353

54-
query = messages[-2].split("human:")[1]
54+
try:
55+
query = messages[-2].split("human:")[1]
56+
except IndexError:
57+
query = messages[-3].split("human:")[1]
5558
print("Query Message: ", query)
5659
# output = ""
5760
# i = 0

pilot/model/llm_out/falcon_llm.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import copy
3+
from threading import Thread
4+
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
5+
6+
7+
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
8+
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
9+
tokenizer.bos_token_id = 1
10+
print(params)
11+
stop = params.get("stop", "###")
12+
prompt = params["prompt"]
13+
query = prompt
14+
print("Query Message: ", query)
15+
16+
input_ids = tokenizer(query, return_tensors="pt").input_ids
17+
input_ids = input_ids.to(model.device)
18+
19+
streamer = TextIteratorStreamer(
20+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
21+
)
22+
23+
tokenizer.bos_token_id = 1
24+
stop_token_ids = [0]
25+
26+
class StopOnTokens(StoppingCriteria):
27+
def __call__(
28+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
29+
) -> bool:
30+
for stop_id in stop_token_ids:
31+
if input_ids[0][-1] == stop_id:
32+
return True
33+
return False
34+
35+
stop = StopOnTokens()
36+
37+
generate_kwargs = dict(
38+
input_ids=input_ids,
39+
max_new_tokens=512,
40+
temperature=1.0,
41+
do_sample=True,
42+
top_k=1,
43+
streamer=streamer,
44+
repetition_penalty=1.7,
45+
stopping_criteria=StoppingCriteriaList([stop]),
46+
)
47+
48+
t = Thread(target=model.generate, kwargs=generate_kwargs)
49+
t.start()
50+
51+
out = ""
52+
for new_text in streamer:
53+
out += new_text
54+
yield out

pilot/model/llm_out/gorilla_llm.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
3+
4+
@torch.inference_mode()
5+
def generate_stream(
6+
model, tokenizer, params, device, context_len=42048, stream_interval=2
7+
):
8+
"""Fork from https://github.com/ShishirPatil/gorilla/blob/main/inference/serve/gorilla_cli.py"""
9+
prompt = params["prompt"]
10+
l_prompt = len(prompt)
11+
max_new_tokens = int(params.get("max_new_tokens", 1024))
12+
stop_str = params.get("stop", None)
13+
14+
input_ids = tokenizer(prompt).input_ids
15+
output_ids = list(input_ids)
16+
input_echo_len = len(input_ids)
17+
max_src_len = context_len - max_new_tokens - 8
18+
input_ids = input_ids[-max_src_len:]
19+
past_key_values = out = None
20+
21+
for i in range(max_new_tokens):
22+
if i == 0:
23+
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
24+
logits = out.logits
25+
past_key_values = out.past_key_values
26+
else:
27+
out = model(
28+
input_ids=torch.as_tensor([[token]], device=device),
29+
use_cache=True,
30+
past_key_values=past_key_values,
31+
)
32+
logits = out.logits
33+
past_key_values = out.past_key_values
34+
35+
last_token_logits = logits[0][-1]
36+
37+
probs = torch.softmax(last_token_logits, dim=-1)
38+
token = int(torch.multinomial(probs, num_samples=1))
39+
output_ids.append(token)
40+
41+
if token == tokenizer.eos_token_id:
42+
stopped = True
43+
else:
44+
stopped = False
45+
46+
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
47+
tmp_output_ids = output_ids[input_echo_len:]
48+
output = tokenizer.decode(
49+
tmp_output_ids,
50+
skip_special_tokens=True,
51+
spaces_between_special_tokens=False,
52+
)
53+
pos = output.rfind(stop_str, l_prompt)
54+
if pos != -1:
55+
output = output[:pos]
56+
stopped = True
57+
yield output
58+
59+
if stopped:
60+
break
61+
62+
del past_key_values

pilot/out_parser/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
6161

6262
# stream out output
6363
output = data["text"][11:].replace("<s>", "").strip()
64+
65+
# TODO gorilla and falcon output
6466
else:
6567
output = data["text"].strip()
6668

pilot/server/chat_adapter.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from functools import cache
55
from typing import List
6-
76
from pilot.model.llm_out.vicuna_base_llm import generate_stream
87

98

@@ -96,6 +95,18 @@ def get_generate_stream_func(self):
9695
return guanaco_generate_stream
9796

9897

98+
class FalconChatAdapter(BaseChatAdpter):
99+
"""Model chat adapter for Guanaco"""
100+
101+
def match(self, model_path: str):
102+
return "falcon" in model_path
103+
104+
def get_generate_stream_func(self):
105+
from pilot.model.llm_out.falcon_llm import falcon_generate_output
106+
107+
return falcon_generate_output
108+
109+
99110
class ProxyllmChatAdapter(BaseChatAdpter):
100111
def match(self, model_path: str):
101112
return "proxyllm" in model_path
@@ -106,10 +117,21 @@ def get_generate_stream_func(self):
106117
return proxyllm_generate_stream
107118

108119

120+
class GorillaChatAdapter(BaseChatAdpter):
121+
def match(self, model_path: str):
122+
return "gorilla" in model_path
123+
124+
def get_generate_stream_func(self):
125+
from pilot.model.llm_out.gorilla_llm import generate_stream
126+
127+
return generate_stream
128+
129+
109130
register_llm_model_chat_adapter(VicunaChatAdapter)
110131
register_llm_model_chat_adapter(ChatGLMChatAdapter)
111132
register_llm_model_chat_adapter(GuanacoChatAdapter)
112-
133+
register_llm_model_chat_adapter(FalconChatAdapter)
134+
register_llm_model_chat_adapter(GorillaChatAdapter)
113135

114136
# Proxy model for test and develop, it's cheap for us now.
115137
register_llm_model_chat_adapter(ProxyllmChatAdapter)

pilot/source_embedding/pdf_embedding.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from langchain.document_loaders import PyPDFLoader
66
from langchain.schema import Document
7+
from langchain.text_splitter import SpacyTextSplitter
78

89
from pilot.configs.config import Config
910
from pilot.source_embedding import SourceEmbedding, register
10-
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
1111

1212
CFG = Config()
1313

@@ -24,10 +24,12 @@ def __init__(self, file_path, vector_store_config):
2424
@register
2525
def read(self):
2626
"""Load from pdf path."""
27-
# loader = UnstructuredPaddlePDFLoader(self.file_path)
2827
loader = PyPDFLoader(self.file_path)
29-
textsplitter = CHNDocumentSplitter(
30-
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
28+
# textsplitter = CHNDocumentSplitter(
29+
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
30+
# )
31+
textsplitter = SpacyTextSplitter(
32+
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
3133
)
3234
return loader.load_and_split(textsplitter)
3335

0 commit comments

Comments
 (0)