1
+ import torch
2
+ import torch .nn .functional as F
3
+ import torch .nn as nn
4
+ import torchvision .transforms as transforms
5
+
6
+ import numpy as np
7
+ import json
8
+ import os
9
+ import os .path as osp
10
+ from PIL import Image
11
+ from PIL import ImageDraw
12
+
13
+ import time
14
+ import warnings
15
+ from tqdm import tqdm
16
+ from predict_pose import generate_pose_keypoints
17
+ from visualization import load_checkpoint , save_images
18
+ from gmm import GMM
19
+ from unet import UnetGenerator
20
+ from config import parser
21
+
22
+ warnings .filterwarnings ("ignore" )
23
+
24
+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
25
+
26
+ def tensortoimage (t , path ):
27
+ im = transforms .ToPILImage ()(t ).convert ("RGB" )
28
+ im .save (path )
29
+
30
+ def generate_data (opt , im_path , cloth_path , pose_path , segm_path ):
31
+
32
+ transform = transforms .Compose ([ \
33
+ transforms .ToTensor (), \
34
+ transforms .Normalize ((0.5 ,), (0.5 ,))])
35
+
36
+ c = Image .open (cloth_path )
37
+ c = transform (c )
38
+
39
+ im = Image .open (osp .join (im_path ))
40
+ im = transform (im )
41
+
42
+ im_parse = Image .open (segm_path )
43
+ parse_array = np .array (im_parse )
44
+
45
+ parse_shape = (parse_array > 0 ).astype (np .float32 )
46
+
47
+ parse_head = (parse_array == 1 ).astype (np .float32 ) + \
48
+ (parse_array == 2 ).astype (np .float32 ) + \
49
+ (parse_array == 4 ).astype (np .float32 ) + \
50
+ (parse_array == 13 ).astype (np .float32 )
51
+
52
+ parse_ttp = (parse_array == 1 ).astype (np .float32 ) + \
53
+ (parse_array == 2 ).astype (np .float32 ) + \
54
+ (parse_array == 4 ).astype (np .float32 ) + \
55
+ (parse_array == 13 ).astype (np .float32 ) + \
56
+ (parse_array == 3 ).astype (np .float32 ) + \
57
+ (parse_array == 8 ).astype (np .float32 ) + \
58
+ (parse_array == 9 ).astype (np .float32 ) + \
59
+ (parse_array == 10 ).astype (np .float32 ) + \
60
+ (parse_array == 11 ).astype (np .float32 ) + \
61
+ (parse_array == 12 ).astype (np .float32 ) + \
62
+ (parse_array == 14 ).astype (np .float32 ) + \
63
+ (parse_array == 3 ).astype (np .float32 ) + \
64
+ (parse_array >= 15 ).astype (np .float32 )
65
+
66
+ phead = torch .from_numpy (parse_head )
67
+ ptexttp = torch .from_numpy (parse_ttp )
68
+
69
+ # shape downsample
70
+ parse_shape = Image .fromarray ((parse_shape * 255 ).astype (np .uint8 ))
71
+ parse_shape = parse_shape .resize ((opt .fine_width // 16 , opt .fine_height // 16 ), Image .BILINEAR )
72
+ parse_shape = parse_shape .resize ((opt .fine_width , opt .fine_height ), Image .BILINEAR )
73
+ shape = transform (parse_shape ) # [-1,1]
74
+
75
+ im_h = im * phead - (1 - phead ) # [-1,1], fill 0 for other parts
76
+ im_ttp = im * ptexttp - (1 - ptexttp )
77
+
78
+ with open (pose_path , 'r' ) as f :
79
+ pose_label = json .load (f )
80
+ pose_data = pose_label ['people' ][0 ]['pose_keypoints' ]
81
+ pose_data = np .array (pose_data )
82
+ pose_data = pose_data .reshape ((- 1 ,3 ))
83
+
84
+ point_num = pose_data .shape [0 ]
85
+ pose_map = torch .zeros (point_num , opt .fine_height , opt .fine_width )
86
+ r = opt .radius
87
+ im_pose = Image .new ('L' , (opt .fine_width , opt .fine_height ))
88
+ pose_draw = ImageDraw .Draw (im_pose )
89
+ for i in range (point_num ):
90
+ one_map = Image .new ('L' , (opt .fine_width , opt .fine_height ))
91
+ draw = ImageDraw .Draw (one_map )
92
+ pointx = pose_data [i ,0 ]
93
+ pointy = pose_data [i ,1 ]
94
+ if pointx > 1 and pointy > 1 :
95
+ draw .rectangle ((pointx - r , pointy - r , pointx + r , pointy + r ), 'white' , 'white' )
96
+ pose_draw .rectangle ((pointx - r , pointy - r , pointx + r , pointy + r ), 'white' , 'white' )
97
+ one_map = transform (one_map )
98
+ pose_map [i ] = one_map [0 ]
99
+
100
+ # cloth-agnostic representation
101
+ agnostic = torch .cat ([shape , im_h , pose_map ], 0 )
102
+
103
+ return (torch .unsqueeze (agnostic , 0 ), torch .unsqueeze (c ,0 ), torch .unsqueeze (im_ttp ,0 ))
104
+
105
+ def main ():
106
+ opt = parser ()
107
+
108
+ im_path = opt .input_image_path #person image path
109
+ cloth_path = opt .cloth_image_path #cloth image path
110
+ pose_path = opt .input_image_path .replace ('.jpg' , '_keypoints.json' ) #pose keypoint path
111
+ generate_pose_keypoints (im_path ) #generating pose keypoints
112
+ segm_path = opt .human_parsing_image_path #segemented mask path
113
+ img_name = im_path .split ('/' )[- 1 ].split ('.' )[0 ] + '_'
114
+
115
+ agnostic , c , im_ttp = generate_data (opt , im_path , cloth_path , pose_path , segm_path )
116
+
117
+ agnostic = agnostic .to (device )
118
+ c = c .to (device )
119
+ im_ttp = im_ttp .to (device )
120
+
121
+ gmm = GMM (opt )
122
+ load_checkpoint (gmm , os .path .join (opt .checkpoint_dir , 'GMM' , 'gmm_final.pth' ))
123
+ gmm .to (device )
124
+ gmm .eval ()
125
+
126
+ unet_mask = UnetGenerator (25 , 20 , ngf = 64 )
127
+ load_checkpoint (unet_mask , os .path .join (opt .checkpoint_dir , 'SEG' , 'segm_final.pth' ))
128
+ unet_mask .to (device )
129
+ unet_mask .eval ()
130
+
131
+ tom = UnetGenerator (26 , 4 , ngf = 64 )
132
+ load_checkpoint (tom , os .path .join (opt .checkpoint_dir , 'TOM' , 'tom_final.pth' ))
133
+ tom .to (device )
134
+ tom .eval ()
135
+
136
+ with torch .no_grad ():
137
+ output_segm = unet_mask (torch .cat ([agnostic , c ], 1 ))
138
+ grid_zero , theta , grid_one , delta_theta = gmm (agnostic , c )
139
+ c_warp = F .grid_sample (c , grid_one , padding_mode = 'border' )
140
+ output_segm = F .log_softmax (output_segm , dim = 1 )
141
+
142
+ output_argm = torch .max (output_segm , dim = 1 , keepdim = True )[1 ]
143
+ final_segm = torch .zeros (output_segm .shape ).to (device ).scatter (1 , output_argm , 1.0 )
144
+
145
+ input_tom = torch .cat ([final_segm , c_warp , im_ttp ], 1 )
146
+
147
+ with torch .no_grad ():
148
+ output_tom = tom (input_tom )
149
+ person_r = torch .tanh (output_tom [:,:3 ,:,:])
150
+ mask_c = torch .sigmoid (output_tom [:,3 :,:,:])
151
+ mask_c = (mask_c >= 0.5 ).type (torch .float )
152
+ img_tryon = mask_c * c_warp + (1 - mask_c ) * person_r
153
+ print ('Output generated!' )
154
+
155
+ c_warp = c_warp * 0.5 + 0.5
156
+ output_argm = output_argm .type (torch .float )
157
+ person_r = person_r * 0.5 + 0.5
158
+ img_tryon = img_tryon * 0.5 + 0.5
159
+
160
+ tensortoimage (c_warp [0 ].cpu (), osp .join (opt .save_dir , img_name + 'w_cloth.png' ))
161
+ tensortoimage (output_argm [0 ][0 ].cpu (), osp .join (opt .save_dir , img_name + 'seg_mask.png' ))
162
+ tensortoimage (mask_c [0 ].cpu (), osp .join (opt .save_dir , img_name + 'c_mask.png' ))
163
+ tensortoimage (person_r [0 ].cpu (), osp .join (opt .save_dir , img_name + 'ren_person.png' ))
164
+ tensortoimage (img_tryon [0 ].cpu (), osp .join (opt .save_dir , img_name + 'final_output.png' ))
165
+ print ('Output saved at {}' .format (opt .save_dir ))
166
+
167
+ if __name__ == "__main__" :
168
+ main ()
0 commit comments