Skip to content

Commit

Permalink
fix seed for iterator + multi gpu fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed May 24, 2018
1 parent 31a8412 commit 1953d39
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
11 changes: 8 additions & 3 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,25 @@ def train_epoch(self, train_iter, epoch):
normalization = 0
idx += 1

# Make sure to process remaining batches
# Make sure to process remaining batches in the case of
# grad_accum_count > 1 but not enough batches to fill true_batchs
if len(true_batchs) > 0:
self._gradient_accumulation(
true_batchs, total_stats,
report_stats, normalization)

# NOTE: In multi-gpu cases every processes needs to call reduce/gather
# There is a total of (i+1) iterations.
# There is a total of i+1 iterations.
# If this number isn't divisible by `n_gpu`
# then, there will be a point (last iterations) where only
# `(i+1) % self.n_gpu` GPUs will effectively work
# i.e. run _gradient_accumulation which will be blocking.
# therefore, we run those operations that are awaited.
if self.gpu_rank >= self.n_gpu - ((i+1) % self.n_gpu):
# print("GPU: %d there was %d batches last_batch: %d" % (self.gpu_rank, i, len(true_batchs)))
if self.gpu_rank >= ((i+1) % self.n_gpu):
if len(true_batchs) == 0:
# add dummy gradients, just to unlock
# print("GPU: %d orphan empty batch padding i: %d" % (self.gpu_rank, i))
grads = [p.grad.data.mul(0)
for p in self.model.parameters() if p.requires_grad]
onmt.utils.multi_utils.all_reduce_and_rescale_tensors(
Expand All @@ -241,6 +244,8 @@ def train_epoch(self, train_iter, epoch):
report_stats = self.maybe_report_training(
epoch, idx, num_batches, self.optim.learning_rate,
report_stats)
#else:
# print("GPU: %d orphan non-empty batch i: %d" % (self.gpu_rank, i))

return total_stats

Expand Down
6 changes: 5 additions & 1 deletion train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import sys
from datetime import datetime

import random
import torch
import torch.nn as nn
from torch import cuda
Expand Down Expand Up @@ -65,6 +65,10 @@ def training_opt_postprocessing(opt):
if opt.gpuid:
torch.cuda.set_device(opt.device_id)
if opt.seed > 0:
# this one is needed for torchtext random call (shuffled iterator)
# in multi gpu it ensures datasets are read in the same order
random.seed(opt.seed)
# These ensure same initialization in multi gpu mode
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)

Expand Down

0 comments on commit 1953d39

Please sign in to comment.