@@ -15,11 +15,11 @@ class Trainer(object):
15
15
def __init__ (self , dataloader , configs ):
16
16
17
17
self .batchsize = configs ["batchsize" ]
18
+ self .epoch_iters = len (dataloader )
18
19
self .max_iteration = configs ["iterations" ]
19
20
self .video_length = configs ["video_length" ]
20
21
21
22
self .dataloader = dataloader
22
- self .dataiter = iter (dataloader )
23
23
24
24
self .log_dir = Path (configs ["log_dir" ]) / configs ["experiment_name" ]
25
25
self .log_dir .mkdir (parents = True , exist_ok = True )
@@ -28,7 +28,7 @@ def __init__(self, dataloader, configs):
28
28
self .tensorboard_dir .mkdir (parents = True , exist_ok = True )
29
29
30
30
self .logger = Logger (self .log_dir , self .tensorboard_dir ,\
31
- configs ["log_interval" ], len ( dataloader ) )
31
+ configs ["log_interval" ], self . epoch_iters )
32
32
33
33
self .evaluation_interval = configs ["evaluation_interval" ]
34
34
self .log_samples_interval = configs ["log_samples_interval" ]
@@ -38,18 +38,6 @@ def __init__(self, dataloader, configs):
38
38
self .device = self .use_cuda and torch .device ('cuda' ) or torch .device ('cpu' )
39
39
self .configs = configs
40
40
41
- def sample_real_batch (self ):
42
- try :
43
- batch = next (self .dataiter )
44
- except StopIteration :
45
- self .data_iter = iter (self .dataloader )
46
- batch = next (self .dataiter )
47
-
48
- if self .use_cuda :
49
- batch = batch .cuda ()
50
-
51
- return batch .float ()
52
-
53
41
def create_optimizer (self , model , lr , decay ):
54
42
return optim .Adam (
55
43
model .parameters (),
@@ -90,11 +78,11 @@ def train(self, gen, idis, vdis):
90
78
91
79
# training loop
92
80
logger = self .logger
81
+ dataiter = iter (self .dataloader )
93
82
while True :
94
83
#--------------------
95
84
# phase generator
96
85
#--------------------
97
-
98
86
gen .train (); opt_gen .zero_grad ()
99
87
100
88
# fake batch
@@ -118,7 +106,9 @@ def train(self, gen, idis, vdis):
118
106
vdis .train (); opt_vdis .zero_grad ()
119
107
120
108
# real batch
121
- x_real = Variable (self .sample_real_batch ())
109
+ x_real = next (dataiter ).float ()
110
+ x_real = x_real .cuda () if self .use_cuda else x_fake
111
+ x_real = Variable (x_real )
122
112
123
113
y_real_i = idis (x_real [:,:,t_rand ])
124
114
y_real_v = vdis (x_real )
@@ -143,6 +133,9 @@ def train(self, gen, idis, vdis):
143
133
144
134
iteration = self .logger .metrics ["iteration" ]
145
135
136
+ if iteration % (self .epoch_iters - 1 ) == 0 :
137
+ dataiter = iter (self .dataloader )
138
+
146
139
# snapshot models
147
140
if iteration % configs ["snapshot_interval" ] == 0 :
148
141
torch .save ( gen , str (self .log_dir / 'gen_{:05d}.pytorch' .format (iteration )))
0 commit comments