Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Qwen2_VL class to update max_num_frames default value and ha…
Browse files Browse the repository at this point in the history
…ndle video processing and context handling
pufanyi committed Jan 16, 2025
1 parent aa10a0c commit 40ef743
Showing 2 changed files with 64 additions and 6 deletions.
54 changes: 54 additions & 0 deletions alb.json

Large diffs are not rendered by default.

16 changes: 10 additions & 6 deletions lmms_eval/models/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ def __init__(
use_flash_attention_2: Optional[bool] = True,
max_pixels: int = 1605632,
min_pixels: int = 3136,
max_num_frames: int = 1,
max_num_frames: int = 20,
use_custom_video_loader: Optional[bool] = True,
fps: Optional[float] = None, # Only applicable if use_custom_video_loader is True
max_image_size: Optional[int] = 1024, # Only applicable if use_custom_video_loader is True
@@ -215,7 +215,7 @@ def _collate(x):
split = split[0]
if self.continual_mode and self.cache_mode == "resume":
doc_uuid = get_uuid(task, split, doc_id)
print(doc_uuid)
# print(doc_uuid)
if doc_uuid in self.response_cache:
content = self.response_cache[doc_uuid]
if content:
@@ -253,9 +253,9 @@ def _collate(x):
messages = []
processed_visuals = []
for i, context in enumerate(contexts):
# context += "\nPlease think step by step."
context += "\nPlease think step by step."

print("context", context)
# print("context", context)

if "<image>" in context:
context = context.split("<image>")
@@ -267,15 +267,17 @@ def _collate(x):

if len(visuals) > 0:
visual = visuals[i] if i < len(visuals) else None
print("visuals", visual)
# print("visuals", visual)
if isinstance(visual, Image.Image):
visual = [visual]
if isinstance(visual, (list, tuple)) and isinstance(visual[0], str):
assert len(visual) == 1, f"Expected a single video file but got {len(visual)} files"
visual = visual[0]
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
if self.use_custom_video_loader:
visual = read_video_pyav_base64(visual, num_frm=self.max_num_frames, fps=self.fps, img_format="JPEG", max_image_size=self.max_image_size)
image_contents = list(map(lambda x: f"data:image/jpeg;base64,{x}", visual))
if len(context) == 2:
print("image_contents", image_contents)
if len(image_contents) == 1:
message.append(
{
@@ -287,6 +289,8 @@ def _collate(x):
message.append({"role": "user", "content": [{"type": "text", "text": context[0]}, {"type": "image", "image": image_contents[-1]}, {"type": "text", "text": context[1]}]})
else:
message.append({"role": "user", "content": [{"type": "video", "video": image_contents}, {"type": "text", "text": context}]})
# with open("alb.json", "w") as f:
# json.dump(message, f, indent=4)
else:
vr = decord.VideoReader(visual)
first_frame = vr[0].asnumpy()

0 comments on commit 40ef743

Please sign in to comment.