Skip to content

Commit e83c642

Browse files
author
Oscar Lo
committed
added mantis_eval
1 parent c1bdb64 commit e83c642

File tree

6 files changed

+110
-33
lines changed

6 files changed

+110
-33
lines changed

open_flamingo/eval/eval_datasets.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"vizwiz",
1616
"textvqa",
1717
"gqa",
18+
"mantiseval",
1819
"hateful_memes",
1920
"imagenet",
2021
]
@@ -107,14 +108,26 @@ def get_img_path(self, question):
107108
return os.path.join(self.image_dir_path, question["image_id"])
108109
elif self.dataset_name == "textvqa" or self.dataset_name == "gqa":
109110
return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
111+
elif self.dataset_name == "mantiseval":
112+
img_paths = []
113+
for img_id in question['image_id']:
114+
img_paths.append(os.path.join(self.image_dir_path, f"{img_id}.jpg"))
115+
return img_paths
110116
else:
111117
raise Exception(f"Unknown VQA dataset {self.dataset_name}")
112118

113119
def __getitem__(self, idx):
114120
question = self.questions[idx]
115121
img_path = self.get_img_path(question)
116-
image = Image.open(img_path)
117-
image.load()
122+
if self.dataset_name == "mantiseval":
123+
image = []
124+
for path in img_path:
125+
img = Image.open(path)
126+
img.load()
127+
image.append(img)
128+
else:
129+
image = Image.open(img_path)
130+
image.load()
118131
results = {
119132
"image": image,
120133
"question": question["question"],

open_flamingo/eval/eval_models/blip.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from transformers import Blip2Processor, Blip2ForConditionalGeneration
77
from eval_models.eval_model import BaseEvalModel
8-
from utils import unwrap_model
8+
from utils import unwrap_model, combine_images
99
from transformers.modeling_outputs import CausalLMOutputWithPast
1010

1111

@@ -27,9 +27,14 @@ def required_args(self):
2727

2828
def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
2929
batch_images = None
30+
for i in range(len(batch)):
31+
if len(batch[i]) > 1:
32+
batch[i] = combine_images(batch[i])
33+
"""
3034
assert all(
3135
len(example) == 1 for example in batch
3236
), "BLIP-2 only supports one image per example"
37+
"""
3338
for example in batch:
3439
if batch_images is None:
3540
batch_images = self.processor.image_processor(
@@ -111,6 +116,9 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
111116

112117
def get_gqa_prompt(self, question, answer=None) -> str:
113118
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
119+
120+
def get_mantiseval_prompt(self, question, answer=None) -> str:
121+
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
114122

115123
def get_coco_prompt(self, caption=None) -> str:
116124
return f"A photo of {caption if caption is not None else ''}"

open_flamingo/eval/eval_models/open_flamingo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
291291
def get_gqa_prompt(self, question, answer=None) -> str:
292292
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
293293

294+
def get_mantiseval_prompt(self, question, answer=None) -> str:
295+
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
296+
294297
def get_coco_prompt(self, caption=None) -> str:
295298
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
296299

open_flamingo/eval/evaluate.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
HatefulMemesDataset,
3535
)
3636
from ok_vqa_utils import postprocess_ok_vqa_generation
37-
from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation
37+
from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation, compute_mantis_accuracy
3838

3939
parser = argparse.ArgumentParser()
4040
parser.add_argument(
@@ -152,29 +152,30 @@
152152
default=None,
153153
)
154154

155-
## VQAV2, OK-VQA, VizWiz, TextVQA, GQA Datasets
156-
for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa']:
155+
## VQAV2, OK-VQA, VizWiz, TextVQA, GQA, Mantis-Eval Datasets
156+
for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa', 'mantiseval']:
157157
parser.add_argument(
158-
f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' else f"--{task}_train_image_dir_path",
158+
f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' or task=='mantiseval' else f"--{task}_train_image_dir_path",
159159
type=str,
160160
default=None,
161161
)
162-
if task!='gqa' and task!='textvqa':
162+
if task != 'mantiseval':
163+
if task!='gqa' and task!='textvqa':
164+
parser.add_argument(
165+
f"--{task}_test_image_dir_path",
166+
type=str,
167+
default=None,
168+
)
163169
parser.add_argument(
164-
f"--{task}_test_image_dir_path",
170+
f"--{task}_train_questions_json_path",
171+
type=str,
172+
default=None,
173+
)
174+
parser.add_argument(
175+
f"--{task}_train_annotations_json_path",
165176
type=str,
166177
default=None,
167178
)
168-
parser.add_argument(
169-
f"--{task}_train_questions_json_path",
170-
type=str,
171-
default=None,
172-
)
173-
parser.add_argument(
174-
f"--{task}_train_annotations_json_path",
175-
type=str,
176-
default=None,
177-
)
178179
parser.add_argument(
179180
f"--{task}_test_questions_json_path",
180181
type=str,
@@ -315,7 +316,7 @@ def main():
315316
}
316317
)
317318

318-
for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]:
319+
for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]:
319320
if var_args[f"eval_{vqa_task}"]:
320321
print(f"Evaluating on {vqa_task}...")
321322

@@ -601,16 +602,16 @@ def evaluate_vqa(
601602
float: accuracy score
602603
"""
603604
var_args = vars(args)
604-
for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]:
605+
for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]:
605606
if dataset_name == task:
606607
task = task
607-
train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"]
608-
train_questions_json_path = var_args[f"{task}_train_questions_json_path"]
609-
train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"]
610-
test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"]
608+
train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"]
609+
train_questions_json_path = var_args[f"{task}_train_questions_json_path"] if task!="mantiseval" else var_args[f"{task}_test_questions_json_path"]
610+
train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"] if task!="mantiseval" else var_args[f"{task}_test_annotations_json_path"]
611+
test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"]
611612
test_questions_json_path = var_args[f"{task}_test_questions_json_path"]
612613
test_annotations_json_path = var_args[f"{task}_test_annotations_json_path"]
613-
if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]:
614+
if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]:
614615
raise ValueError(f"Unsupported dataset: {dataset_name}")
615616

616617
train_dataset = VQADataset(
@@ -675,7 +676,10 @@ def evaluate_vqa(
675676
context_images = [x["image"] for x in batch_demo_samples[i]]
676677
else:
677678
context_images = []
678-
batch_images.append(context_images + [batch["image"][i]])
679+
if dataset_name == "mantiseval":
680+
batch_images.append(context_images + batch["image"][i])
681+
else:
682+
batch_images.append(context_images + [batch["image"][i]])
679683

680684
context_text = "".join(
681685
[
@@ -703,7 +707,7 @@ def evaluate_vqa(
703707
num_beams=num_beams,
704708
length_penalty=length_penalty,
705709
)
706-
710+
707711
process_function = (
708712
postprocess_ok_vqa_generation
709713
if dataset_name == "okvqa"
@@ -732,11 +736,17 @@ def evaluate_vqa(
732736
f.write(json.dumps(all_predictions, indent=4))
733737

734738
if test_annotations_json_path is not None:
735-
acc = compute_vqa_accuracy(
736-
f"{dataset_name}results_{random_uuid}.json",
737-
test_questions_json_path,
738-
test_annotations_json_path,
739-
)
739+
if dataset_name == "mantiseval":
740+
acc = compute_mantis_accuracy(
741+
f"{dataset_name}results_{random_uuid}.json",
742+
test_annotations_json_path,
743+
)
744+
else:
745+
acc = compute_vqa_accuracy(
746+
f"{dataset_name}results_{random_uuid}.json",
747+
test_questions_json_path,
748+
test_annotations_json_path,
749+
)
740750
# delete the temporary file
741751
os.remove(f"{dataset_name}results_{random_uuid}.json")
742752

open_flamingo/eval/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
import torch.nn as nn
55
from contextlib import suppress
6+
from PIL import Image
67

78

89
def random_seed(seed=42, rank=0):
@@ -122,3 +123,25 @@ def get_autocast(precision):
122123
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
123124
else:
124125
return suppress
126+
127+
def combine_images(images):
128+
img_heights, _ = zip(*(img.size for img in images))
129+
avg_height = sum(img_heights) // len(img_heights)
130+
for i, img in enumerate(images):
131+
images[i] = img.resize((int(img.size[0] * avg_height / img.size[1]), avg_height))
132+
resized_heights, resized_widths = zip(*(img.size for img in images))
133+
total_width = sum(resized_widths)
134+
max_height = max(resized_heights)
135+
new_img = Image.new("RGB", (total_width + 10 * (len(images) - 1), max_height))
136+
x_offset = 0
137+
for i, img in enumerate(images):
138+
if i > 0:
139+
new_img.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0))
140+
x_offset += 1
141+
new_img.paste(Image.new("RGB", (8, max_height), (255, 255, 255)), (x_offset, 0))
142+
x_offset += 8
143+
new_img.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0))
144+
x_offset += 1
145+
new_img.paste(img, (x_offset, 0))
146+
x_offset += img.size[0]
147+
return new_img

open_flamingo/eval/vqa_metric.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,26 @@ def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_p
553553

554554
return vqaEval.accuracy["overall"]
555555

556+
def compute_mantis_accuracy(result_json_path, annotation_json_path):
557+
dataset = json.load(open(annotation_json_path, "r"))
558+
gt_ans = {}
559+
for ann in dataset["annotations"]:
560+
gt_ans[ann["question_id"]] = {"answer": ann["answers"][0]["answer"], "type": ann["question_type"]}
561+
results = json.load(open(result_json_path, "r"))
562+
assert type(results) == list, "results is not an array of objects"
563+
correct = 0
564+
for res in results:
565+
res_ans = res["answer"].lower().strip('()\n ')
566+
if gt_ans[res["question_id"]]["type"] == "multi-choice":
567+
if len(res_ans) > 1:
568+
for c in res_ans:
569+
if c.isalpha():
570+
res_ans = c
571+
break
572+
if res_ans == gt_ans[res["question_id"]]["answer"].lower().strip('()\n '):
573+
correct+=1
574+
acc = correct / len(results)
575+
return acc
556576

557577
def postprocess_vqa_generation(predictions):
558578
answer = re.split("Question|Answer|Short", predictions, 1)[0]

0 commit comments

Comments
 (0)