Skip to content

Commit 35e77f0

Browse files
committed
update documentation and example scripts
1 parent b0ff9a4 commit 35e77f0

15 files changed

+258
-24
lines changed

docs/inputs.png

190 KB
Loading

docs/signature.png

91.9 KB
Loading

docs/xattn_langstream.png

82.2 KB
Loading

open_flamingo/eval/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# OpenFlamingo Evaluation Suite
2-
32
This is the evaluation module of OpenFlamingo. It contains a set of utilities for evaluating multimodal models on various benchmarking datasets.
43

54
*This module is a work in progress! We will be updating this README as it develops. In the meantime, if you notice an issue, please file a Bug Report or Feature Request [here](https://github.com/mlfoundations/open_flamingo/issues/new/choose).*
@@ -19,6 +18,15 @@ This is the evaluation module of OpenFlamingo. It contains a set of utilities fo
1918

2019
When evaluating a model using `num_shots` shots, we sample the exemplars from the training split. Performance is evaluated on a disjoint test split, subsampled to `--num_samples` examples (or using the full test split if `--num_samples=-1`).
2120

21+
## Supported models
22+
This evaluation module interfaces with models using the `EvalModel` class defined in `eval/eval_models/eval_model.py`. The `EvalModel` wrapper standardizes the generation and rank classification interfaces.
23+
24+
To help standardize VLM evaluations, we have implemented EvalModel wrappers for models from three code repositories:
25+
26+
* This open_flamingo repository, i.e. all models created using this repository's `src` code
27+
* The pretrained [BLIP-2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) models. Note that this model can only take in one image per input sequence; this is not to be confused with the BLIP-like implementation in the open_flamingo repository, which can take in arbitrarily interleaved image/text sequences
28+
* Huggingface's [IDEFICS](https://huggingface.co/blog/idefics) models
29+
2230
## Sample scripts
2331
Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun`. We provide a sample Slurm evaluation script in `open_flamingo/open_flamingo/scripts/run_eval.sh`.
2432

open_flamingo/scripts/fill_vqa_testdev_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission.
2+
Helper scripts to prepare a Vizwiz or VQAv2 test-dev evaluation for EvalAI submission.
33
Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set.
44
Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction.
55
"""

open_flamingo/scripts/run_eval.sh renamed to open_flamingo/scripts/run_eval_ddp.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Notes:
99
- VQAv2 test-dev and test-std annotations are not publicly available.
1010
To evaluate on these splits, please follow the VQAv2 instructions and submit to EvalAI.
1111
This script will evaluate on the val split.
12+
- Vizwiz test-dev annotations are also not publicly available; please go through EvalAI.
1213
com
1314

1415
export PYTHONFAULTHANDLER=1
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/bin/bash
2+
#SBATCH --nodes=1
3+
#SBATCH --ntasks-per-node=2
4+
#SBATCH --gpus-per-task=1
5+
6+
<<com
7+
Example Slurm evaluation script.
8+
Notes:
9+
- VQAv2 test-dev and test-std annotations are not publicly available.
10+
To evaluate on these splits, please follow the VQAv2 instructions and submit to EvalAI.
11+
This script will evaluate on the val split.
12+
- Vizwiz test-dev annotations are also not publicly available; please go through EvalAI.
13+
com
14+
15+
export PYTHONFAULTHANDLER=1
16+
export CUDA_LAUNCH_BLOCKING=0
17+
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
18+
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
19+
export MASTER_PORT=$(shuf -i 0-65535 -n 1)
20+
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
21+
22+
echo go $COUNT_NODE
23+
echo $HOSTNAMES
24+
25+
export PYTHONPATH="$PYTHONPATH:open_flamingo"
26+
srun --cpu_bind=v --accel-bind=gn python
27+
deepspeed open_flamingo/open_flamingo/eval/evaluate.py \
28+
--vision_encoder_path ViT-L-14 \
29+
--vision_encoder_pretrained openai\
30+
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
31+
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
32+
--cross_attn_every_n_layers 1 \
33+
--checkpoint_path "openflamingo/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt" \
34+
--results_file "results.json" \
35+
--precision fp32 \
36+
--batch_size 8 \
37+
--deepspeed \
38+
--eval_coco \
39+
--eval_vqav2 \
40+
--eval_flickr30 \
41+
--eval_ok_vqa \
42+
--eval_textvqa \
43+
--eval_vizwiz \
44+
--eval_hateful_memes \
45+
--coco_train_image_dir_path "/path/to/mscoco_karpathy/train2014" \
46+
--coco_val_image_dir_path "/path/to/mscoco_karpathy/val2014" \
47+
--coco_karpathy_json_path "/path/to/mscoco_karpathy/dataset_coco.json" \
48+
--coco_annotations_json_path "/path/to/mscoco_karpathy/annotations/captions_val2014.json" \
49+
--vqav2_train_image_dir_path "/path/to/vqav2/train2014" \
50+
--vqav2_train_annotations_json_path "/path/to/vqav2/v2_mscoco_train2014_annotations.json" \
51+
--vqav2_train_questions_json_path "/path/to/vqav2/v2_OpenEnded_mscoco_train2014_questions.json" \
52+
--vqav2_test_image_dir_path "/path/to/vqav2/val2014" \
53+
--vqav2_test_annotations_json_path "/path/to/vqav2/v2_mscoco_val2014_annotations.json" \
54+
--vqav2_test_questions_json_path "/path/to/vqav2/v2_OpenEnded_mscoco_val2014_questions.json" \
55+
--flickr_image_dir_path "/path/to/flickr30k/flickr30k-images" \
56+
--flickr_karpathy_json_path "/path/to/flickr30k/dataset_flickr30k.json" \
57+
--flickr_annotations_json_path "/path/to/flickr30k/dataset_flickr30k_coco_style.json" \
58+
--ok_vqa_train_image_dir_path "/path/to/okvqa/train2014" \
59+
--ok_vqa_train_annotations_json_path "/path/to/okvqa/mscoco_train2014_annotations.json" \
60+
--ok_vqa_train_questions_json_path "/path/to/okvqa/OpenEnded_mscoco_train2014_questions.json" \
61+
--ok_vqa_test_image_dir_path "/path/to/okvqa/val2014" \
62+
--ok_vqa_test_annotations_json_path "/path/to/okvqa/mscoco_val2014_annotations.json" \
63+
--ok_vqa_test_questions_json_path "/path/to/okvqa/OpenEnded_mscoco_val2014_questions.json" \
64+
--textvqa_image_dir_path "/path/to/textvqa/train_images/" \
65+
--textvqa_train_questions_json_path "/path/to/textvqa/train_questions_vqa_format.json" \
66+
--textvqa_train_annotations_json_path "/path/to/textvqa/train_annotations_vqa_format.json" \
67+
--textvqa_test_questions_json_path "/path/to/textvqa/val_questions_vqa_format.json" \
68+
--textvqa_test_annotations_json_path "/path/to/textvqa/val_annotations_vqa_format.json" \
69+
--vizwiz_train_image_dir_path "/path/to/v7w/train" \
70+
--vizwiz_test_image_dir_path "/path/to/v7w/val" \
71+
--vizwiz_train_questions_json_path "/path/to/v7w/train_questions_vqa_format.json" \
72+
--vizwiz_train_annotations_json_path "/path/to/v7w/train_annotations_vqa_format.json" \
73+
--vizwiz_test_questions_json_path "/path/to/v7w/val_questions_vqa_format.json" \
74+
--vizwiz_test_annotations_json_path "/path/to/v7w/val_annotations_vqa_format.json" \
75+
--hateful_memes_image_dir_path "/path/to/hateful_memes/img" \
76+
--hateful_memes_train_annotations_json_path "/path/to/hateful_memes/train.json" \
77+
--hateful_memes_test_annotations_json_path "/path/to/hateful_memes/dev.json" \
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/bin/bash
2+
#SBATCH --nodes 1
3+
#SBATCH --ntasks-per-node=8
4+
#SBATCH --gpus-per-task=1
5+
#SBATCH --time=5-00:00:00
6+
#SBATCH --job-name=openflamingo
7+
8+
export PYTHONFAULTHANDLER=1
9+
export CUDA_LAUNCH_BLOCKING=0
10+
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
11+
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
12+
export MASTER_PORT=$(shuf -i 0-65535 -n 1)
13+
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
14+
15+
export PYTHONPATH="$PYTHONPATH:open_flamingo"
16+
srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \
17+
--lm_path meta-llama/Llama-2-13b \
18+
--tokenizer_path meta-llama/Llama-2-13b \
19+
--model_family flamingo \
20+
--cross_attn_every_n_layers 4 \
21+
--dataset_resampled \
22+
--batch_size_mmc4 16 \
23+
--batch_size_laion 32 \
24+
--train_num_samples_mmc4 125000\
25+
--train_num_samples_laion 250000 \
26+
--loss_multiplier_laion 0.2 \
27+
--workers=4 \
28+
--run_name "fsdp" \
29+
--num_epochs 480 \
30+
--warmup_steps 0 \
31+
--mmc4_textsim_threshold 0.0 \
32+
--laion_shards "/path/to/laion-samples/{000000..000001}.tar" \
33+
--mmc4_shards "/path/to/mmc4-samples/{000000..000001}.tar" \
34+
--report_to_wandb
Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
11
#!/bin/bash
22
#SBATCH --nodes 1
3-
#SBATCH --ntasks-per-node=6
3+
#SBATCH --ntasks-per-node=8
44
#SBATCH --gpus-per-task=1
5-
#SBATCH --account=efml
6-
#SBATCH --partition=gpu
7-
#SBATCH --time=48:00:00
8-
#SBATCH --job-name=flamingo
5+
#SBATCH --time=5-00:00:00
6+
#SBATCH --job-name=openflamingo
97

108
export PYTHONFAULTHANDLER=1
119
export CUDA_LAUNCH_BLOCKING=0
1210
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
1311
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
14-
export MASTER_PORT=15000
12+
export MASTER_PORT=$(shuf -i 0-65535 -n 1)
1513
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
16-
export HF_DATASETS_CACHE="/gscratch/efml/anasa2/.huggingface" TRANSFORMERS_CACHE="/gscratch/efml/anasa2/.huggingface"
1714

1815
export PYTHONPATH="$PYTHONPATH:open_flamingo"
1916
srun --cpu_bind=v --accel-bind=gn python
2017

2118
deepspeed open_flamingo/open_flamingo/train/train.py \
2219
--lm_path meta-llama/Llama-2-13b \
2320
--tokenizer_path meta-llama/Llama-2-13b \
21+
--model_family flamingo \
2422
--cross_attn_every_n_layers 4 \
2523
--dataset_resampled \
2624
--batch_size_mmc4 16 \
@@ -34,7 +32,6 @@ deepspeed open_flamingo/open_flamingo/train/train.py \
3432
--num_epochs 480 \
3533
--warmup_steps 0 \
3634
--mmc4_textsim_threshold 0.0 \
37-
--laion_shards "/mmfs1/gscratch/efml/anasa2/laion-samples/{000000..000001}.tar" \
38-
--mmc4_shards "/mmfs1/gscratch/efml/anasa2/mmc4-samples/shard_{0..1}-000000000.tar" \
39-
--gradient_checkpointing \
40-
--report_to_wandb \
35+
--laion_shards "/path/to/laion-samples/{000000..000001}.tar" \
36+
--mmc4_shards "/path/to/mmc4-samples/{000000..000001}.tar" \
37+
--report_to_wandb
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
#SBATCH --nodes 1
3+
#SBATCH --ntasks-per-node=8
4+
#SBATCH --gpus-per-task=1
5+
#SBATCH --time=5-00:00:00
6+
#SBATCH --job-name=openflamingo
7+
8+
<<com
9+
To use FSDP, please make sure to use Pytorch Nightly > 2.0.1!
10+
com
11+
12+
export PYTHONFAULTHANDLER=1
13+
export CUDA_LAUNCH_BLOCKING=0
14+
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
15+
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16+
export MASTER_PORT=$(shuf -i 0-65535 -n 1)
17+
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
18+
19+
export PYTHONPATH="$PYTHONPATH:open_flamingo"
20+
srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \
21+
--lm_path meta-llama/Llama-2-13b \
22+
--tokenizer_path meta-llama/Llama-2-13b \
23+
--model_family flamingo \
24+
--cross_attn_every_n_layers 4 \
25+
--dataset_resampled \
26+
--batch_size_mmc4 16 \
27+
--batch_size_laion 32 \
28+
--fsdp \
29+
--fsdp_sharding_strategy hybrid \
30+
--train_num_samples_mmc4 125000\
31+
--train_num_samples_laion 250000 \
32+
--loss_multiplier_laion 0.2 \
33+
--workers=4 \
34+
--run_name "fsdp" \
35+
--num_epochs 480 \
36+
--warmup_steps 0 \
37+
--mmc4_textsim_threshold 0.0 \
38+
--laion_shards "/path/to/laion-samples/{000000..000001}.tar" \
39+
--mmc4_shards "/path/to/mmc4-samples/{000000..000001}.tar" \
40+
--report_to_wandb

0 commit comments

Comments
 (0)