Skip to content

Commit 2cd6ca5

Browse files
committed
add and test on the evaluation of Refcoco+
1 parent 06ccee0 commit 2cd6ca5

File tree

4 files changed

+75
-12
lines changed

4 files changed

+75
-12
lines changed

README.md

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Code and pre-trained models for **ViLBERT: Pretraining Task-Agnostic VisiolinguisticRepresentations for Vision-and-Language Tasks**.
44

55

6+
7+
*Note: This is beta release which *
8+
9+
610
## Repository Setup
711

812
1. Create a fresh conda environment, and install all dependencies.
@@ -42,32 +46,85 @@ Check `README.md` under `data` for more details.
4246
|ViLBERT 6-Layer| RefCOCO+ |[Link]()|
4347
|ViLBERT 6-Layer| Image Retrieval |[Link]()|
4448

49+
### Zero-Shot Image Retrieval
4550

46-
## Visiolinguistic Pre-training
51+
We can directly use the Pre-trained ViLBERT model for zero-shot image retrieval tasks on Flickr30k.
4752

48-
Once you extracted all the image features, to train the model:
53+
1: Download the pretrained model with objective `Conceptual Caption` and put it under `save`
4954

55+
2: Update `featyres_h5path1` and `val_annotations_jsonpath` in `vlbert_task.yml` to load the Flickr30k testset image feature and jsonfile (defualt is training feature).
56+
57+
3: Use the following command to evaluate pre-trained 6 layer ViLBERT model. (only support single GPU for evaluation now):
58+
59+
```bash
60+
python eval_retrieval.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect/pytorch_model_9.bin --config_file config/bert_base_6layer_6conect.json --task 3 --split test --batch_size 1 --zero_shot
5061
```
5162

63+
### Image Retrieval
64+
65+
1: Download the pretrained model with objective `Image Retrieval` and put it under `save`
66+
67+
2: Update `featyres_h5path1` and `val_annotations_jsonpath` in `vlbert_task.yml` to load the Flickr30k testset image feature and jsonfile (defualt is training feature).
68+
69+
3: Use the following command to evaluate pre-trained 6 layer ViLBERT model. (only support single GPU for evaluation now):
70+
71+
```bash
72+
python eval_retrieval.py --bert_model bert-base-uncased --from_pretrained save/RetrievalFlickr30k_bert_base_6layer_6conect-pretrained/pytorch_model_19.bin --config_file config/bert_base_6layer_6conect.json --task 3 --split test --batch_size 1
5273
```
5374

54-
train the model in a distributed setting:
75+
### VQA
76+
77+
1: Download the pretrained model with objective `VQA` and put it under `save`
78+
79+
2: To test on held out validation split, use the following command:
80+
5581
```
5682
5783
```
5884

59-
### Zero-Shot Image Retrieval
85+
### VCR
6086

61-
We can directly use the Pre-trained ViLBERT model for zero-shot image retrieval tasks on Flickr30k.
87+
1: Download the pretrained model with objective `VCR` and put it under `save`
88+
89+
2: To test on VCR Q->A
90+
91+
```
92+
93+
```
94+
95+
3: To test on VCR QA->R
6296

63-
First, update `featyres_h5path1` and `val_annotations_jsonpath` in `vlbert_task.yml` to load the Flickr30k testset image feature and jsonfile (defualt is training feature).
97+
```
98+
99+
```
64100

65-
Then, use the following command to evaluate pre-trained 6 layer ViLBERT model. (only support single GPU for evaluation now):
101+
### RefCOCO+
102+
103+
1: Download the pretrained model with objective `RefCOCO+` and put it under `save`
104+
105+
2: We use the Pre-computed detections/masks from [MAttNet](https://github.com/lichengunc/MAttNet) for fully-automatic comprehension task, Check the MAttNet repository for more details.
106+
107+
3: To test on the RefCOCO+ val set and use the following command:
66108

67109
```bash
68-
python eval_retrieval.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect/pytorch_model_9.bin --config_file config/bert_base_6layer_6conect.json --task 3 --split test --batch_size 1 --zero_shot
110+
python eval_tasks.py --bert_model bert-base-uncased --from_pretrained save/refcoco+_bert_base_6layer_6conect-pretrained/pytorch_model_19.bin --config_file config/bert_base_6layer_6conect.json --task 4
111+
```
112+
113+
## Visiolinguistic Pre-training
114+
115+
Once you extracted all the image features, to train the model:
116+
69117
```
70118
119+
```
120+
121+
train the model in a distributed setting:
122+
123+
```
124+
125+
```
126+
127+
71128

72129
## TASKS
73130

eval_tasks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch.nn as nn
2020

2121
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
22-
2322
from vilbert.task_utils import LoadDatasetEval, LoadLosses, ForwardModelsTrain, ForwardModelsVal, EvaluatingModel
2423

2524
import vilbert.utils as utils

vilbert/datasets/refer_expression_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def __init__(
7979

8080
self.max_region_num = max_region_num
8181

82+
if not os.path.exists(os.path.join(dataroot, "cache")):
83+
os.makedirs(os.path.join(dataroot, "cache"))
84+
8285
cache_path = os.path.join(dataroot, "cache", task + '_' + split + '_' + str(max_seq_length)+ "_" + str(max_region_num) + '.pkl')
8386
if not os.path.exists(cache_path):
8487
self.tokenize()

vilbert/datasets/vqa_dataset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def __init__(
100100
):
101101
super().__init__()
102102
self.split = split
103-
ans2label_path = os.path.join('data', task, "cache", "trainval_ans2label.pkl")
104-
label2ans_path = os.path.join('data', task, "cache", "trainval_label2ans.pkl")
103+
ans2label_path = os.path.join(dataroot, "cache", "trainval_ans2label.pkl")
104+
label2ans_path = os.path.join(dataroot, "cache", "trainval_label2ans.pkl")
105105
self.ans2label = cPickle.load(open(ans2label_path, "rb"))
106106
self.label2ans = cPickle.load(open(label2ans_path, "rb"))
107107
self.num_labels = len(self.ans2label)
@@ -110,7 +110,11 @@ def __init__(
110110
self._image_features_reader = image_features_reader
111111
self._tokenizer = tokenizer
112112
self._padding_index = padding_index
113-
cache_path = os.path.join('data', task, "cache", task + '_' + split + '_' + str(max_seq_length)+'.pkl')
113+
114+
if not os.path.exists(os.path.join(dataroot, "cache")):
115+
os.makedirs(os.path.join(dataroot, "cache"))
116+
117+
cache_path = os.path.join(dataroot, "cache", task + '_' + split + '_' + str(max_seq_length)+'.pkl')
114118
if not os.path.exists(cache_path):
115119
self.entries = _load_dataset(dataroot, split)
116120
self.tokenize(max_seq_length)

0 commit comments

Comments
 (0)