Skip to content

[Model] Add new model: Prism #622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
def build_model_from_config(cfg):
import vlmeval.api
import vlmeval.vlm
import vlmeval.composite

config = cp.deepcopy(cfg)
assert 'class' in config
cls_name = config.pop('class')
if hasattr(vlmeval.composite, cls_name):
return getattr(vlmeval.composite, cls_name)(supported_VLM, **config)
if hasattr(vlmeval.api, cls_name):
return getattr(vlmeval.api, cls_name)(**config)
elif hasattr(vlmeval.vlm, cls_name):
Expand Down Expand Up @@ -165,6 +169,14 @@ def main():

for _, model_name in enumerate(args.model):
model = None
if use_config:
model = build_model_from_config(cfg['model'][model_name])
if model_name == 'Prism':
fronted_name = cfg['model']['Prism']['model']['fronted']['model']
backend_name = cfg['model']['Prism']['model']['backend']['model']
backend_name = backend_name.replace('/', '-')
model_name = model_name + '_' + fronted_name + '_' + backend_name

date, commit_id = timestr('day'), githash(digits=8)
eval_id = f"T{date}_G{commit_id}"

Expand All @@ -179,9 +191,6 @@ def main():
if not osp.exists(pred_root):
os.makedirs(pred_root, exist_ok=True)

if use_config:
model = build_model_from_config(cfg['model'][model_name])

for _, dataset_name in enumerate(args.data):
try:
result_file_base = f'{model_name}_{dataset_name}.xlsx'
Expand Down Expand Up @@ -281,6 +290,7 @@ def main():
model = model_name # which is only a name

# Perform the Inference

if dataset.MODALITY == 'VIDEO':
model = infer_data_job_video(
model,
Expand Down
7 changes: 7 additions & 0 deletions vlmeval/composite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch

# torch.set_grad_enabled(False)
# torch.manual_seed(1234)


from .prism import Prism
190 changes: 190 additions & 0 deletions vlmeval/composite/prism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import torch
import re
from vlmeval.api import OpenAIWrapper, SiliconFlowAPI
from vlmeval.utils import track_progress_rich
import os


# remap the gpt model name
gpt_version_map = {
'gpt-4-0409': 'gpt-4-turbo-2024-04-09',
'gpt-4-0125': 'gpt-4-0125-preview',
'gpt-4-turbo': 'gpt-4-1106-preview',
'gpt-4-0613': 'gpt-4-0613',
'chatgpt-1106': 'gpt-3.5-turbo-1106',
'chatgpt-0613': 'gpt-3.5-turbo-0613',
'chatgpt-0125': 'gpt-3.5-turbo-0125',
'gpt-4o': 'gpt-4o-2024-05-13'
}

# # map the model name to the api type
# reasoning_mapping = {
# 'llama3-70b-chat':'silicon',
# 'Mixtral-8x22B-chat':'silicon',
# 'deepseek-ai/DeepSeek-V2-Chat':'silicon',
# }
#
# # stop_tokens for deploying vllm
# stop_tokens = {
# 'llama3-70b-chat': ["<|eot_id|>"],
# }

mapping = {}
mapping.update(gpt_version_map)

# mapping.update(reasoning_mapping)

prompt_human1 = ('Describe the fine-grained content of the image, including scenes, objects,'
' relationships, instance location, and any text present.')
prompt_human2 = ('Describe the fine-grained content of the image, including scenes, objects, '
'relationships, instance location, background and any text present. Please skip '
'generating statements for non-existent contents and describe all you see. ')
prompt_gpt1 = 'Given the image below, please provide a detailed description of what you see.'
prompt_gpt2 = 'Analyze the image below and describe the main elements and their relationship.'
prompt_cot = ('Describe the fine-grained content of the image, including scenes, objects, relationships,'
' instance location, and any text present. Let\'s think step by step.')
prompt_decompose = ('Decompose the image into several parts and describe the fine-grained content of the '
'image part by part, including scenes, objects, relationships, instance location, and'
' any text present.')

genric_prompt_mapping = {
'generic':prompt_human1,
'human1':prompt_human1,
'gpt1':prompt_gpt1,
'gpt2':prompt_gpt2,
'human2':prompt_human2,
'cot': prompt_cot,
'decompose': prompt_decompose,
}


class Prism():

def __init__(self, supported_VLM, **kwargs):
self.supported_VLM = supported_VLM
self.config = kwargs

self.model_name_fronted = self.config['model']['fronted']['model']
self.model_name_backend = self.config['model']['backend']['model']
self.fronted_prompt_type = self.config['model']['fronted']['prompt_type']

self.model_fronted = supported_VLM[self.model_name_fronted]() if (
isinstance(self.model_name_fronted, str)) else None
self.model_backend = Reasoning(model=self.model_name_backend)

def set_dump_image(self, dump_image):
if hasattr(self.model_fronted, 'set_dump_image'):
self.model_fronted.set_dump_image(dump_image)

def generate(self, message, dataset=None):

# struct prompt
question = message[1]['value']
prompt_fronted = self.build_fronted_prompt()
message[1]['value'] = prompt_fronted

# generate fronted
is_api = getattr(self.model_fronted, 'is_api', False)
if is_api:
response_fronted = self.fronted_api(message=message, dataset=dataset)
else:
response_fronted = self.model_fronted.generate(message=message, dataset=dataset)

print("----fronted output----\n" + response_fronted + "\n----backend output----")

# generate backend
response_backend = self.model_backend.generate(question, response_fronted)

return response_backend

def fronted_api(self, message, dataset=None):
result = self.model_fronted.generate(message)
# gen_func = self.model_fronted.generate
# struct = {}
# struct['message'] = message
# struct['dataset'] = dataset
# result = track_progress_rich(gen_func, [struct])
return result

def build_fronted_prompt(self):
prompt = genric_prompt_mapping[self.fronted_prompt_type]
return prompt


class Reasoning:
def __init__(self, model):
self.model = LLMWrapper(model)

def generate(self, question, des):
prompt = build_infer_prompt_external(question, des)
return self.model.generate(prompt)


def build_infer_prompt_external(question, des):
if not question.endswith('\n'):
question += '\n'
if not question.lower().startswith('question:') and not question.lower().startswith('hint:'):
question = 'Question: ' + question
if not des.endswith('\n'):
des += '\n'
description = 'Description: ' + des
role = ('You are an excellent text-based reasoning expert. You are required to answer the question'
' based on the detailed description of the image.\n\n')

prompt = role + description + question
return prompt


class LLMWrapper:

def __init__(self, model_name, max_tokens=512, verbose=True, retry=5):

# api bases, openai default
# self.deepseek_api_base = 'https://api.deepseek.com/v1/chat/completions'

# server settings of vllm
# self.PORT = 8080
# self.vllm_api_base = f'http://localhost:{self.PORT}/v1/chat/completions'

self.prism_llm_api_base = os.environ['PRISM_LLM_API_BASE']

if model_name.endswith('-2048'):
model_name = model_name.replace('-2048', '')
max_tokens = 2048

if model_name in gpt_version_map:
gpt_version = gpt_version_map[model_name]
model = OpenAIWrapper(gpt_version, max_tokens=max_tokens, verbose=verbose, retry=retry)
else:
# use your api
api_key = os.environ['PRISM_LLM_API_KEY']
model = SiliconFlowAPI(model_name, api_base=self.prism_llm_api_base, key=api_key,
system_prompt='You are a helpful assistant.', verbose=verbose, retry=retry)
# model = OpenAIWrapper(model_name, api_base=self.prism_llm_api_base, key=api_key,
# max_tokens=max_tokens, system_prompt='You are a helpful assistant.',
# verbose=verbose, retry=retry)

# elif reasoning_mapping[model_name] == 'vllm':
# model = OpenAIWrapper(model_name, api_base=self.vllm_api_base, max_tokens=max_tokens,
# system_prompt='You are a helpful assistant.', verbose=verbose, retry=retry,
# stop=stop_tokens[model_name])
# elif reasoning_mapping[model_name] == 'deepseek':
# deepseek_key = os.environ['SILICON_API_KEY']
# model = OpenAIWrapper(model_name, api_base=self.deepseek_api_base, key=deepseek_key,
# max_tokens=max_tokens, system_prompt='You are a helpful assistant.', verbose=verbose, retry=retry)
#
# else:
# print('Unknown API model for inference')

self.model = model

def generate(self, prompt, **kwargs):
response = self.model.generate(prompt, **kwargs)
return response

@staticmethod
def api_models():
gpt_models = list(gpt_version_map.keys())
api_models = gpt_models.copy()
# api_models.extend(list(reasoning_mapping.keys()))
return api_models
19 changes: 18 additions & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,23 @@
'h2ovl-mississippi-1b': partial(H2OVLChat, model_path='h2oai/h2ovl-mississippi-800m'),
}

prismcationer_series = {
'prismcaptioner-7b': partial(
LLaVA_XTuner_Wrapper,
llm_path='internlm/internlm2-chat-7b',
llava_path='Yuxuan-Qiao/PrismCaptioner-7B',
visual_select_layer=-2,
prompt_template='internlm2_chat',
visual_encoder_path='google/siglip-so400m-patch14-384'),
'prismcaptioner-2b': partial(
LLaVA_XTuner_Wrapper,
llm_path='internlm/internlm2-chat-1_8b',
llava_path='Yuxuan-Qiao/PrismCaptioner-2B',
visual_select_layer=-2,
prompt_template='internlm2_chat',
visual_encoder_path='google/siglip-so400m-patch14-384'),
}

supported_VLM = {}

model_groups = [
Expand All @@ -368,7 +385,7 @@
mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series,
slime_series, eagle_series, moondream_series, llama_series, molmo_series,
kosmos_series, points_series, nvlm_series, vintern_series, h2ovl_series, aria_series,
smolvlm_series
smolvlm_series, prismcationer_series
]

for grp in model_groups:
Expand Down
1 change: 0 additions & 1 deletion vlmeval/dataset/image_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .utils import build_judge, DEBUG_MESSAGE
from ..smp import *
from ..utils import track_progress_rich
import ipdb


class ImageVQADataset(ImageBaseDataset):
Expand Down
1 change: 1 addition & 0 deletions vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@
from .h2ovl_mississippi import H2OVLChat
from .falcon_vlm import Falcon2VLM
from .smolvlm import SmolVLM
from .prismcaptioner import LLaVA_XTuner_Wrapper
Loading