Skip to content

Commit e8f77b2

Browse files
authored
Add files via upload
1 parent 0a9e133 commit e8f77b2

File tree

3 files changed

+322
-0
lines changed

3 files changed

+322
-0
lines changed

config.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
2+
import os.path as osp
3+
import os
4+
5+
6+
class parser:
7+
def __init__(self):
8+
self.dataroot = 'dataset'
9+
self.datamode = 'train' #train, test
10+
self.stage = 'TOM' #GMM, SEG, TOM
11+
self.runmode = self.datamode
12+
self.name = self.stage
13+
if self.datamode == 'train':
14+
self.data_list = 'train_pairs.txt'
15+
elif self.datamode == 'test':
16+
self.data_list = 'test_pairs.txt'
17+
self.fine_width = 192
18+
self.fine_height = 256
19+
self.radius = 4
20+
self.grid_path = osp.join(self.dataroot, 'grid.png')
21+
if self.datamode == 'train': #for training keep true, for test keep false
22+
self.shuffle = True
23+
else:
24+
self.shuffle = False
25+
self.batch_size = 16
26+
self.workers = 1
27+
self.grid_size = 5
28+
29+
self.lr = 0.002
30+
self.keep_step = 8000
31+
self.decay_step = 5500
32+
self.previous_step = 0 #if you want to resume training from some steps
33+
#set previous_step in as per last updated checkpoints
34+
self.save_count = 200
35+
self.display_count = 50
36+
37+
self.tensorboard_dir = osp.join(os.getcwd(), 'tensorboard')
38+
self.checkpoint_dir = osp.join(os.getcwd(), 'checkpoints')
39+
self.save_dir = osp.join(os.getcwd(), 'outputs') #for saving output while infernce
40+
if not osp.exists(self.save_dir):
41+
os.makedirs(self.save_dir)
42+
if self.previous_step == 0:
43+
self.checkpoint = ''
44+
else:
45+
self.checkpoint = osp.join(self.checkpoint_dir, self.name, 'step_%06d.pth' % (self.previous_step))
46+
47+
self.input_image_path = 'custom/input/019579_0.jpg'
48+
self.cloth_image_path = 'custom/input/017575_1.jpg'
49+
self.human_parsing_image_path = 'custom/input/019579_0.png'

inference.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import torch.nn as nn
4+
import torchvision.transforms as transforms
5+
6+
import numpy as np
7+
import json
8+
import os
9+
import os.path as osp
10+
from PIL import Image
11+
from PIL import ImageDraw
12+
13+
import time
14+
import warnings
15+
from tqdm import tqdm
16+
from predict_pose import generate_pose_keypoints
17+
from visualization import load_checkpoint, save_images
18+
from gmm import GMM
19+
from unet import UnetGenerator
20+
from config import parser
21+
22+
warnings.filterwarnings("ignore")
23+
24+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25+
26+
def tensortoimage(t, path):
27+
im = transforms.ToPILImage()(t).convert("RGB")
28+
im.save(path)
29+
30+
def generate_data(opt, im_path, cloth_path, pose_path, segm_path):
31+
32+
transform = transforms.Compose([ \
33+
transforms.ToTensor(), \
34+
transforms.Normalize((0.5,), (0.5,))])
35+
36+
c = Image.open(cloth_path)
37+
c = transform(c)
38+
39+
im = Image.open(osp.join(im_path))
40+
im = transform(im)
41+
42+
im_parse = Image.open(segm_path)
43+
parse_array = np.array(im_parse)
44+
45+
parse_shape = (parse_array > 0).astype(np.float32)
46+
47+
parse_head = (parse_array == 1).astype(np.float32) + \
48+
(parse_array == 2).astype(np.float32) + \
49+
(parse_array == 4).astype(np.float32) + \
50+
(parse_array == 13).astype(np.float32)
51+
52+
parse_ttp = (parse_array == 1).astype(np.float32) + \
53+
(parse_array == 2).astype(np.float32) + \
54+
(parse_array == 4).astype(np.float32) + \
55+
(parse_array == 13).astype(np.float32) + \
56+
(parse_array == 3).astype(np.float32) + \
57+
(parse_array == 8).astype(np.float32) + \
58+
(parse_array == 9).astype(np.float32) + \
59+
(parse_array == 10).astype(np.float32) + \
60+
(parse_array == 11).astype(np.float32) + \
61+
(parse_array == 12).astype(np.float32) + \
62+
(parse_array == 14).astype(np.float32) + \
63+
(parse_array == 3).astype(np.float32) + \
64+
(parse_array >= 15).astype(np.float32)
65+
66+
phead = torch.from_numpy(parse_head)
67+
ptexttp = torch.from_numpy(parse_ttp)
68+
69+
# shape downsample
70+
parse_shape = Image.fromarray((parse_shape*255).astype(np.uint8))
71+
parse_shape = parse_shape.resize((opt.fine_width//16, opt.fine_height//16), Image.BILINEAR)
72+
parse_shape = parse_shape.resize((opt.fine_width, opt.fine_height), Image.BILINEAR)
73+
shape = transform(parse_shape) # [-1,1]
74+
75+
im_h = im * phead - (1 - phead) # [-1,1], fill 0 for other parts
76+
im_ttp = im * ptexttp - (1- ptexttp)
77+
78+
with open(pose_path, 'r') as f:
79+
pose_label = json.load(f)
80+
pose_data = pose_label['people'][0]['pose_keypoints']
81+
pose_data = np.array(pose_data)
82+
pose_data = pose_data.reshape((-1,3))
83+
84+
point_num = pose_data.shape[0]
85+
pose_map = torch.zeros(point_num, opt.fine_height, opt.fine_width)
86+
r = opt.radius
87+
im_pose = Image.new('L', (opt.fine_width, opt.fine_height))
88+
pose_draw = ImageDraw.Draw(im_pose)
89+
for i in range(point_num):
90+
one_map = Image.new('L', (opt.fine_width, opt.fine_height))
91+
draw = ImageDraw.Draw(one_map)
92+
pointx = pose_data[i,0]
93+
pointy = pose_data[i,1]
94+
if pointx > 1 and pointy > 1:
95+
draw.rectangle((pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white')
96+
pose_draw.rectangle((pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white')
97+
one_map = transform(one_map)
98+
pose_map[i] = one_map[0]
99+
100+
# cloth-agnostic representation
101+
agnostic = torch.cat([shape, im_h, pose_map], 0)
102+
103+
return(torch.unsqueeze(agnostic, 0), torch.unsqueeze(c,0), torch.unsqueeze(im_ttp,0))
104+
105+
def main():
106+
opt = parser()
107+
108+
im_path = opt.input_image_path #person image path
109+
cloth_path = opt.cloth_image_path #cloth image path
110+
pose_path = opt.input_image_path.replace('.jpg', '_keypoints.json') #pose keypoint path
111+
generate_pose_keypoints(im_path) #generating pose keypoints
112+
segm_path = opt.human_parsing_image_path #segemented mask path
113+
img_name = im_path.split('/')[-1].split('.')[0] + '_'
114+
115+
agnostic, c, im_ttp = generate_data(opt, im_path, cloth_path, pose_path, segm_path)
116+
117+
agnostic = agnostic.to(device)
118+
c = c.to(device)
119+
im_ttp = im_ttp.to(device)
120+
121+
gmm = GMM(opt)
122+
load_checkpoint(gmm, os.path.join(opt.checkpoint_dir, 'GMM', 'gmm_final.pth'))
123+
gmm.to(device)
124+
gmm.eval()
125+
126+
unet_mask = UnetGenerator(25, 20, ngf=64)
127+
load_checkpoint(unet_mask, os.path.join(opt.checkpoint_dir, 'SEG', 'segm_final.pth'))
128+
unet_mask.to(device)
129+
unet_mask.eval()
130+
131+
tom = UnetGenerator(26, 4, ngf=64)
132+
load_checkpoint(tom, os.path.join(opt.checkpoint_dir, 'TOM', 'tom_final.pth'))
133+
tom.to(device)
134+
tom.eval()
135+
136+
with torch.no_grad():
137+
output_segm = unet_mask(torch.cat([agnostic, c], 1))
138+
grid_zero, theta, grid_one, delta_theta = gmm(agnostic, c)
139+
c_warp = F.grid_sample(c, grid_one, padding_mode='border')
140+
output_segm = F.log_softmax(output_segm, dim=1)
141+
142+
output_argm = torch.max(output_segm, dim=1, keepdim=True)[1]
143+
final_segm = torch.zeros(output_segm.shape).to(device).scatter(1, output_argm, 1.0)
144+
145+
input_tom = torch.cat([final_segm, c_warp, im_ttp], 1)
146+
147+
with torch.no_grad():
148+
output_tom = tom(input_tom)
149+
person_r = torch.tanh(output_tom[:,:3,:,:])
150+
mask_c = torch.sigmoid(output_tom[:,3:,:,:])
151+
mask_c = (mask_c >= 0.5).type(torch.float)
152+
img_tryon = mask_c * c_warp + (1 - mask_c) * person_r
153+
print('Output generated!')
154+
155+
c_warp = c_warp*0.5+0.5
156+
output_argm = output_argm.type(torch.float)
157+
person_r = person_r*0.5+0.5
158+
img_tryon = img_tryon*0.5+0.5
159+
160+
tensortoimage(c_warp[0].cpu(), osp.join(opt.save_dir, img_name+'w_cloth.png'))
161+
tensortoimage(output_argm[0][0].cpu(), osp.join(opt.save_dir, img_name+'seg_mask.png'))
162+
tensortoimage(mask_c[0].cpu(), osp.join(opt.save_dir, img_name+'c_mask.png'))
163+
tensortoimage(person_r[0].cpu(), osp.join(opt.save_dir, img_name+'ren_person.png'))
164+
tensortoimage(img_tryon[0].cpu(), osp.join(opt.save_dir, img_name+'final_output.png'))
165+
print('Output saved at {}'.format(opt.save_dir))
166+
167+
if __name__ == "__main__":
168+
main()

predict_pose.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import cv2
2+
import numpy as np
3+
import os
4+
import json
5+
6+
class general_pose_model(object):
7+
def __init__(self, modelpath):
8+
# Specify the model to be used
9+
# Body25: 25 points
10+
# COCO: 18 points
11+
# MPI: 15 points
12+
self.inWidth = 368
13+
self.inHeight = 368
14+
self.threshold = 0.05
15+
self.pose_net = self.general_coco_model(modelpath)
16+
17+
def general_coco_model(self, modelpath):
18+
self.points_name = {
19+
"Nose": 0, "Neck": 1,
20+
"RShoulder": 2, "RElbow": 3, "RWrist": 4,
21+
"LShoulder": 5, "LElbow": 6, "LWrist": 7,
22+
"RHip": 8, "RKnee": 9, "RAnkle": 10,
23+
"LHip": 11, "LKnee": 12, "LAnkle": 13,
24+
"REye": 14, "LEye": 15,
25+
"REar": 16, "LEar": 17,
26+
"Background": 18}
27+
self.num_points = 18
28+
self.point_pairs = [[1, 0], [1, 2], [1, 5],
29+
[2, 3], [3, 4], [5, 6],
30+
[6, 7], [1, 8], [8, 9],
31+
[9, 10], [1, 11], [11, 12],
32+
[12, 13], [0, 14], [0, 15],
33+
[14, 16], [15, 17]]
34+
prototxt = os.path.join(
35+
modelpath,
36+
'pose_deploy_linevec.prototxt')
37+
caffemodel = os.path.join(
38+
modelpath,
39+
'pose_iter_440000.caffemodel')
40+
coco_model = cv2.dnn.readNetFromCaffe(prototxt, caffemodel)
41+
42+
return coco_model
43+
44+
def predict(self, imgfile):
45+
img_cv2 = cv2.imread(imgfile)
46+
img_height, img_width, _ = img_cv2.shape
47+
inpBlob = cv2.dnn.blobFromImage(img_cv2,
48+
1.0 / 255,
49+
(self.inWidth, self.inHeight),
50+
(0, 0, 0),
51+
swapRB=False,
52+
crop=False)
53+
self.pose_net.setInput(inpBlob)
54+
self.pose_net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
55+
self.pose_net.setPreferableTarget(cv2.dnn.DNN_TARGET_OPENCL)
56+
57+
output = self.pose_net.forward()
58+
59+
H = output.shape[2]
60+
W = output.shape[3]
61+
62+
points = []
63+
for idx in range(self.num_points):
64+
probMap = output[0, idx, :, :] # confidence map.
65+
66+
# Find global maxima of the probMap.
67+
minVal, prob, minLoc, point = cv2.minMaxLoc(probMap)
68+
69+
# Scale the point to fit on the original image
70+
x = (img_width * point[0]) / W
71+
y = (img_height * point[1]) / H
72+
73+
if prob > self.threshold:
74+
points.append(x)
75+
points.append(y)
76+
points.append(prob)
77+
else:
78+
points.append(0)
79+
points.append(0)
80+
points.append(0)
81+
82+
return points
83+
84+
def generate_pose_keypoints(img_file):
85+
86+
modelpath = 'pose'
87+
pose_model = general_pose_model(modelpath)
88+
89+
res_points = pose_model.predict(img_file)
90+
91+
pose_data = {"version": 1,
92+
"people": [
93+
{"pose_keypoints": res_points}
94+
]
95+
}
96+
97+
pose_keypoints_path = img_file.replace('.jpg', '_keypoints.json')
98+
99+
json_object = json.dumps(pose_data, indent = 4)
100+
101+
# Writing to sample.json
102+
with open(pose_keypoints_path, "w") as outfile:
103+
outfile.write(json_object)
104+
print('File saved at {}'.format(pose_keypoints_path))
105+

0 commit comments

Comments
 (0)