Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaVA-Video-7B reproduce and Qwen2-VL-7B result #11

Open
xiaoqian-shen opened this issue Feb 21, 2025 · 6 comments
Open

LLaVA-Video-7B reproduce and Qwen2-VL-7B result #11

xiaoqian-shen opened this issue Feb 21, 2025 · 6 comments

Comments

@xiaoqian-shen
Copy link

xiaoqian-shen commented Feb 21, 2025

Hi, thanks for sharing the code for your excellent work.

I’m curious why the results for Qwen2-VL-7B+Video-RAG aren’t reported on the VideoMME or MLVU benchmarks.

I tried using this repository and replaced the base model with Qwen2.5-VL-7B, but the results were not good.
Qwen2.5-VL-7B+Video-RAG
{'count': 21.844660194174757, 'ego': 50.85227272727273, 'findNeedle': 63.38028169014085, 'order': 47.87644787644788, 'plotQA': 63.63636363636363, 'anomaly_reco': 66.0, 'topic_reasoning': 86.31178707224335, 'Acc': 58.647654093836245}

I also tried to reproduce the result of LLaVA-Video-7B+Video-RAG, but it only achieves 67.29 while the performance reported in paper is 72.4.
{'plotQA': 72.35621521335807, 'ego': 60.22727272727273, 'findNeedle': 76.61971830985915, 'anomaly_reco': 71.5, 'topic_reasoning': 87.83269961977186, 'count': 38.83495145631068, 'order': 52.123552123552116, 'Acc': 67.2953081876725}

Here is the code for LLaVA-Video-7B+Video-RAG, I have made no changes to the official code and only write a dataloader.

from PIL import Image
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
import torch
from transformers import AutoProcessor, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel
import copy
from decord import VideoReader, cpu
import numpy as np
import json
from tqdm import tqdm
import os
import easyocr
from rag_retriever_dynamic import retrieve_documents_with_dynamic
import re
import ast
import socket
import pickle
from filter_keywords import filter_keywords
from scene_graph import generate_scene_graph_description
import ffmpeg, torchaudio
from tqdm import tqdm
from transformers.trainer_pt_utils import IterableDatasetShard
import random

max_frames_num = 32
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336", torch_dtype=torch.float16, device_map="auto")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
whisper_model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-large",
    torch_dtype=torch.float16,
    device_map="auto"
)
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large")

class EvalDatasetMLVU(torch.utils.data.IterableDataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path: str = "evaluation/mlvu",
    ) -> None:
        super(EvalDatasetMLVU, self).__init__()

        self.data_path = data_path

        data_list = {
            "count": ("json/4_count.json", f"video/4_count", "video"), # 206
            "ego": ("json/3_ego.json", f"video/3_ego", "video"), # 357
            "findNeedle": ("json/2_needle.json", f"video/2_needle", "video"), # 355
            "order": ("json/5_order.json", f"video/5_order", "video"), # 259
            "plotQA": ("json/1_plotQA.json", f"video/1_plotQA", "video"),
            "anomaly_reco": (
                "json/6_anomaly_reco.json",
                f"video/6_anomaly_reco",
                "video",
            ),
            "topic_reasoning": (
                "json/7_topic_reasoning.json",
                f"video/7_topic_reasoning",
                "video",
            ),
        }

        list_data_dict = []
        for k, v in data_list.items():
            with open(os.path.join(data_path, v[0]), "r") as f:
                json_data = json.load(f)
            for data in json_data:
                question, answer = self.qa_template(data)
                list_data_dict.append(
                    {
                        "task_type": k,
                        "video_path": os.path.join(self.data_path, v[1], data["video"]),
                        "video_name": data["video"],
                        "question": data["question"],
                        "candidates": data["candidates"],
                        "prompt": question,
                        "answer": answer,
                        "duration": data["duration"],
                    }
                )

        random.shuffle(list_data_dict)
        self.data = list_data_dict

    def __len__(self) -> int:
        return len(self.data)

    # pyre-fixme[3]: Return type must be annotated.
    # pyre-fixme[2]: Parameter must be annotated.
    def qa_template(self, data):
        question = f"Question: {data['question']}\n"
        question += "Options:\n"
        answer = data["answer"]
        answer_idx = -1
        for idx, c in enumerate(data["candidates"]):
            question += f"({chr(ord('A') + idx)}) {c}\n"
            if c == answer:
                answer_idx = idx
        answer = f"{chr(ord('A') + answer_idx)}"
        return question, answer

    # pyre-fixme[3]: Return type must be annotated.
    def __iter__(self):
        return iter(self.data)

    # pyre-fixme[3]: Return type must be annotated.
    # pyre-fixme[2]: Parameter must be annotated.
    def __getitem__(self, i):
        return self.data[i]

def process_video(video_path, max_frames_num, fps=1, force_sample=False):
    if max_frames_num == 0:
        return np.zeros((1, 336, 336, 3))
    vr = VideoReader(video_path, ctx=cpu(),num_threads=1)
    total_frame_num = len(vr)
    video_time = total_frame_num / vr.get_avg_fps()
    fps = round(vr.get_avg_fps()/fps)
    frame_idx = [i for i in range(0, len(vr), fps)]
    frame_time = [i/fps for i in frame_idx]
    if len(frame_idx) > max_frames_num or force_sample:
        sample_fps = max_frames_num
        uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
        frame_idx = uniform_sampled_frames.tolist()
        frame_time = [i/vr.get_avg_fps() for i in frame_idx]
    frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
    spare_frames = vr.get_batch(frame_idx).asnumpy()

    return spare_frames, frame_time, video_time

def extract_audio(video_path, audio_path):
    if not os.path.exists(audio_path):
        ffmpeg.input(video_path).output(audio_path, acodec='pcm_s16le', ac=1, ar='16k').run()

def chunk_audio(audio_path, chunk_length_s=30):
    speech, sr = torchaudio.load(audio_path)
    speech = speech.mean(dim=0)  
    speech = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(speech)  
    num_samples_per_chunk = chunk_length_s * 16000 
    chunks = []
    for i in range(0, len(speech), num_samples_per_chunk):
        chunks.append(speech[i:i + num_samples_per_chunk])
    return chunks

def transcribe_chunk(chunk):

    inputs = whisper_processor(chunk, return_tensors="pt")
    inputs["input_features"] = inputs["input_features"].to(whisper_model.device, torch.float16)
    with torch.no_grad():
        predicted_ids = whisper_model.generate(
            inputs["input_features"],
            no_repeat_ngram_size=2,
            early_stopping=True
        )
    transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return transcription

def get_asr_docs(video_path, audio_path):

    full_transcription = []

    try:
        extract_audio(video_path, audio_path)
    except:
        return full_transcription
    audio_chunks = chunk_audio(audio_path, chunk_length_s=30)
    
    for chunk in audio_chunks:
        transcription = transcribe_chunk(chunk)
        full_transcription.append(transcription)

    return full_transcription

def get_ocr_docs(frames):
    reader = easyocr.Reader(['en']) 
    text_set = []
    ocr_docs = []
    for img in frames:
        ocr_results = reader.readtext(img)
        det_info = ""
        for result in ocr_results:
            text = result[1]
            confidence = result[2]
            if confidence > 0.5 and text not in text_set:
                det_info += f"{text}; "
                text_set.append(text)
        if len(det_info) > 0:
            ocr_docs.append(det_info)

    return ocr_docs

    
def save_frames(frames):
    file_paths = []
    for i, frame in enumerate(frames):
        img = Image.fromarray(frame)
        file_path = f'restore/frame_{i}.png'
        img.save(file_path)
        file_paths.append(file_path)
    return file_paths
    
def get_det_docs(frames, prompt):
    prompt = ",".join(prompt)
    frames_path = save_frames(frames)
    res = []
    if len(frames) > 0:
        client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        client_socket.connect(('0.0.0.0', 9999))
        data = (frames_path, prompt)
        client_socket.send(pickle.dumps(data))
        result_data = client_socket.recv(4096)
        try:
            res = pickle.loads(result_data)
        except:
            res = []
    return res

def det_preprocess(det_docs, location, relation, number):

    scene_descriptions = []

    for det_doc_per_frame in det_docs:
        objects = []
        scene_description = ""
        if len(det_doc_per_frame) > 0:
            for obj_id, objs in enumerate(det_doc_per_frame.split(";")):
                obj_name = objs.split(":")[0].strip()
                obj_bbox = objs.split(":")[1].strip()
                obj_bbox = ast.literal_eval(obj_bbox)
                objects.append({"id": obj_id, "label": obj_name, "bbox": obj_bbox})

            scene_description = generate_scene_graph_description(objects, location, relation, number)
        scene_descriptions.append(scene_description)
    
    return scene_descriptions


# load your VLM
device = "cuda"
overwrite_config = {}
tokenizer, model, image_processor, max_length = load_pretrained_model(
    "lmms-lab/LLaVA-Video-7B-Qwen2", 
    None, 
    "llava_qwen", 
    torch_dtype="bfloat16", 
    device_map="auto", 
    overwrite_config=overwrite_config)  # Add any other thing you want to pass in llava_model_args
model.eval()
conv_template = "qwen_1_5"  # Make sure you use correct chat template for different models

dataset = EvalDatasetMLVU()
shard_dataset = IterableDatasetShard(
    dataset,
    batch_size=1
)

# The inference function of your VLM
def llava_inference(qs, video):
    if video is not None:
        question = DEFAULT_IMAGE_TOKEN + qs
    else:
        question = qs
    conv = copy.deepcopy(conv_templates[conv_template])
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    cont = model.generate(
        input_ids,
        images=video,
        modalities= ["video"],
        do_sample=False,
        temperature=0,
        max_new_tokens=4096,
    )
    text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip()
    return text_outputs


# super-parameters setting
rag_threshold = 0.3
clip_threshold = 0.3
beta = 3.0 

# Choose the auxiliary texts you want
USE_OCR = True
USE_ASR = True
USE_DET = True
print(f"---------------OCR{rag_threshold}: {USE_OCR}-----------------")
print(f"---------------ASR{rag_threshold}: {USE_ASR}-----------------")
print(f"---------------DET{beta}-{clip_threshold}: {USE_DET}-----------------")
print(f"---------------Frames: {max_frames_num}-----------------")

otuput = []

for line in tqdm(shard_dataset):
    video_name = line.get("video_name", None)
    answer = line.get("answer", None)
    prompt = line.get("prompt", None)
    questions = line.get("question", None)
    task_type = line.get("task_type", None)
    video_path = line.get("video_path", None)
    candidates = line.get("candidates", None)
    subtitle_path = line.get("subtitle", None)
    node_list = line.get("node_list", None)
    letters = line.get("letters", ["A", "B", "C", "D"])
    llm_info = line.get("info", None)

    question = prompt

    try:
        frames, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True)
        raw_video = [f for f in frames]

        video = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].cuda().bfloat16()
        video = [video]
    except:
        continue

    if USE_DET:
        video_tensor = []
        for frame in raw_video:
            processed = clip_processor(images=frame, return_tensors="pt")["pixel_values"].to(clip_model.device, dtype=torch.float16)
            video_tensor.append(processed.squeeze(0))
        video_tensor = torch.stack(video_tensor, dim=0)

    if USE_OCR:
        ocr_docs_total = get_ocr_docs(frames)

    if USE_ASR:
        if os.path.exists(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt")):
            with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'r', encoding='utf-8') as f:
                asr_docs_total = f.readlines()
        else:
            audio_path = os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".wav")
            asr_docs_total = get_asr_docs(video_path, audio_path)
            with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'w', encoding='utf-8') as f:
                for doc in asr_docs_total:
                    f.write(doc + '\n')

    # step 0: get cot information
    retrieve_pmt_0 = question
    # you can change this decouple prompt to fit your requirements
    retrieve_pmt_0 += "\nTo answer the question step by step, you can provide your retrieve request to assist you by the following json format:"
    retrieve_pmt_0 += '''{
        "ASR": Optional[str]. The subtitles of the video that may relavent to the question you want to retrieve, in two sentences. If you no need for this information, please return null.
        "DET": Optional[list]. (The output must include only physical entities, not abstract concepts, less than five entities) All the physical entities and their location related to the question you want to retrieve, not abstract concepts. If you no need for this information, please return null.
        "TYPE": Optional[list]. (The output must be specified as null or a list containing only one or more of the following strings: 'location', 'number', 'relation'. No other values are valid for this field) The information you want to obtain about the detected objects. If you need the object location in the video frame, output "location"; if you need the number of specific object, output "number"; if you need the positional relationship between objects, output "relation". 
    }
    ## Example 1: 
    Question: How many blue balloons are over the long table in the middle of the room at the end of this video? A. 1. B. 2. C. 3. D. 4.
    Your retrieve can be:
    {
        "ASR": "The location and the color of balloons, the number of the blue balloons.",
        "DET": ["blue ballons", "long table"],
        "TYPE": ["relation", "number"]
    }
    ## Example 2: 
    Question: In the lower left corner of the video, what color is the woman wearing on the right side of the man in black clothes? A. Blue. B. White. C. Red. D. Yellow.
    Your retrieve can be:
    {
        "ASR": null,
        "DET": ["the man in black", "woman"],
        "TYPE": ["location", "relation"]
    }
    ## Example 3: 
    Question: In which country is the comedy featured in the video recognized worldwide? A. China. B. UK. C. Germany. D. United States.
    Your retrieve can be:
    {
        "ASR": "The country recognized worldwide for its comedy.",
        "DET": null,
        "TYPE": null
    }
    Note that you don't need to answer the question in this step, so you don't need any infomation about the video of image. You only need to provide your retrieve request (it's optional), and I will help you retrieve the infomation you want. Please provide the json format.'''

    json_request = llava_inference(retrieve_pmt_0, None)

    # step 1: get docs information
    query = [question]

    # APE fetch
    if USE_DET:
        det_docs = []
        try:
            request_det = json.loads(json_request)["DET"]
            request_det = filter_keywords(request_det)
            clip_text = ["A picture of " + txt for txt in request_det]
            if len(clip_text) == 0:
                clip_text = ["A picture of object"]
        except:
            request_det = None
            clip_text = ["A picture of object"]

        clip_inputs = clip_processor(text=clip_text, return_tensors="pt", padding=True, truncation=True).to(clip_model.device)
        clip_img_feats = clip_model.get_image_features(video_tensor)
        with torch.no_grad():
            text_features = clip_model.get_text_features(**clip_inputs)
            similarities = (clip_img_feats @ text_features.T).squeeze(0).mean(1).cpu()
            similarities = np.array(similarities, dtype=np.float64)
            alpha = beta * (len(similarities) / 16)
            similarities = similarities * alpha / np.sum(similarities)

        del clip_inputs, clip_img_feats, text_features
        torch.cuda.empty_cache()

        det_top_idx = [idx for idx in range(max_frames_num) if similarities[idx] > clip_threshold]
            
        if request_det is not None and len(request_det) > 0:
            det_docs = get_det_docs(frames[det_top_idx], request_det)  

            L, R, N = False, False, False
            try:
                det_retrieve_info = json.loads(json_request)["TYPE"]
            except:
                det_retrieve_info = None
            if det_retrieve_info is not None:
                if "location" in det_retrieve_info:
                    L = True
                if "relation" in det_retrieve_info:
                    R = True
                if "number" in det_retrieve_info:
                    N = True
            det_docs = det_preprocess(det_docs, location=L, relation=R, number=N)  # pre-process of APE information


    # OCR fetch
    if USE_OCR:
        try:
            request_det = json.loads(json_request)["DET"]
            request_det = filter_keywords(request_det)
        except:
            request_det = None
        ocr_docs = []
        if len(ocr_docs_total) > 0:
            ocr_query = query.copy()
            if request_det is not None and len(request_det) > 0:
                ocr_query.extend(request_det)
            ocr_docs, _ = retrieve_documents_with_dynamic(ocr_docs_total, ocr_query, threshold=rag_threshold)

    # ASR fetch
    if USE_ASR:
        asr_docs = []
        try:
            request_asr = json.loads(json_request)["ASR"]
        except:
            request_asr = None
        if len(asr_docs_total) > 0:
            asr_query = query.copy()
            if request_asr is not None:
                asr_query.append(request_asr)
            asr_docs, _ = retrieve_documents_with_dynamic(asr_docs_total, asr_query, threshold=rag_threshold)

    qs = ""
    if USE_DET and len(det_docs) > 0:
        for i, info in enumerate(det_docs):
            if len(info) > 0:
                qs += f"Frame {str(det_top_idx[i]+1)}: " + info + "\n"
        if len(qs) > 0:
            qs = f"\nVideo have {str(max_frames_num)} frames in total, the detected objects' information in specific frames: " + qs
    if USE_ASR and len(asr_docs) > 0:
        qs += "\nVideo Automatic Speech Recognition information (given in chronological order of the video): " + " ".join(asr_docs)
    if USE_OCR and len(ocr_docs) > 0:
        qs += "\nVideo OCR information (given in chronological order of the video): " + "; ".join(ocr_docs)
    qs += "Select the best answer to the following multiple-choice question based on the video and the information (if given). Respond with only the letter (A, B, C, or D) of the correct option. Question: " + question  # you can change this prompt

    res = llava_inference(qs, video)
    line["pred"] = res
    print(line)
    otuput.append(line)

with open("videorag_llava_video_7b.json", "w") as f:
    json.dump(otuput, f)

Could you help with reproducing LLaVA-Video-7B result? And is it possible for you to provide an official result for Qwen2-VL-7B+Video-RAG, or Qwen2.5-VL-7B+Video-RAG?

Thanks very much!

@Leon1207
Copy link
Owner

Leon1207 commented Feb 22, 2025

Hi!
We apologize for not including the results of Qwen2-VL-7B since Video-RAG is a plug-and-play module that can be inserted into any LVLMs, we do not have enough time to conduct experiments on Qwen2-VL-7B. Would you mind offering me some time to reproduce the results of Qwen2-VL-7B in Video-MME and MLVU? Once we reproduce the result, we will remind you of this issue. Appreciate!
As for the result of LLaVA-Video-7B in MLVU, we used 64 frames for evaluation to get 72.4, but we found 32 frames used in your code. We also provide our code under the evals dir, together with our result file with the following results:

Image

MLVU_LLaVA-Video-7B+Video-RAG.json

@xiaoqian-shen
Copy link
Author

Thanks for your response! I will then run with 64 frames, and also sincerely appreciate it if you have time to provide the QwenVL's result.

@Leon1207
Copy link
Owner

Leon1207 commented Feb 23, 2025

Thanks for your response! I will then run with 64 frames, and also sincerely appreciate it if you have time to provide the QwenVL's result.

Hi!
We first reproduced the result in Video-MME at the setting of 64 frames, and MLVU is on the way, here are the results:

  • Qwen2-7B:60.1 | 72.7,57.4,50.2 (Overall | Short, Medium, Long)
  • Qwen2-7B + Video-RAG:65.7 | 73.6,64.1,59.3

If you have more requirements, please let me know, thanks!

@Leon1207
Copy link
Owner

Hi!
The results of MLVU at the setting of 64 frames are as follows:

  • Qwen2-7B:60.4
  • Qwen2-7B + Video-RAG:61.0
    If you have more requirements, please let me know, thanks!

@xiaoqian-shen
Copy link
Author

Thanks for your reply! I also have similar result for Qwen2.5VL + Video-RAG.
{'count': 24.271844660194176, 'ego': 54.26136363636363, 'findNeedle': 67.04225352112675, 'order': 51.35135135135135, 'plotQA': 67.90352504638219, 'anomaly_reco': 71.0, 'topic_reasoning': 88.97338403041825, 'Acc': 62.28150873965041}

I am wondering that what is the peak performance for Video-RAG by taking more frames? Since the base model Qwen2VL can achieves 65.7 on MLVU by taking in more frames, do you have any insight that taking in more frames will make Video-RAG be better than the original model's performance?

@Leon1207
Copy link
Owner

Hi!
We believe increasing the number of frames can boost Video-RAG, however, due to the resource constraint, it's limited for us to explore the peak performance for Video-RAG (the official frame-rate of Qwen2VL in Video-MME can up to 768).
Here we already tried to conduct our Video-RAG on LongVA-7B from 8 to 256 frames:

Image

Hope this result can help you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants