Skip to content

Commit 6c85ef7

Browse files
authored
[Model] Add AKI model (#853)
1 parent 4350db5 commit 6c85ef7

File tree

6 files changed

+109
-0
lines changed

6 files changed

+109
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Note that some VLMs may not be able to run under certain transformer versions, w
6262
- **Please use** `transformers==4.36.2` **for**: `Moondream1`.
6363
- **Please use** `transformers==4.37.0` **for**: `LLaVA series`, `ShareGPT4V series`, `TransCore-M`, `LLaVA (XTuner)`, `CogVLM Series`, `EMU2 Series`, `Yi-VL Series`, `MiniCPM-[V1/V2]`, `OmniLMM-12B`, `DeepSeek-VL series`, `InternVL series`, `Cambrian Series`, `VILA Series`, `Llama-3-MixSenseV1_1`, `Parrot-7B`, `PLLaVA Series`.
6464
- **Please use** `transformers==4.40.0` **for**: `IDEFICS2`, `Bunny-Llama3`, `MiniCPM-Llama3-V2.5`, `360VL-70B`, `Phi-3-Vision`, `WeMM`.
65+
- **Please use** `transformers==4.42.0` **for**: `AKI`.
6566
- **Please use** `transformers==4.44.0` **for**: `Moondream2`, `H2OVL series`.
6667
- **Please use** `transformers==4.45.0` **for**: `Aria`.
6768
- **Please use** `transformers==latest` **for**: `LLaVA-Next series`, `PaliGemma-3B`, `Chameleon series`, `Video-LLaVA-7B-HF`, `Ovis series`, `Mantis series`, `MiniCPM-V2.6`, `OmChat-v2.0-13B-sinlge-beta`, `Idefics-3`, `GLM-4v-9B`, `VideoChat2-HD`, `RBDash_72b`, `Llama-3.2 series`, `Kosmos series`.

docs/ja/README_ja.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ PS: 日本語の README には最新のアップデートがすべて含まれ
4242
- **`transformers==4.33.0`を使用してください**: `Qwenシリーズ`, `Monkeyシリーズ`, `InternLM-XComposerシリーズ`, `mPLUG-Owl2`, `OpenFlamingo v2`, `IDEFICSシリーズ`, `VisualGLM`, `MMAlaya`, `ShareCaptioner`, `MiniGPT-4シリーズ`, `InstructBLIPシリーズ`, `PandaGPT`, `VXVERSE`, `GLM-4v-9B`.
4343
- **`transformers==4.37.0`を使用してください**: `LLaVAシリーズ`, `ShareGPT4Vシリーズ`, `TransCore-M`, `LLaVA (XTuner)`, `CogVLMシリーズ`, `EMU2シリーズ`, `Yi-VLシリーズ`, `MiniCPM-[V1/V2]`, `OmniLMM-12B`, `DeepSeek-VLシリーズ`, `InternVLシリーズ`, `Cambrianシリーズ`, `VILA-VLシリーズ`.
4444
- **`transformers==4.40.0`を使用してください**: `IDEFICS2`, `Bunny-Llama3`, `MiniCPM-Llama3-V2.5`, `360VL-70B`, `Phi-3-Vision`, `WeMM`.
45+
- **`transformers==4.42.0`を使用してください**: `AKI`.
4546
- **`transformers==latest`を使用してください**: `LLaVA-Nextシリーズ`, `PaliGemma-3B`, `Chameleon-VLシリーズ`, `Video-LLaVA-7B-HF`, `Ovis1.5シリーズ`, `Mantisシリーズ`, `MiniCPM-V2.6`.
4647

4748
```python

docs/zh-CN/README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
- **请用** `transformers==4.33.0` **来运行**: `Qwen series`, `Monkey series`, `InternLM-XComposer Series`, `mPLUG-Owl2`, `OpenFlamingo v2`, `IDEFICS series`, `VisualGLM`, `MMAlaya`, `ShareCaptioner`, `MiniGPT-4 series`, `InstructBLIP series`, `PandaGPT`, `VXVERSE`.
6060
- **请用** `transformers==4.37.0 ` **来运行**: `LLaVA series`, `ShareGPT4V series`, `TransCore-M`, `LLaVA (XTuner)`, `CogVLM Series`, `EMU2 Series`, `Yi-VL Series`, `MiniCPM-[V1/V2]`, `OmniLMM-12B`, `DeepSeek-VL series`, `InternVL series`, `Cambrian Series`, `VILA Series`, `Llama-3-MixSenseV1_1`, `Parrot-7B`, `PLLaVA Series`.
6161
- **请用** `transformers==4.40.0 ` **来运行**: `IDEFICS2`, `Bunny-Llama3`, `MiniCPM-Llama3-V2.5`, `360VL-70B`, `Phi-3-Vision`, `WeMM`.
62+
- **请用** `transformers==4.42.0 ` **来运行**: `AKI`.
6263
- **请用** `transformers==latest` **来运行**: `LLaVA-Next series`, `PaliGemma-3B`, `Chameleon series`, `Video-LLaVA-7B-HF`, `Ovis series`, `Mantis series`, `MiniCPM-V2.6`, `OmChat-v2.0-13B-sinlge-beta`, `Idefics-3`, `GLM-4v-9B`, `VideoChat2-HD`.
6364

6465
**如何测试一个 VLM 是否可以正常运行:**

vlmeval/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
}
4848

4949
ungrouped = {
50+
"AKI": partial(AKI, name="AKI", ckpt_pth="Sony/AKI-4B-phi-3.5-mini"),
5051
"TransCore_M": partial(TransCoreM, root=TransCore_ROOT),
5152
"PandaGPT_13B": partial(PandaGPT, name="PandaGPT_13B", root=PandaGPT_ROOT),
5253
"flamingov2": partial(

vlmeval/vlm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,4 @@
9090
from .ola import Ola
9191
from .ursa import UrsaChat
9292
from .vlm_r1 import VLMR1Chat
93+
from .aki import AKI

vlmeval/vlm/aki.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
from PIL import Image
3+
import warnings
4+
from .base import BaseModel
5+
from ..smp import splitlen, get_cache_path
6+
from transformers import AutoTokenizer, AutoConfig
7+
from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Normalize
8+
try:
9+
from torchvision.transforms import InterpolationMode
10+
BICUBIC = InterpolationMode.BICUBIC
11+
except ImportError:
12+
BICUBIC = Image.BICUBIC
13+
14+
15+
class AKI(BaseModel):
16+
INSTALL_REQ = True
17+
INTERLEAVE = False
18+
19+
def __init__(self,
20+
name,
21+
ckpt_pth=None,
22+
**kwargs):
23+
24+
self.name = name
25+
try:
26+
from open_flamingo.src.modeling_aki import AKI
27+
except:
28+
raise ImportError('Please first install AKIVLM from https://github.com/sony/aki')
29+
30+
# replace GenerationMixin to modify attention mask handling
31+
from transformers.generation.utils import GenerationMixin
32+
from open_flamingo import _aki_update_model_kwargs_for_generation
33+
GenerationMixin._update_model_kwargs_for_generation = _aki_update_model_kwargs_for_generation
34+
35+
config = AutoConfig.from_pretrained(ckpt_pth)
36+
tokenizer = AutoTokenizer.from_pretrained(ckpt_pth)
37+
model = AKI.from_pretrained(ckpt_pth, tokenizer=tokenizer)
38+
39+
n_px = getattr(config, "n_px", 384)
40+
norm_mean = getattr(config, "norm_mean", 0.5)
41+
norm_std = getattr(config, "norm_std", 0.5)
42+
43+
image_processor = Compose([
44+
Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC, antialias=True),
45+
Lambda(lambda x: x.convert('RGB')),
46+
ToTensor(),
47+
Normalize(mean=(norm_mean, norm_mean, norm_mean), std=(norm_std, norm_std, norm_std))
48+
])
49+
self.model = model.eval().cuda()
50+
51+
tokenizer.padding_side = 'left'
52+
tokenizer.add_eos_token = False
53+
self.tokenizer = tokenizer
54+
self.image_proc = image_processor
55+
56+
kwargs_default = {
57+
'max_new_tokens': 512,
58+
'temperature': 0.0,
59+
'do_sample': False,
60+
'eos_token_id': tokenizer.eos_token_id,
61+
}
62+
kwargs_default.update(kwargs)
63+
self.kwargs = kwargs_default
64+
65+
def apply_prompt_template(self, query):
66+
SYSTEM_BASE = "A chat between a curious user and an artificial intelligence assistant."
67+
SYSTEM_DETAIL = "The assistant gives helpful, detailed, and polite answers to the user's questions."
68+
SYSTEM_MESSAGE = SYSTEM_BASE + " " + SYSTEM_DETAIL
69+
SYSTEM_MESSAGE_ROLE = '<|system|>' + '\n' + SYSTEM_MESSAGE + '<|end|>\n'
70+
71+
s = (
72+
f'{SYSTEM_MESSAGE_ROLE}'
73+
f'<|user|>\n{query}<|end|>\n<|assistant|>\n'
74+
)
75+
return s
76+
77+
def generate_inner(self, message, dataset=None):
78+
vision_x, prompt = [], ''
79+
for msg in message:
80+
if msg['type'] == 'image':
81+
img = Image.open(msg['value']).convert('RGB')
82+
83+
## [NOTE]: only use the first image in this work if including multiple images in a sample
84+
if len(vision_x) == 0:
85+
vision_x.append(self.image_proc(img).unsqueeze(0))
86+
prompt += '<image>'
87+
else:
88+
warnings.warn('======Only the first image is used in the input.')
89+
elif msg['type'] == 'text':
90+
prompt += msg['value']
91+
# prompt += f"\nAnswer the question using a single word or phrase. {msg['value']}" # for YorN
92+
93+
vision_x = torch.cat(vision_x, dim=0) if len(vision_x) > 1 else vision_x[0]
94+
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
95+
prompt = self.apply_prompt_template(prompt)
96+
lang_x = self.tokenizer([prompt], return_tensors='pt')
97+
98+
generated_text = self.model.generate(
99+
vision_x=vision_x.cuda(),
100+
lang_x=lang_x['input_ids'].cuda(),
101+
attention_mask=lang_x['attention_mask'].cuda(),
102+
**self.kwargs)
103+
generated_text = self.tokenizer.decode(generated_text[0], skip_special_tokens=True)
104+
return generated_text

0 commit comments

Comments
 (0)