Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenRocks committed Aug 7, 2020
1 parent f728eb1 commit cd54665
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 59 deletions.
2 changes: 1 addition & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
from .data import (TxtTokLmdb, DetectFeatLmdb,
ImageLmdbGroup, ConcatDatasetWithLens)
from .sampler import TokenBucketSampler, DistributedTokenBucketSampler
from .sampler import TokenBucketSampler
from .loader import PrefetchLoader
from .vqa import VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate
from .nlvr2 import (Nlvr2PairedDataset, Nlvr2PairedEvalDataset,
Expand Down
11 changes: 1 addition & 10 deletions data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,13 @@ def compute_num_bb(confs, conf_th, min_bb, max_bb):

def _check_distributed():
try:
dist = hvd.size() != hvd.local_size
dist = hvd.size() != hvd.local_size()
except ValueError:
# not using horovod
dist = False
return dist


def _null():
return None


def default_none_dict(dictionary):
""" make picklable default dict """
return defaultdict(_null, dictionary)


class DetectFeatLmdb(object):
def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36,
compress=True):
Expand Down
5 changes: 1 addition & 4 deletions data/nlvr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from cytoolz import concat

from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb,
get_ids_and_lens, pad_tensors, get_gather_index,
default_none_dict)
get_ids_and_lens, pad_tensors, get_gather_index)


class Nlvr2PairedDataset(DetectFeatTxtTokDataset):
Expand Down Expand Up @@ -93,7 +92,6 @@ def nlvr2_paired_collate(inputs):
'gather_index': gather_index,
'img_type_ids': img_type_ids,
'targets': targets}
batch = default_none_dict(batch)
return batch


Expand Down Expand Up @@ -200,7 +198,6 @@ def nlvr2_triplet_collate(inputs):
'gather_index': gather_index,
'img_type_ids': img_type_ids,
'targets': targets}
batch = default_none_dict(batch)
return batch


Expand Down
10 changes: 0 additions & 10 deletions data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,3 @@ def __iter__(self):
def __len__(self):
raise ValueError("NOT supported. "
"This has some randomness across epochs")


class DistributedTokenBucketSampler(TokenBucketSampler):
def __init__(self, num_replicas, rank, *args, **kwargs):
super().__init__(*args, **kwargs)
self._rank = rank
self._num_replicas = num_replicas

def _create_ids(self):
return super()._create_ids()[self._rank::self._num_replicas]
41 changes: 32 additions & 9 deletions model/nlvr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Uniter for NLVR2 model
"""
from collections import defaultdict

import torch
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -31,9 +33,15 @@ def init_type_embedding(self):
new_emb.weight.data[2, :].copy_(emb)
self.uniter.embeddings.token_type_embeddings = new_emb

def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
attn_masks, gather_index,
img_type_ids, targets, compute_loss=True):
def forward(self, batch, compute_loss=True):
batch = defaultdict(lambda: None, batch)
input_ids = batch['input_ids']
position_ids = batch['position_ids']
img_feat = batch['img_feat']
img_pos_feat = batch['img_pos_feat']
attn_masks = batch['attn_masks']
gather_index = batch['gather_index']
img_type_ids = batch['img_type_ids']
sequence_output = self.uniter(input_ids, position_ids,
img_feat, img_pos_feat,
attn_masks, gather_index,
Expand All @@ -46,6 +54,7 @@ def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
answer_scores = self.nlvr2_output(reshaped_output)

if compute_loss:
targets = batch['targets']
nlvr2_loss = F.cross_entropy(
answer_scores, targets, reduction='none')
return nlvr2_loss
Expand All @@ -72,9 +81,15 @@ def init_type_embedding(self):
new_emb.weight.data[2, :].copy_(emb)
self.uniter.embeddings.token_type_embeddings = new_emb

def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
attn_masks, gather_index,
img_type_ids, targets, compute_loss=True):
def forward(self, batch, compute_loss=True):
batch = defaultdict(lambda: None, batch)
input_ids = batch['input_ids']
position_ids = batch['position_ids']
img_feat = batch['img_feat']
img_pos_feat = batch['img_pos_feat']
attn_masks = batch['attn_masks']
gather_index = batch['gather_index']
img_type_ids = batch['img_type_ids']
sequence_output = self.uniter(input_ids, position_ids,
img_feat, img_pos_feat,
attn_masks, gather_index,
Expand All @@ -84,6 +99,7 @@ def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
answer_scores = self.nlvr2_output(pooled_output)

if compute_loss:
targets = batch['targets']
nlvr2_loss = F.cross_entropy(
answer_scores, targets, reduction='none')
return nlvr2_loss
Expand Down Expand Up @@ -141,9 +157,15 @@ def init_type_embedding(self):
new_emb.weight.data[2, :].copy_(emb)
self.uniter.embeddings.token_type_embeddings = new_emb

def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
attn_masks, gather_index,
img_type_ids, targets, compute_loss=True):
def forward(self, batch, compute_loss=True):
batch = defaultdict(lambda: None, batch)
input_ids = batch['input_ids']
position_ids = batch['position_ids']
img_feat = batch['img_feat']
img_pos_feat = batch['img_pos_feat']
attn_masks = batch['attn_masks']
gather_index = batch['gather_index']
img_type_ids = batch['img_type_ids']
sequence_output = self.uniter(input_ids, position_ids,
img_feat, img_pos_feat,
attn_masks, gather_index,
Expand Down Expand Up @@ -174,6 +196,7 @@ def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
torch.cat([left_out, right_out], dim=-1))

if compute_loss:
targets = batch['targets']
nlvr2_loss = F.cross_entropy(
answer_scores, targets, reduction='none')
return nlvr2_loss
Expand Down
2 changes: 1 addition & 1 deletion model/vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class UniterForVisualQuestionAnswering(UniterPreTrainedModel):
""" Finetune multi-modal BERT for VQA
""" Finetune UNITER for VQA
"""
def __init__(self, config, img_dim, num_answer):
super().__init__(config)
Expand Down
47 changes: 23 additions & 24 deletions train_nlvr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from tqdm import tqdm

from data import (DistributedTokenBucketSampler, DetectFeatLmdb, TxtTokLmdb,
from data import (TokenBucketSampler, DetectFeatLmdb, TxtTokLmdb,
Nlvr2PairedDataset, Nlvr2PairedEvalDataset,
Nlvr2TripletDataset, Nlvr2TripletEvalDataset,
nlvr2_paired_collate, nlvr2_paired_eval_collate,
Expand All @@ -44,9 +44,8 @@ def create_dataloader(img_path, txt_path, batch_size, is_train,
opts.num_bb, opts.compressed_db)
txt_db = TxtTokLmdb(txt_path, opts.max_txt_len if is_train else -1)
dset = dset_cls(txt_db, img_db, opts.use_img_type)
sampler = DistributedTokenBucketSampler(
hvd.size(), hvd.rank(), dset.lens,
bucket_size=BUCKET_SIZE, batch_size=batch_size, droplast=is_train)
sampler = TokenBucketSampler(dset.lens, bucket_size=BUCKET_SIZE,
batch_size=batch_size, droplast=is_train)
loader = DataLoader(dset, batch_sampler=sampler,
num_workers=opts.n_workers, pin_memory=opts.pin_mem,
collate_fn=collate_fn)
Expand All @@ -73,7 +72,7 @@ def main(opts):

# train_examples = None
LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, "
f"{opts.train_img_dir}")
f"{opts.train_img_db}")
if 'paired' in opts.model:
DatasetCls = Nlvr2PairedDataset
EvalDatasetCls = Nlvr2PairedEvalDataset
Expand Down Expand Up @@ -156,7 +155,7 @@ def main(opts):
targets = batch['targets']
n_examples += targets.size(0)

loss = model(**batch, compute_loss=True)
loss = model(batch, compute_loss=True)
loss = loss.mean()
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
Expand All @@ -182,9 +181,7 @@ def main(opts):
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

# log loss
losses = all_gather_list(running_loss)
running_loss = RunningMeter(
'loss', sum(l.val for l in losses)/len(losses))
# NOTE: not gathered across GPUs for efficiency
TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
TB_LOGGER.step()

Expand Down Expand Up @@ -226,17 +223,19 @@ def main(opts):
break
n_epoch += 1
LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")
for split, loader in [('val', val_dataloader), ('test', test_dataloader)]:
LOGGER.info(f"Step {global_step}: start running "
f"validation on {split} split...")
log, results = validate(model, loader, split)
with open(f'{opts.output_dir}/results/'
f'{split}_results_{global_step}_'
f'rank{rank}_final.csv', 'w') as f:
for id_, ans in results:
f.write(f'{id_},{ans}\n')
TB_LOGGER.log_scaler_dict(log)
model_saver.save(model, f'{global_step}_final')
if opts.num_train_steps % opts.valid_steps != 0:
for split, loader in [('val', val_dataloader),
('test', test_dataloader)]:
LOGGER.info(f"Step {global_step}: start running "
f"validation on {split} split...")
log, results = validate(model, loader, split)
with open(f'{opts.output_dir}/results/'
f'{split}_results_{global_step}_'
f'rank{rank}.csv', 'w') as f:
for id_, ans in results:
f.write(f'{id_},{ans}\n')
TB_LOGGER.log_scaler_dict(log)
model_saver.save(model, global_step)


@torch.no_grad()
Expand All @@ -252,7 +251,7 @@ def validate(model, val_loader, split):
targets = batch['targets']
del batch['targets']
del batch['qids']
scores = model(**batch, targets=None, compute_loss=False)
scores = model(batch, compute_loss=False)
loss = F.cross_entropy(scores, targets, reduction='sum')
val_loss += loss.item()
tot_score += (scores.max(dim=-1, keepdim=False)[1] == targets
Expand Down Expand Up @@ -284,19 +283,19 @@ def validate(model, val_loader, split):
parser.add_argument("--train_txt_db",
default=None, type=str,
help="The input train corpus. (LMDB)")
parser.add_argument("--train_img_dir",
parser.add_argument("--train_img_db",
default=None, type=str,
help="The input train images.")
parser.add_argument("--val_txt_db",
default=None, type=str,
help="The input validation corpus. (LMDB)")
parser.add_argument("--val_img_dir",
parser.add_argument("--val_img_db",
default=None, type=str,
help="The input validation images.")
parser.add_argument("--test_txt_db",
default=None, type=str,
help="The input test corpus. (LMDB)")
parser.add_argument("--test_img_dir",
parser.add_argument("--test_img_db",
default=None, type=str,
help="The input test images.")
parser.add_argument('--compressed_db', action='store_true',
Expand Down

0 comments on commit cd54665

Please sign in to comment.