Skip to content

Commit 82e2a3c

Browse files
committed
add readme for zero-shot image retrieval
1 parent 9ab3e05 commit 82e2a3c

9 files changed

Lines changed: 45 additions & 619 deletions

File tree

README.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ Code and pre-trained models for **ViLBERT: Pretraining Task-Agnostic Visiolingui
1010
```text
1111
conda create -n vilbert python=3.6
1212
conda activate vilbert
13-
git clone https://github.com/jiasenlu/ViLBert
14-
cd ViLBert
13+
git clone https://github.com/jiasenlu/vilbert_v0.1
14+
cd vilbert_v0.1
1515
pip install -r requirements.txt
1616
```
1717

@@ -45,13 +45,13 @@ Check `README.md` under `data` for more details.
4545

4646
## Visiolinguistic Pre-training
4747

48-
To train the model:
48+
Once you extracted all the image features, to train the model:
4949

5050
```
5151
5252
```
5353

54-
Distributed Training:
54+
train the model in a distributed setting:
5555
```
5656
5757
```
@@ -60,10 +60,12 @@ Distributed Training:
6060

6161
We can directly use the Pre-trained ViLBERT model for zero-shot image retrieval tasks on Flickr30k.
6262

63-
To evaluate on Flickr30k:
63+
First, update `featyres_h5path1` and `val_annotations_jsonpath` in `vlbert_task.yml` to load the Flickr30k testset image feature and jsonfile (defualt is training feature).
6464

65-
```
66-
python
65+
Then, use the following command to evaluate pre-trained 6 layer ViLBERT model. (only support single GPU for evaluation now):
66+
67+
```bash
68+
python eval_retrieval.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect/pytorch_model_9.bin --config_file config/bert_base_6layer_6conect.json --task 3 --split test --batch_size 1 --zero_shot
6769
```
6870

6971

@@ -91,7 +93,7 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 t
9193
```
9294

9395
### Image Retrieval
94-
```
96+
```bash
9597
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 4e-5 --num_workers 9 --tasks 11 --save_name pretrained
9698
```
9799

eval_tasks.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main():
8888
"Positive power of 2: static loss scaling value.\n",
8989
)
9090
parser.add_argument(
91-
"--num_workers", type=int, default=16, help="Number of workers in the dataloader."
91+
"--num_workers", type=int, default=10, help="Number of workers in the dataloader."
9292
)
9393
parser.add_argument(
9494
"--save_name",
@@ -97,10 +97,7 @@ def main():
9797
help="save name for training.",
9898
)
9999
parser.add_argument(
100-
"--use_chunk", default=0, type=float, help="whether use chunck for parallel training."
101-
)
102-
parser.add_argument(
103-
"--batch_size", default=1024, type=int, help="what is the batch size?"
100+
"--batch_size", default=1000, type=int, help="what is the batch size?"
104101
)
105102
parser.add_argument(
106103
"--tasks", default='', type=str, help="1-2-3... training task separate by -"
@@ -117,7 +114,7 @@ def main():
117114

118115
args = parser.parse_args()
119116
with open('vlbert_tasks.yml', 'r') as f:
120-
task_cfg = edict(yaml.load(f))
117+
task_cfg = edict(yaml.safe_load(f))
121118

122119
random.seed(args.seed)
123120
np.random.seed(args.seed)
@@ -150,7 +147,6 @@ def main():
150147
torch.cuda.set_device(args.local_rank)
151148
device = torch.device("cuda", args.local_rank)
152149
n_gpu = 1
153-
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
154150
torch.distributed.init_process_group(backend="nccl")
155151

156152
logger.info(
@@ -202,12 +198,10 @@ def main():
202198

203199
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
204200

205-
print("***** Running training *****")
206201
print(" Num Iters: ", task_num_iters)
207202
print(" Batch size: ", task_batch_size)
208203

209204
model.eval()
210-
# when run evaluate, we run each task sequentially.
211205
for task_id in task_ids:
212206
results = []
213207
others = []

parallel/data_parallel.py

Lines changed: 0 additions & 226 deletions
This file was deleted.

0 commit comments

Comments
 (0)