-
-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Deepseek-vl2 and Language only inputs (#153)
* initial setup * fix language model loading * add prompt util * language model working * formatting * fix layer norm value * fix deepseek merged features * add support for 27B * add language model only and deepseek-vl-v2 inputs * formatting * remove unused code * remove duplicate * add source * add source * remove unused * bump version * format * add tests * fix deprecation warning (tests)
- Loading branch information
Showing
13 changed files
with
2,490 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .deepseek_vl_v2 import ( | ||
DeepseekVLV2Processor, | ||
LanguageModel, | ||
Model, | ||
ModelConfig, | ||
ProjectorConfig, | ||
TextConfig, | ||
VisionConfig, | ||
VisionModel, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
""" | ||
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py | ||
""" | ||
|
||
import dataclasses | ||
from enum import IntEnum, auto | ||
from typing import Dict, List | ||
|
||
|
||
class SeparatorStyle(IntEnum): | ||
"""Separator styles.""" | ||
|
||
DeepSeek = auto() | ||
DeepSeekV2 = auto() | ||
PLAIN = auto() | ||
ALIGNMENT = auto() | ||
|
||
|
||
@dataclasses.dataclass | ||
class Conversation: | ||
"""A class that manages prompt templates and keeps all conversation history.""" | ||
|
||
# The name of this template | ||
name: str | ||
# The template of the system prompt | ||
system_template: str = "{system_message}" | ||
# The system message | ||
system_message: str = "" | ||
# The names of two roles | ||
roles: List[str] = (("USER", "ASSISTANT"),) | ||
# All messages. Each item is (role, message). | ||
messages: List[List[str]] = () | ||
# The number of few shot examples | ||
offset: int = 0 | ||
# The separator style and configurations | ||
sep_style: SeparatorStyle = SeparatorStyle.DeepSeek | ||
sep: str = "\n" | ||
sep2: str = None | ||
# Stop criteria (the default one is EOS token) | ||
stop_str: str = None | ||
# Stops generation if meeting any token in this list | ||
stop_token_ids: List[int] = None | ||
|
||
def get_prompt(self) -> str: | ||
"""Get the prompt for generation.""" | ||
system_prompt = self.system_template.format(system_message=self.system_message) | ||
if self.sep_style == SeparatorStyle.DeepSeek: | ||
seps = [self.sep, self.sep2] | ||
if system_prompt == "" or system_prompt is None: | ||
ret = "" | ||
else: | ||
ret = system_prompt + seps[0] | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
ret += role + ": " + message + seps[i % 2] | ||
else: | ||
ret += role + ":" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.DeepSeekV2: | ||
seps = [self.sep, self.sep2] | ||
if system_prompt == "" or system_prompt is None: | ||
ret = "" | ||
else: | ||
ret = system_prompt + seps[0] | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
if role == "User": | ||
ret += "<|sft▁begin|>\n" + message + self.sep | ||
else: | ||
ret += message + self.sep2 | ||
else: | ||
ret = ret | ||
return ret | ||
|
||
elif self.sep_style == SeparatorStyle.PLAIN: | ||
seps = [self.sep, self.sep2] | ||
ret = "" | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
if type(message) is tuple: | ||
message, _, _ = message | ||
if i % 2 == 0: | ||
ret += message + seps[i % 2] | ||
else: | ||
ret += message + seps[i % 2] | ||
else: | ||
ret += "" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.ALIGNMENT: | ||
seps = [self.sep, self.sep2] | ||
ret = "" | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
if type(message) is tuple: | ||
message, _, _ = message | ||
if i % 2 == 0: | ||
ret += "<image>\n" + seps[i % 2] | ||
else: | ||
ret += message + seps[i % 2] | ||
else: | ||
ret += "" | ||
return ret | ||
else: | ||
raise ValueError(f"Invalid style: {self.sep_style}") | ||
|
||
def set_system_message(self, system_message: str): | ||
"""Set the system message.""" | ||
self.system_message = system_message | ||
|
||
def append_message(self, role: str, message: str): | ||
"""Append a new message.""" | ||
self.messages.append([role, message]) | ||
|
||
def update_last_message(self, message: str): | ||
"""Update the last output. | ||
The last message is typically set to be None when constructing the prompt, | ||
so we need to update it in-place after getting the response from a model. | ||
""" | ||
self.messages[-1][1] = message | ||
|
||
def reset_message(self): | ||
"""Reset a new message.""" | ||
self.messages = [] | ||
|
||
def to_gradio_chatbot(self): | ||
"""Convert the conversation to gradio chatbot format.""" | ||
ret = [] | ||
for i, (role, msg) in enumerate(self.messages[self.offset :]): | ||
if i % 2 == 0: | ||
ret.append([msg, None]) | ||
else: | ||
ret[-1][-1] = msg | ||
return ret | ||
|
||
def to_openai_api_messages(self): | ||
"""Convert the conversation to OpenAI chat completion format.""" | ||
system_prompt = self.system_template.format(system_message=self.system_message) | ||
ret = [{"role": "system", "content": system_prompt}] | ||
|
||
for i, (_, msg) in enumerate(self.messages[self.offset :]): | ||
if i % 2 == 0: | ||
ret.append({"role": "user", "content": msg}) | ||
else: | ||
if msg is not None: | ||
ret.append({"role": "assistant", "content": msg}) | ||
return ret | ||
|
||
def copy(self): | ||
return Conversation( | ||
name=self.name, | ||
system_template=self.system_template, | ||
system_message=self.system_message, | ||
roles=self.roles, | ||
messages=[[x, y] for x, y in self.messages], | ||
offset=self.offset, | ||
sep_style=self.sep_style, | ||
sep=self.sep, | ||
sep2=self.sep2, | ||
stop_str=self.stop_str, | ||
stop_token_ids=self.stop_token_ids, | ||
) | ||
|
||
def dict(self): | ||
return { | ||
"template_name": self.name, | ||
"system_message": self.system_message, | ||
"roles": self.roles, | ||
"messages": self.messages, | ||
"offset": self.offset, | ||
} | ||
|
||
|
||
# A global registry for all conversation templates | ||
conv_templates: Dict[str, Conversation] = {} | ||
|
||
|
||
def register_conv_template(template: Conversation, override: bool = False): | ||
"""Register a new conversation template.""" | ||
if not override: | ||
assert ( | ||
template.name not in conv_templates | ||
), f"{template.name} has been registered." | ||
|
||
conv_templates[template.name] = template | ||
|
||
|
||
def get_conv_template(name: str) -> Conversation: | ||
"""Get a conversation template.""" | ||
return conv_templates[name].copy() | ||
|
||
|
||
register_conv_template( | ||
Conversation( | ||
name="deepseek", | ||
system_template="{system_message}", | ||
# system_message="You are a helpful assistant. Please answer truthfully and write out your " | ||
# "thinking step by step to be sure you get the right answer.", | ||
system_message="", | ||
roles=("<|User|>", "<|Assistant|>"), | ||
messages=(), | ||
offset=0, | ||
sep_style=SeparatorStyle.DeepSeek, | ||
sep="\n\n", | ||
sep2="<|end▁of▁sentence|>", | ||
stop_token_ids=[100001], | ||
stop_str=["User:", "<|end▁of▁sentence|>"], | ||
) | ||
) | ||
|
||
register_conv_template( | ||
Conversation( | ||
name="deepseekv2", | ||
system_template="{system_message}", | ||
system_message="", | ||
roles=("|<User>|", "|<Assistant>|"), | ||
messages=(), | ||
offset=0, | ||
sep_style=SeparatorStyle.DeepSeekV2, | ||
sep="\n<|sft▁end|>", | ||
sep2="<|end▁of▁sentence|>", | ||
stop_token_ids=[100001], | ||
stop_str=["User:", "<|end▁of▁sentence|>"], | ||
) | ||
) | ||
|
||
|
||
register_conv_template( | ||
Conversation( | ||
name="plain", | ||
system_template="", | ||
system_message="", | ||
roles=("", ""), | ||
messages=(), | ||
offset=0, | ||
sep_style=SeparatorStyle.PLAIN, | ||
sep="", | ||
sep2="", | ||
stop_token_ids=[100001], | ||
stop_str=["</s>"], | ||
) | ||
) | ||
|
||
|
||
register_conv_template( | ||
Conversation( | ||
name="alignment", | ||
system_template="", | ||
system_message="", | ||
roles=("", ""), | ||
messages=(), | ||
offset=0, | ||
sep_style=SeparatorStyle.ALIGNMENT, | ||
sep="", | ||
sep2="", | ||
stop_token_ids=[100001], | ||
stop_str=["</s>"], | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
print("deepseek template:") | ||
conv = get_conv_template("deepseek") |
Oops, something went wrong.