Skip to content

Commit

Permalink
Add support for Deepseek-vl2 and Language only inputs (#153)
Browse files Browse the repository at this point in the history
* initial setup

* fix language model loading

* add prompt util

* language model working

* formatting

* fix layer norm value

* fix deepseek merged features

* add support for 27B

* add language model only and deepseek-vl-v2 inputs

* formatting

* remove unused code

* remove duplicate

* add source

* add source

* remove unused

* bump version

* format

* add tests

* fix deprecation warning (tests)
  • Loading branch information
Blaizzy authored Dec 22, 2024
1 parent 3f5e162 commit 398cb62
Show file tree
Hide file tree
Showing 13 changed files with 2,490 additions and 21 deletions.
2 changes: 1 addition & 1 deletion mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .utils import generate, get_model_path, load, load_config, load_image_processor

DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit"
DEFAULT_IMAGE = ["http://images.cocodataset.org/val2017/000000039769.jpg"]
DEFAULT_IMAGE = []
DEFAULT_PROMPT = "What are these?"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.5
Expand Down
10 changes: 10 additions & 0 deletions mlx_vlm/models/deepseek_vl_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .deepseek_vl_v2 import (
DeepseekVLV2Processor,
LanguageModel,
Model,
ModelConfig,
ProjectorConfig,
TextConfig,
VisionConfig,
VisionModel,
)
264 changes: 264 additions & 0 deletions mlx_vlm/models/deepseek_vl_v2/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
"""

import dataclasses
from enum import IntEnum, auto
from typing import Dict, List


class SeparatorStyle(IntEnum):
"""Separator styles."""

DeepSeek = auto()
DeepSeekV2 = auto()
PLAIN = auto()
ALIGNMENT = auto()


@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""

# The name of this template
name: str
# The template of the system prompt
system_template: str = "{system_message}"
# The system message
system_message: str = ""
# The names of two roles
roles: List[str] = (("USER", "ASSISTANT"),)
# All messages. Each item is (role, message).
messages: List[List[str]] = ()
# The number of few shot examples
offset: int = 0
# The separator style and configurations
sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
sep: str = "\n"
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: str = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None

def get_prompt(self) -> str:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.DeepSeek:
seps = [self.sep, self.sep2]
if system_prompt == "" or system_prompt is None:
ret = ""
else:
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.DeepSeekV2:
seps = [self.sep, self.sep2]
if system_prompt == "" or system_prompt is None:
ret = ""
else:
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if role == "User":
ret += "<|sft▁begin|>\n" + message + self.sep
else:
ret += message + self.sep2
else:
ret = ret
return ret

elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
if i % 2 == 0:
ret += message + seps[i % 2]
else:
ret += message + seps[i % 2]
else:
ret += ""
return ret
elif self.sep_style == SeparatorStyle.ALIGNMENT:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
if i % 2 == 0:
ret += "<image>\n" + seps[i % 2]
else:
ret += message + seps[i % 2]
else:
ret += ""
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message

def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])

def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message

def reset_message(self):
"""Reset a new message."""
self.messages = []

def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
system_prompt = self.system_template.format(system_message=self.system_message)
ret = [{"role": "system", "content": system_prompt}]

for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret

def copy(self):
return Conversation(
name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)

def dict(self):
return {
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
}


# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}


def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert (
template.name not in conv_templates
), f"{template.name} has been registered."

conv_templates[template.name] = template


def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()


register_conv_template(
Conversation(
name="deepseek",
system_template="{system_message}",
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
# "thinking step by step to be sure you get the right answer.",
system_message="",
roles=("<|User|>", "<|Assistant|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.DeepSeek,
sep="\n\n",
sep2="<|end▁of▁sentence|>",
stop_token_ids=[100001],
stop_str=["User:", "<|end▁of▁sentence|>"],
)
)

register_conv_template(
Conversation(
name="deepseekv2",
system_template="{system_message}",
system_message="",
roles=("|<User>|", "|<Assistant>|"),
messages=(),
offset=0,
sep_style=SeparatorStyle.DeepSeekV2,
sep="\n<|sft▁end|>",
sep2="<|end▁of▁sentence|>",
stop_token_ids=[100001],
stop_str=["User:", "<|end▁of▁sentence|>"],
)
)


register_conv_template(
Conversation(
name="plain",
system_template="",
system_message="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="",
sep2="",
stop_token_ids=[100001],
stop_str=["</s>"],
)
)


register_conv_template(
Conversation(
name="alignment",
system_template="",
system_message="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.ALIGNMENT,
sep="",
sep2="",
stop_token_ids=[100001],
stop_str=["</s>"],
)
)


if __name__ == "__main__":
print("deepseek template:")
conv = get_conv_template("deepseek")
Loading

0 comments on commit 398cb62

Please sign in to comment.