|
| 1 | +import torch |
| 2 | +from PIL import Image |
| 3 | +import warnings |
| 4 | +from .base import BaseModel |
| 5 | +from ..smp import splitlen, get_cache_path |
| 6 | +from transformers import AutoTokenizer, AutoConfig |
| 7 | +from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Normalize |
| 8 | +try: |
| 9 | + from torchvision.transforms import InterpolationMode |
| 10 | + BICUBIC = InterpolationMode.BICUBIC |
| 11 | +except ImportError: |
| 12 | + BICUBIC = Image.BICUBIC |
| 13 | + |
| 14 | + |
| 15 | +class AKI(BaseModel): |
| 16 | + INSTALL_REQ = True |
| 17 | + INTERLEAVE = False |
| 18 | + |
| 19 | + def __init__(self, |
| 20 | + name, |
| 21 | + ckpt_pth=None, |
| 22 | + **kwargs): |
| 23 | + |
| 24 | + self.name = name |
| 25 | + try: |
| 26 | + from open_flamingo.src.modeling_aki import AKI |
| 27 | + except: |
| 28 | + raise ImportError('Please first install AKIVLM from https://github.com/sony/aki') |
| 29 | + |
| 30 | + # replace GenerationMixin to modify attention mask handling |
| 31 | + from transformers.generation.utils import GenerationMixin |
| 32 | + from open_flamingo import _aki_update_model_kwargs_for_generation |
| 33 | + GenerationMixin._update_model_kwargs_for_generation = _aki_update_model_kwargs_for_generation |
| 34 | + |
| 35 | + config = AutoConfig.from_pretrained(ckpt_pth) |
| 36 | + tokenizer = AutoTokenizer.from_pretrained(ckpt_pth) |
| 37 | + model = AKI.from_pretrained(ckpt_pth, tokenizer=tokenizer) |
| 38 | + |
| 39 | + n_px = getattr(config, "n_px", 384) |
| 40 | + norm_mean = getattr(config, "norm_mean", 0.5) |
| 41 | + norm_std = getattr(config, "norm_std", 0.5) |
| 42 | + |
| 43 | + image_processor = Compose([ |
| 44 | + Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC, antialias=True), |
| 45 | + Lambda(lambda x: x.convert('RGB')), |
| 46 | + ToTensor(), |
| 47 | + Normalize(mean=(norm_mean, norm_mean, norm_mean), std=(norm_std, norm_std, norm_std)) |
| 48 | + ]) |
| 49 | + self.model = model.eval().cuda() |
| 50 | + |
| 51 | + tokenizer.padding_side = 'left' |
| 52 | + tokenizer.add_eos_token = False |
| 53 | + self.tokenizer = tokenizer |
| 54 | + self.image_proc = image_processor |
| 55 | + |
| 56 | + kwargs_default = { |
| 57 | + 'max_new_tokens': 512, |
| 58 | + 'temperature': 0.0, |
| 59 | + 'do_sample': False, |
| 60 | + 'eos_token_id': tokenizer.eos_token_id, |
| 61 | + } |
| 62 | + kwargs_default.update(kwargs) |
| 63 | + self.kwargs = kwargs_default |
| 64 | + |
| 65 | + def apply_prompt_template(self, query): |
| 66 | + SYSTEM_BASE = "A chat between a curious user and an artificial intelligence assistant." |
| 67 | + SYSTEM_DETAIL = "The assistant gives helpful, detailed, and polite answers to the user's questions." |
| 68 | + SYSTEM_MESSAGE = SYSTEM_BASE + " " + SYSTEM_DETAIL |
| 69 | + SYSTEM_MESSAGE_ROLE = '<|system|>' + '\n' + SYSTEM_MESSAGE + '<|end|>\n' |
| 70 | + |
| 71 | + s = ( |
| 72 | + f'{SYSTEM_MESSAGE_ROLE}' |
| 73 | + f'<|user|>\n{query}<|end|>\n<|assistant|>\n' |
| 74 | + ) |
| 75 | + return s |
| 76 | + |
| 77 | + def generate_inner(self, message, dataset=None): |
| 78 | + vision_x, prompt = [], '' |
| 79 | + for msg in message: |
| 80 | + if msg['type'] == 'image': |
| 81 | + img = Image.open(msg['value']).convert('RGB') |
| 82 | + |
| 83 | + ## [NOTE]: only use the first image in this work if including multiple images in a sample |
| 84 | + if len(vision_x) == 0: |
| 85 | + vision_x.append(self.image_proc(img).unsqueeze(0)) |
| 86 | + prompt += '<image>' |
| 87 | + else: |
| 88 | + warnings.warn('======Only the first image is used in the input.') |
| 89 | + elif msg['type'] == 'text': |
| 90 | + prompt += msg['value'] |
| 91 | + # prompt += f"\nAnswer the question using a single word or phrase. {msg['value']}" # for YorN |
| 92 | + |
| 93 | + vision_x = torch.cat(vision_x, dim=0) if len(vision_x) > 1 else vision_x[0] |
| 94 | + vision_x = vision_x.unsqueeze(1).unsqueeze(0) |
| 95 | + prompt = self.apply_prompt_template(prompt) |
| 96 | + lang_x = self.tokenizer([prompt], return_tensors='pt') |
| 97 | + |
| 98 | + generated_text = self.model.generate( |
| 99 | + vision_x=vision_x.cuda(), |
| 100 | + lang_x=lang_x['input_ids'].cuda(), |
| 101 | + attention_mask=lang_x['attention_mask'].cuda(), |
| 102 | + **self.kwargs) |
| 103 | + generated_text = self.tokenizer.decode(generated_text[0], skip_special_tokens=True) |
| 104 | + return generated_text |
0 commit comments