Skip to content

Commit c84ea2a

Browse files
committed
Testing with defined pairs from files list!
1 parent 3c51dcb commit c84ea2a

File tree

9 files changed

+359
-254
lines changed

9 files changed

+359
-254
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.idea/
12
results/
23
runs/
34
sample/

data/aligned_dataset.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@ def initialize(self, opt):
1515
self.root = opt.dataroot
1616
self.diction = {}
1717

18+
self.fine_height = 256
19+
self.fine_width = 192
20+
self.radius = 5
21+
22+
# load data list from pairs file
23+
human_names = []
24+
cloth_names = []
25+
with open(os.path.join(opt.dataroot, opt.datapairs), 'r') as f:
26+
for line in f.readlines():
27+
h_name, c_name = line.strip().split()
28+
human_names.append(h_name)
29+
cloth_names.append(c_name)
30+
self.human_names = human_names
31+
self.cloth_names = cloth_names
32+
self.dataset_size = len(human_names)
33+
1834
# input A (label maps)
1935
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
2036
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
@@ -99,7 +115,13 @@ def __getitem__(self, index):
99115
# if '006581' in s:
100116
# test = k
101117
# break
102-
A_path = self.A_paths[index]
118+
119+
# get names from the pairs file
120+
c_name = self.cloth_names[index]
121+
h_name = self.human_names[index]
122+
123+
# A_path = self.A_paths[index]
124+
A_path = osp.join(self.dir_A, h_name.replace(".jpg", ".png"))
103125
A = Image.open(A_path).convert('L')
104126

105127
params = get_params(self.opt, A.size)
@@ -114,7 +136,8 @@ def __getitem__(self, index):
114136
B_tensor = inst_tensor = feat_tensor = 0
115137
# input B (real images)
116138

117-
B_path = self.B_paths[index]
139+
# B_path = self.B_paths[index]
140+
B_path = osp.join(self.dir_B, h_name)
118141
name = B_path.split('/')[-1]
119142

120143
B = Image.open(B_path).convert('RGB')
@@ -136,12 +159,14 @@ def __getitem__(self, index):
136159

137160
### input_C (color)
138161
# print(self.C_paths)
139-
C_path = self.C_paths[test]
162+
# C_path = self.C_paths[test]
163+
C_path = osp.join(self.dir_C, c_name)
140164
C = Image.open(C_path).convert('RGB')
141165
C_tensor = transform_B(C)
142166

143167
# Edge
144-
E_path = self.E_paths[test]
168+
# E_path = self.E_paths[test]
169+
E_path = osp.join(self.dir_E, c_name)
145170
# print(E_path)
146171
E = Image.open(E_path).convert('L')
147172
E_tensor = transform_A(E)

models/base_model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2-
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
1+
# Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2+
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
33
import os
44
import torch
55
import sys
66

7+
78
class BaseModel(torch.nn.Module):
89
def name(self):
910
return 'BaseModel'
@@ -49,18 +50,18 @@ def save_network(self, network, network_label, epoch_label, gpu_ids):
4950
# network.cuda()
5051

5152
# helper loading function that can be used by subclasses
52-
def load_network(self, network, network_label, epoch_label, save_dir=''):
53+
def load_network(self, network, network_label, epoch_label, save_dir=''):
5354
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
54-
print (save_filename)
55+
print(save_filename)
5556
if not save_dir:
5657
save_dir = self.save_dir
57-
save_path = os.path.join(save_dir, save_filename)
58+
save_path = os.path.join(save_dir, save_filename)
5859
if not os.path.isfile(save_path):
5960
print('%s not exists yet!' % save_path)
6061
if network_label == 'G':
6162
raise('Generator must exist!')
6263
else:
63-
#network.load_state_dict(torch.load(save_path))
64+
# network.load_state_dict(torch.load(save_path))
6465

6566
network.load_state_dict(torch.load(save_path))
6667
# except:

models/models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2-
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
1+
# Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2+
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
33
import torch
44
import ipdb
55

6+
67
def create_model(opt):
78
if opt.model == 'pix2pixHD':
89
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
910
if opt.isTrain:
1011
model = Pix2PixHDModel()
11-
#ipdb.set_trace()
12+
# ipdb.set_trace()
1213
else:
1314
model = InferenceModel()
1415

0 commit comments

Comments
 (0)