|
139 | 139 | default=False,
|
140 | 140 | help="Whether to evaluate on TextVQA.",
|
141 | 141 | )
|
| 142 | + |
| 143 | +parser.add_argument( |
| 144 | + "--eval_gqa", |
| 145 | + action="store_true", |
| 146 | + default=False, |
| 147 | + help="Whether to evaluate on GQA.", |
| 148 | +) |
| 149 | + |
142 | 150 | parser.add_argument(
|
143 | 151 | "--eval_imagenet",
|
144 | 152 | action="store_true",
|
|
346 | 354 | default=None,
|
347 | 355 | )
|
348 | 356 |
|
| 357 | +# GQA Dataset |
| 358 | +parser.add_argument( |
| 359 | + "--gqa_train_image_dir_path", |
| 360 | + type=str, |
| 361 | + help="Path to the gqa train images directory.", |
| 362 | + default=None, |
| 363 | +) |
| 364 | +parser.add_argument( |
| 365 | + "--gqa_train_questions_json_path", |
| 366 | + type=str, |
| 367 | + help="Path to the gqa questions json file.", |
| 368 | + default=None, |
| 369 | +) |
| 370 | +parser.add_argument( |
| 371 | + "--gqa_train_annotations_json_path", |
| 372 | + type=str, |
| 373 | + help="Path to the gqa annotations json file", |
| 374 | + default=None, |
| 375 | +) |
| 376 | +parser.add_argument( |
| 377 | + "--gqa_test_image_dir_path", |
| 378 | + type=str, |
| 379 | + help="Path to the gqa test images directory.", |
| 380 | + default=None, |
| 381 | +) |
| 382 | +parser.add_argument( |
| 383 | + "--gqa_test_questions_json_path", |
| 384 | + type=str, |
| 385 | + help="Path to the gqa questions json file", |
| 386 | + default=None, |
| 387 | +) |
| 388 | +parser.add_argument( |
| 389 | + "--gqa_test_annotations_json_path", |
| 390 | + type=str, |
| 391 | + help="Path to the gqa annotations json file", |
| 392 | + default=None, |
| 393 | +) |
| 394 | + |
349 | 395 | ## Imagenet dataset
|
350 | 396 | parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
351 | 397 |
|
@@ -650,6 +696,44 @@ def main():
|
650 | 696 | "stddev": np.nanstd(scores),
|
651 | 697 | }
|
652 | 698 | )
|
| 699 | + |
| 700 | + if args.eval_gqa: |
| 701 | + print("Evaluating on GQA...") |
| 702 | + |
| 703 | + #load cached demonstration features on GQA |
| 704 | + if args.cached_demonstration_features is not None: |
| 705 | + cached_features = torch.load( |
| 706 | + f"{args.cached_demonstration_features}/imagenet.pkl", map_location="cpu" |
| 707 | + ) |
| 708 | + else: |
| 709 | + cached_features = None |
| 710 | + |
| 711 | + for shot in args.shots: |
| 712 | + scores = [] |
| 713 | + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): |
| 714 | + gqa_score = evaluate_vqa( |
| 715 | + args=args, |
| 716 | + eval_model=eval_model, |
| 717 | + num_shots=shot, |
| 718 | + seed=seed, |
| 719 | + dataset_name="gqa", |
| 720 | + max_new_tokens=10, |
| 721 | + cached_features=cached_features, |
| 722 | + ) |
| 723 | + if args.rank == 0: |
| 724 | + print(f"Shots {shot} Trial {trial} GQA score: {gqa_score}") |
| 725 | + scores.append(gqa_score) |
| 726 | + |
| 727 | + if args.rank == 0: |
| 728 | + print(f"Shots {shot} Mean GQA score: {np.nanmean(scores)}") |
| 729 | + results["gqa"].append( |
| 730 | + { |
| 731 | + "shots": shot, |
| 732 | + "trials": scores, |
| 733 | + "mean": np.nanmean(scores), |
| 734 | + "stddev": np.nanstd(scores), |
| 735 | + } |
| 736 | + ) |
653 | 737 |
|
654 | 738 | if args.eval_imagenet:
|
655 | 739 | print("Evaluating on ImageNet...")
|
@@ -968,6 +1052,13 @@ def evaluate_vqa(
|
968 | 1052 | test_image_dir_path = args.textvqa_image_dir_path
|
969 | 1053 | test_questions_json_path = args.textvqa_test_questions_json_path
|
970 | 1054 | test_annotations_json_path = args.textvqa_test_annotations_json_path
|
| 1055 | + elif dataset_name == "gqa": |
| 1056 | + train_image_dir_path = args.gqa_train_image_dir_path |
| 1057 | + train_questions_json_path = args.gqa_train_questions_json_path |
| 1058 | + train_annotations_json_path = args.gqa_train_annotations_json_path |
| 1059 | + test_image_dir_path = args.gqa_test_image_dir_path |
| 1060 | + test_questions_json_path = args.gqa_test_questions_json_path |
| 1061 | + test_annotations_json_path = args.gqa_test_annotations_json_path |
971 | 1062 | else:
|
972 | 1063 | raise ValueError(f"Unsupported dataset: {dataset_name}")
|
973 | 1064 |
|
|
0 commit comments