Skip to content
This repository was archived by the owner on Sep 24, 2022. It is now read-only.

Commit 587ffe7

Browse files
committedDec 21, 2018
succeeded to reproduce mocogan 🎉
1 parent cf3496c commit 587ffe7

File tree

5 files changed

+26
-26
lines changed

5 files changed

+26
-26
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ __pycache__
33
deploy.sh
44

55
data/*
6+
result/*
67
notebooks/*
78
!.gitkeep

‎result/.gitkeep

Whitespace-only changes.

‎src/dataset.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class VideoDataset(Dataset):
1616
def __init__(self, dataset_path, preprocess_func, video_length=16, image_size=64, \
17-
mode="train"):
17+
number_limit=-1, mode="train"):
1818
# TODO: currently, mode only support 'train'
1919

2020
root_path = dataset_path / 'preprocessed' / mode
@@ -28,16 +28,21 @@ def __init__(self, dataset_path, preprocess_func, video_length=16, image_size=64
2828
raise e
2929

3030
# collect video folder paths
31-
video_list = []
3231
with open(root_path/"list.txt") as f:
33-
for line in f.readlines():
34-
# append [color_path, depth_path, n_frames]
35-
color_path, depth_path, n_frames = line.strip().split(" ")
36-
video_list.append([
37-
root_path / color_path,
38-
root_path / depth_path,
39-
int(n_frames)
40-
])
32+
lines = f.readlines()
33+
34+
if number_limit != -1:
35+
lines = lines[:number_limit]
36+
37+
video_list = []
38+
for line in lines:
39+
# append [color_path, depth_path, n_frames]
40+
color_path, depth_path, n_frames = line.strip().split(" ")
41+
video_list.append([
42+
root_path / color_path,
43+
root_path / depth_path,
44+
int(n_frames)
45+
])
4146

4247
self.dataset_path = dataset_path
4348
self.root_path = root_path

‎src/train.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def prepare_dataset(configs):
2121
eval(f'preprocess_{configs["dataset"]["name"]}_dataset'),
2222
configs['video_length'],
2323
configs['image_size'],
24+
configs["dataset"]['number_limit'],
2425
)
2526

2627
def main():

‎src/trainer.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ class Trainer(object):
1515
def __init__(self, dataloader, configs):
1616

1717
self.batchsize = configs["batchsize"]
18+
self.epoch_iters = len(dataloader)
1819
self.max_iteration = configs["iterations"]
1920
self.video_length = configs["video_length"]
2021

2122
self.dataloader = dataloader
22-
self.dataiter = iter(dataloader)
2323

2424
self.log_dir = Path(configs["log_dir"]) / configs["experiment_name"]
2525
self.log_dir.mkdir(parents=True, exist_ok=True)
@@ -28,7 +28,7 @@ def __init__(self, dataloader, configs):
2828
self.tensorboard_dir.mkdir(parents=True, exist_ok=True)
2929

3030
self.logger = Logger(self.log_dir, self.tensorboard_dir,\
31-
configs["log_interval"], len(dataloader))
31+
configs["log_interval"], self.epoch_iters)
3232

3333
self.evaluation_interval = configs["evaluation_interval"]
3434
self.log_samples_interval = configs["log_samples_interval"]
@@ -38,18 +38,6 @@ def __init__(self, dataloader, configs):
3838
self.device = self.use_cuda and torch.device('cuda') or torch.device('cpu')
3939
self.configs = configs
4040

41-
def sample_real_batch(self):
42-
try:
43-
batch = next(self.dataiter)
44-
except StopIteration:
45-
self.data_iter = iter(self.dataloader)
46-
batch = next(self.dataiter)
47-
48-
if self.use_cuda:
49-
batch = batch.cuda()
50-
51-
return batch.float()
52-
5341
def create_optimizer(self, model, lr, decay):
5442
return optim.Adam(
5543
model.parameters(),
@@ -90,11 +78,11 @@ def train(self, gen, idis, vdis):
9078

9179
# training loop
9280
logger = self.logger
81+
dataiter = iter(self.dataloader)
9382
while True:
9483
#--------------------
9584
# phase generator
9685
#--------------------
97-
9886
gen.train(); opt_gen.zero_grad()
9987

10088
# fake batch
@@ -118,7 +106,9 @@ def train(self, gen, idis, vdis):
118106
vdis.train(); opt_vdis.zero_grad()
119107

120108
# real batch
121-
x_real = Variable(self.sample_real_batch())
109+
x_real = next(dataiter).float()
110+
x_real = x_real.cuda() if self.use_cuda else x_fake
111+
x_real = Variable(x_real)
122112

123113
y_real_i = idis(x_real[:,:,t_rand])
124114
y_real_v = vdis(x_real)
@@ -143,6 +133,9 @@ def train(self, gen, idis, vdis):
143133

144134
iteration = self.logger.metrics["iteration"]
145135

136+
if iteration % (self.epoch_iters-1) == 0:
137+
dataiter = iter(self.dataloader)
138+
146139
# snapshot models
147140
if iteration % configs["snapshot_interval"] == 0:
148141
torch.save( gen, str(self.log_dir/'gen_{:05d}.pytorch'.format(iteration)))

0 commit comments

Comments
 (0)
This repository has been archived.