Skip to content

Commit 358cecc

Browse files
author
Oscar Lo
committed
added gqa as eval dataset
1 parent a5378a8 commit 358cecc

File tree

4 files changed

+99
-1
lines changed

4 files changed

+99
-1
lines changed

open_flamingo/eval/eval_datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"okvqa",
1515
"vizwiz",
1616
"textvqa",
17+
"gqa",
1718
"hateful_memes",
1819
"imagenet",
1920
]
@@ -104,7 +105,7 @@ def get_img_path(self, question):
104105
)
105106
elif self.dataset_name == "vizwiz":
106107
return os.path.join(self.image_dir_path, question["image_id"])
107-
elif self.dataset_name == "textvqa":
108+
elif self.dataset_name == "textvqa" or self.dataset_name == "gqa":
108109
return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
109110
else:
110111
raise Exception(f"Unknown VQA dataset {self.dataset_name}")

open_flamingo/eval/eval_models/blip.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def get_vizwiz_prompt(self, question, answer=None) -> str:
108108

109109
def get_textvqa_prompt(self, question, answer=None) -> str:
110110
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
111+
112+
def get_gqa_prompt(self, question, answer=None) -> str:
113+
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
111114

112115
def get_coco_prompt(self, caption=None) -> str:
113116
return f"A photo of {caption if caption is not None else ''}"

open_flamingo/eval/eval_models/open_flamingo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ def get_vizwiz_prompt(self, question, answer=None) -> str:
287287

288288
def get_textvqa_prompt(self, question, answer=None) -> str:
289289
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
290+
291+
def get_gqa_prompt(self, question, answer=None) -> str:
292+
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
290293

291294
def get_coco_prompt(self, caption=None) -> str:
292295
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

open_flamingo/eval/evaluate.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@
139139
default=False,
140140
help="Whether to evaluate on TextVQA.",
141141
)
142+
143+
parser.add_argument(
144+
"--eval_gqa",
145+
action="store_true",
146+
default=False,
147+
help="Whether to evaluate on GQA.",
148+
)
149+
142150
parser.add_argument(
143151
"--eval_imagenet",
144152
action="store_true",
@@ -346,6 +354,44 @@
346354
default=None,
347355
)
348356

357+
# GQA Dataset
358+
parser.add_argument(
359+
"--gqa_train_image_dir_path",
360+
type=str,
361+
help="Path to the gqa train images directory.",
362+
default=None,
363+
)
364+
parser.add_argument(
365+
"--gqa_train_questions_json_path",
366+
type=str,
367+
help="Path to the gqa questions json file.",
368+
default=None,
369+
)
370+
parser.add_argument(
371+
"--gqa_train_annotations_json_path",
372+
type=str,
373+
help="Path to the gqa annotations json file",
374+
default=None,
375+
)
376+
parser.add_argument(
377+
"--gqa_test_image_dir_path",
378+
type=str,
379+
help="Path to the gqa test images directory.",
380+
default=None,
381+
)
382+
parser.add_argument(
383+
"--gqa_test_questions_json_path",
384+
type=str,
385+
help="Path to the gqa questions json file",
386+
default=None,
387+
)
388+
parser.add_argument(
389+
"--gqa_test_annotations_json_path",
390+
type=str,
391+
help="Path to the gqa annotations json file",
392+
default=None,
393+
)
394+
349395
## Imagenet dataset
350396
parser.add_argument("--imagenet_root", type=str, default="/tmp")
351397

@@ -650,6 +696,44 @@ def main():
650696
"stddev": np.nanstd(scores),
651697
}
652698
)
699+
700+
if args.eval_gqa:
701+
print("Evaluating on GQA...")
702+
703+
#load cached demonstration features on GQA
704+
if args.cached_demonstration_features is not None:
705+
cached_features = torch.load(
706+
f"{args.cached_demonstration_features}/imagenet.pkl", map_location="cpu"
707+
)
708+
else:
709+
cached_features = None
710+
711+
for shot in args.shots:
712+
scores = []
713+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
714+
gqa_score = evaluate_vqa(
715+
args=args,
716+
eval_model=eval_model,
717+
num_shots=shot,
718+
seed=seed,
719+
dataset_name="gqa",
720+
max_new_tokens=10,
721+
cached_features=cached_features,
722+
)
723+
if args.rank == 0:
724+
print(f"Shots {shot} Trial {trial} GQA score: {gqa_score}")
725+
scores.append(gqa_score)
726+
727+
if args.rank == 0:
728+
print(f"Shots {shot} Mean GQA score: {np.nanmean(scores)}")
729+
results["gqa"].append(
730+
{
731+
"shots": shot,
732+
"trials": scores,
733+
"mean": np.nanmean(scores),
734+
"stddev": np.nanstd(scores),
735+
}
736+
)
653737

654738
if args.eval_imagenet:
655739
print("Evaluating on ImageNet...")
@@ -968,6 +1052,13 @@ def evaluate_vqa(
9681052
test_image_dir_path = args.textvqa_image_dir_path
9691053
test_questions_json_path = args.textvqa_test_questions_json_path
9701054
test_annotations_json_path = args.textvqa_test_annotations_json_path
1055+
elif dataset_name == "gqa":
1056+
train_image_dir_path = args.gqa_train_image_dir_path
1057+
train_questions_json_path = args.gqa_train_questions_json_path
1058+
train_annotations_json_path = args.gqa_train_annotations_json_path
1059+
test_image_dir_path = args.gqa_test_image_dir_path
1060+
test_questions_json_path = args.gqa_test_questions_json_path
1061+
test_annotations_json_path = args.gqa_test_annotations_json_path
9711062
else:
9721063
raise ValueError(f"Unsupported dataset: {dataset_name}")
9731064

0 commit comments

Comments
 (0)