Skip to content

Commit

Permalink
Add support for llava_hf video, better loading logic for llava_hf ckpt (
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 authored Sep 17, 2024
1 parent e20d5d6 commit 9f8d1b4
Showing 1 changed file with 56 additions and 9 deletions.
65 changes: 56 additions & 9 deletions lmms_eval/models/llava_hf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import PIL
import torch
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from decord import VideoReader, cpu
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoProcessor,
LlavaForConditionalGeneration,
LlavaNextForConditionalGeneration,
Expand All @@ -21,10 +25,23 @@
from loguru import logger as eval_logger

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_VIDEO_TOKEN = "<video>"

# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0
VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"

model_map = {
"llava": LlavaForConditionalGeneration,
"llava_next": LlavaNextForConditionalGeneration,
}

try:
from transformers import LlavaOnevisionForConditionalGeneration

model_map["llava_onevision"] = LlavaOnevisionForConditionalGeneration
except Exception as e:
eval_logger.debug("Transformers version does not support llava-onevision. Skipping.")


@register_model("llava_hf")
class LlavaHf(lmms):
Expand Down Expand Up @@ -57,6 +74,7 @@ def __init__(
chat_template: Optional[str] = None,
use_cache: bool = True,
specified_eot_token_id: Optional[int] = None,
max_frames_num: Optional[int] = 32,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -73,13 +91,11 @@ def __init__(
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)

if "1.5" in pretrained:
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
elif "1.6" in pretrained:
self._model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
else:
eval_logger.info("Not sure whether you use 1.5 or 1.6. Use 1.5 by default. This might cause bugs if you are actually using 1.6")
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
config = AutoConfig.from_pretrained(pretrained)
self.max_frames_num = max_frames_num
model_type = getattr(config, "model_type", "llava")
model_type = model_map[model_type]
self._model = model_type.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

self.pretrained = pretrained
self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code)
Expand Down Expand Up @@ -239,6 +255,17 @@ def flatten(self, input):
new_list.append(j)
return new_list

def load_video(self, video_path, max_frames_num):
if type(video_path) == str:
vr = VideoReader(video_path, ctx=cpu(0))
else:
vr = VideoReader(video_path[0], ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames # (frames, height, width, channels)

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []

Expand All @@ -265,6 +292,12 @@ def _collate(x):
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
if len(visuals) == 0:
task_type = "text"
elif isinstance(visuals[0], PIL.Image.Image):
task_type = "image"
elif isinstance(visuals[0], str):
task_type = "video"
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
Expand All @@ -284,7 +317,10 @@ def _collate(x):

# Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt.
if DEFAULT_IMAGE_TOKEN not in context:
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
if task_type == "image":
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
elif task_type == "video":
image_tokens = [DEFAULT_VIDEO_TOKEN] * len(visuals)
image_tokens = " ".join(image_tokens)
context = f"{image_tokens}\n{context}"
# Apply chat template
Expand All @@ -301,7 +337,18 @@ def _collate(x):
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")

inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype)
if task_type == "video":
try:
visuals = [self.load_video(visuals, self.max_frames_num)]
except Exception as e:
res.append("")
eval_logger.info(f"Error {e} when loading video : {visuals}")
pbar.update(1)

if task_type == "image":
inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype)
elif task_type == "video":
inputs = self._image_processor(videos=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype)

gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
Expand Down

0 comments on commit 9f8d1b4

Please sign in to comment.