Skip to content

Commit

Permalink
add instructions for training
Browse files Browse the repository at this point in the history
  • Loading branch information
jiasenlu committed Aug 22, 2019
1 parent 40d675c commit 8c60dd2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
48 changes: 23 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# ViLBERT <img src="fig/vilbert_trim.png" width="45">

Code and pre-trained models for **ViLBERT: Pretraining Task-Agnostic VisiolinguisticRepresentations for Vision-and-Language Tasks**.

*Note: This code base is still in beta release.*
Code and pre-trained models for **[ViLBERT: Pretraining Task-Agnostic VisiolinguisticRepresentations for Vision-and-Language Tasks](https://arxiv.org/abs/1908.02265)**.

<span style="color:blue"> *Note: This codebase is still in beta release to replicate the paper's preformance. * </span>

## Repository Setup

Expand All @@ -25,7 +24,7 @@ conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
3. Install apx, follows https://github.com/NVIDIA/apex
## Data Setup

Check `README.md` under `data` for more details.
Check `README.md` under `data` for more details. Check `vlbert_tasks.yml` for more details.


## Pre-trained model for Evaluation
Expand All @@ -38,6 +37,8 @@ Check `README.md` under `data` for more details.
|ViLBERT 6-Layer| RefCOCO+ |[Link]()|
|ViLBERT 6-Layer| Image Retrieval |[Link]()|

## Evaluation

### Zero-Shot Image Retrieval

We can directly use the Pre-trained ViLBERT model for zero-shot image retrieval tasks on Flickr30k.
Expand Down Expand Up @@ -104,48 +105,45 @@ python eval_tasks.py --bert_model bert-base-uncased --from_pretrained save/refco

## Visiolinguistic Pre-training

Once you extracted all the image features, to train the model:

```
```

train the model in a distributed setting:

```
Once you extracted all the image features, to train a 6-layer ViLBERT model on conceptual caption:

```bash
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_concap.py --from_pretrained bert-base-uncased --bert_model bert-base-uncased --conf
ig_file config/bert_base_6layer_6conect.json --learning_rate 1e-4 --train_batch_size 512 --save_name pretrained
```



## TASKS
## Training Down-Stream Tasks

### VQA

To fintune a 6-layer ViLBERT model for VQA with 8 GPU. `--tasks 1` means VQA tasks. Check `vlbert_tasks.yml` for more settings for VQA tasks.
To fintune a 6-layer ViLBERT model for VQA with 8 GPU. `--tasks 0` means VQA tasks. Check `vlbert_tasks.yml` for more settings for VQA tasks.

```bash
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 4e-5 --num_workers 16 --tasks 1 --save_name pretrained
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 4e-5 --num_workers 16 --tasks 0 --save_name pretrained
```

### VCR

Similarly, to finetune a 6-layer vilbert model for VCR task, run the following commands. Here we joint train `Q->A ` and `QA->R` tasks, so the tasks is specified as `--tasks 6-7`
Similarly, to finetune a 6-layer vilbert model for VCR task, run the following commands. Here we joint train `Q->A ` and `QA->R` tasks, so the tasks is specified as `--tasks 1-2`

```bash
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 2e-5 --num_workers 16 --tasks 1-2 --save_name pretrained
```
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 2e-5 --num_workers 16 --tasks 6-7 --save_name pretrained

### Image Retrieval

```bash
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 4e-5 --num_workers 9 --tasks 3 --save_name pretrained
```

### Refer Expression
```
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 4e-5 --num_workers 16 --tasks 11 --save_name pretrained
```

### Image Retrieval
```bash
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 4e-5 --num_workers 9 --tasks 11 --save_name pretrained
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train_tasks.py --bert_model bert-base-uncased --from_pretrained save/bert_base_6_layer_6_connect_freeze_0/pytorch_model_8.bin --config_file config/bert_base_6layer_6conect.json --learning_rate 4e-5 --num_workers 16 --tasks 4 --save_name pretrained
```

- For single GPU training, use smaller batch size and simply remove ` -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 `

## References

If you find this code is useful for your research, please cite our paper
Expand Down
2 changes: 0 additions & 2 deletions train_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""

import argparse
import json
Expand Down Expand Up @@ -42,7 +41,6 @@
from vilbert.datasets import ConceptCapLoaderTrain, ConceptCapLoaderVal
from vilbert.basebert import BertForMultiModalPreTraining
from pytorch_pretrained_bert.modeling import BertConfig
import pdb

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down
1 change: 0 additions & 1 deletion train_concap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler
# from parallel.data_parallel import DataParallel
from tensorboardX import SummaryWriter

from pytorch_pretrained_bert.tokenization import BertTokenizer
Expand Down

0 comments on commit 8c60dd2

Please sign in to comment.