|
34 | 34 | HatefulMemesDataset,
|
35 | 35 | )
|
36 | 36 | from ok_vqa_utils import postprocess_ok_vqa_generation
|
37 |
| -from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation |
| 37 | +from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation, compute_mantis_accuracy |
38 | 38 |
|
39 | 39 | parser = argparse.ArgumentParser()
|
40 | 40 | parser.add_argument(
|
|
152 | 152 | default=None,
|
153 | 153 | )
|
154 | 154 |
|
155 |
| -## VQAV2, OK-VQA, VizWiz, TextVQA, GQA Datasets |
156 |
| -for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa']: |
| 155 | +## VQAV2, OK-VQA, VizWiz, TextVQA, GQA, Mantis-Eval Datasets |
| 156 | +for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa', 'mantiseval']: |
157 | 157 | parser.add_argument(
|
158 |
| - f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' else f"--{task}_train_image_dir_path", |
| 158 | + f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' or task=='mantiseval' else f"--{task}_train_image_dir_path", |
159 | 159 | type=str,
|
160 | 160 | default=None,
|
161 | 161 | )
|
162 |
| - if task!='gqa' and task!='textvqa': |
| 162 | + if task != 'mantiseval': |
| 163 | + if task!='gqa' and task!='textvqa': |
| 164 | + parser.add_argument( |
| 165 | + f"--{task}_test_image_dir_path", |
| 166 | + type=str, |
| 167 | + default=None, |
| 168 | + ) |
163 | 169 | parser.add_argument(
|
164 |
| - f"--{task}_test_image_dir_path", |
| 170 | + f"--{task}_train_questions_json_path", |
| 171 | + type=str, |
| 172 | + default=None, |
| 173 | + ) |
| 174 | + parser.add_argument( |
| 175 | + f"--{task}_train_annotations_json_path", |
165 | 176 | type=str,
|
166 | 177 | default=None,
|
167 | 178 | )
|
168 |
| - parser.add_argument( |
169 |
| - f"--{task}_train_questions_json_path", |
170 |
| - type=str, |
171 |
| - default=None, |
172 |
| - ) |
173 |
| - parser.add_argument( |
174 |
| - f"--{task}_train_annotations_json_path", |
175 |
| - type=str, |
176 |
| - default=None, |
177 |
| - ) |
178 | 179 | parser.add_argument(
|
179 | 180 | f"--{task}_test_questions_json_path",
|
180 | 181 | type=str,
|
@@ -315,7 +316,7 @@ def main():
|
315 | 316 | }
|
316 | 317 | )
|
317 | 318 |
|
318 |
| - for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]: |
| 319 | + for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]: |
319 | 320 | if var_args[f"eval_{vqa_task}"]:
|
320 | 321 | print(f"Evaluating on {vqa_task}...")
|
321 | 322 |
|
@@ -601,16 +602,16 @@ def evaluate_vqa(
|
601 | 602 | float: accuracy score
|
602 | 603 | """
|
603 | 604 | var_args = vars(args)
|
604 |
| - for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]: |
| 605 | + for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]: |
605 | 606 | if dataset_name == task:
|
606 | 607 | task = task
|
607 |
| - train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"] |
608 |
| - train_questions_json_path = var_args[f"{task}_train_questions_json_path"] |
609 |
| - train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"] |
610 |
| - test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"] |
| 608 | + train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"] |
| 609 | + train_questions_json_path = var_args[f"{task}_train_questions_json_path"] if task!="mantiseval" else var_args[f"{task}_test_questions_json_path"] |
| 610 | + train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"] if task!="mantiseval" else var_args[f"{task}_test_annotations_json_path"] |
| 611 | + test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"] |
611 | 612 | test_questions_json_path = var_args[f"{task}_test_questions_json_path"]
|
612 | 613 | test_annotations_json_path = var_args[f"{task}_test_annotations_json_path"]
|
613 |
| - if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]: |
| 614 | + if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]: |
614 | 615 | raise ValueError(f"Unsupported dataset: {dataset_name}")
|
615 | 616 |
|
616 | 617 | train_dataset = VQADataset(
|
@@ -675,7 +676,10 @@ def evaluate_vqa(
|
675 | 676 | context_images = [x["image"] for x in batch_demo_samples[i]]
|
676 | 677 | else:
|
677 | 678 | context_images = []
|
678 |
| - batch_images.append(context_images + [batch["image"][i]]) |
| 679 | + if dataset_name == "mantiseval": |
| 680 | + batch_images.append(context_images + batch["image"][i]) |
| 681 | + else: |
| 682 | + batch_images.append(context_images + [batch["image"][i]]) |
679 | 683 |
|
680 | 684 | context_text = "".join(
|
681 | 685 | [
|
@@ -703,7 +707,7 @@ def evaluate_vqa(
|
703 | 707 | num_beams=num_beams,
|
704 | 708 | length_penalty=length_penalty,
|
705 | 709 | )
|
706 |
| - |
| 710 | + |
707 | 711 | process_function = (
|
708 | 712 | postprocess_ok_vqa_generation
|
709 | 713 | if dataset_name == "okvqa"
|
@@ -732,11 +736,17 @@ def evaluate_vqa(
|
732 | 736 | f.write(json.dumps(all_predictions, indent=4))
|
733 | 737 |
|
734 | 738 | if test_annotations_json_path is not None:
|
735 |
| - acc = compute_vqa_accuracy( |
736 |
| - f"{dataset_name}results_{random_uuid}.json", |
737 |
| - test_questions_json_path, |
738 |
| - test_annotations_json_path, |
739 |
| - ) |
| 739 | + if dataset_name == "mantiseval": |
| 740 | + acc = compute_mantis_accuracy( |
| 741 | + f"{dataset_name}results_{random_uuid}.json", |
| 742 | + test_annotations_json_path, |
| 743 | + ) |
| 744 | + else: |
| 745 | + acc = compute_vqa_accuracy( |
| 746 | + f"{dataset_name}results_{random_uuid}.json", |
| 747 | + test_questions_json_path, |
| 748 | + test_annotations_json_path, |
| 749 | + ) |
740 | 750 | # delete the temporary file
|
741 | 751 | os.remove(f"{dataset_name}results_{random_uuid}.json")
|
742 | 752 |
|
|
0 commit comments