Skip to content

Commit 774df81

Browse files
author
Jiahui Li
committed
first commitgit status 😄
1 parent 65fee80 commit 774df81

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+5595
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*__pycache__*
2+
*.pt*
3+
pretrained_models

ReadMe.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# HeadSwap
2+
This is a HeadSwap project, mainly inspired by [HeSer.Pytorch](https://github.com/LeslieZhoa/HeSer.Pytorch)<br>
3+
It includes two stages<br>
4+
1. One stage deeply copy from [PIRender](https://github.com/RenYurui/PIRender)<br>
5+
2. The other stage is from [HeSer.Pytorch](https://github.com/LeslieZhoa/HeSer.Pytorch)
6+
## Some examples
7+
the pictures is from [小红书](https://www.xiaohongshu.com/) and [baidu](baidu.com)
8+
![](assets/1-e6de879a-f3c2-47f8-9588-feab95df7b9e.png)
9+
![](assets/2-6.png)
10+
![](assets/3c7b5309-37dc-4cd4-abba-dab4b2482285-7ed01838-3959-48c4-a946-f7bb0232b8cc.png)
11+
![](assets/3c7b5309-37dc-4cd4-abba-dab4b2482285-ce07dc0f-db6c-44a6-80f2-ac9cc5b6d21a.png)
12+
![](assets/5-279fd7f6-1039-44eb-b328-4a5b5c12dc59.png)
13+
![](assets/5ededbf7-0bc9-4af8-aec2-e9af939a0c60-5177e7a3-f5dc-4e11-8a96-79bb54b06ced.png)
14+
## Reference
15+
super resolution -> [CF-Net](https://github.com/ytZhang99/CF-Net)<br>
16+
face parsing -> [face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)<br>
17+
3dmm -> [Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
18+
## How to Run
19+
1. envrionment<br>
20+
follow PIRender and Heser<br>
21+
LVT in this project is follow [here](https://github.com/LeslieZhoa/LVT)
22+
2. download model<br>
23+
a. follow Deep3DFaceRecon_pytorch(https://github.com/sicxu/Deep3DFaceRecon_pytorch#prepare-prerequisite-models), download BFM files and epoch_20.pth in pretrained_models <br>
24+
b. follow [PIRender](https://github.com/RenYurui/PIRender/blob/main/scripts/download_weights.sh), put epoch_00190_iteration_000400000_checkpoint.pt to pretrained_models<br>
25+
c.
26+
```
27+
cd process
28+
bash download_weight.sh
29+
```
30+
3. run
31+
```py
32+
# set your own image path in inference.py
33+
python inference.py
34+
```
35+
36+
## Credits
37+
HeSer.Pytorch model and implementation:
38+
https://github.com/LeslieZhoa/HeSer.Pytorch Copyright © 2022,LeslieZhoa License https://github.com/LeslieZhoa/HeSer.Pytorch/blob/main/LICENSE
39+
40+
PIRender model and implementation:
41+
https://github.com/RenYurui/PIRender Copyright © 2021,RenYurui. License https://github.com/RenYurui/PIRender/blob/main/LICENSE.md
42+
43+
CF-Net model and implementation:
44+
https://github.com/ytZhang99/CF-Net Copyright © 2021,ytZhang99.
45+
46+
Deep3DFaceRecon_pytorch model and implementation:
47+
https://github.com/sicxu/Deep3DFaceRecon_pytorch Copyright © 2021,sicxu. License https://github.com/sicxu/Deep3DFaceRecon_pytorch/blob/master/LICENSE
48+
49+
arcface pytorch model pytorch model and implementation:
50+
https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.
51+
52+
LVT model and implementation:
53+
https://github.com/LeslieZhoa/LVT Copyright © 2022, LeslieZhoa.
54+
55+
face-parsing model and implementation:
56+
https://github.com/zllrunning/face-parsing.PyTorch Copyright © 2019, zllrunning.
57+
License https://github.com/zllrunning/face-parsing.PyTorch/blob/master/LICENSE
2.85 MB
Loading

assets/2-6.png

5.17 MB
Loading
4.72 MB
Loading
1.03 MB
Loading
2.5 MB
Loading
4.03 MB
Loading

dataloader/AlignLoader.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#! /usr/bin/python
2+
# -*- encoding: utf-8 -*-
3+
'''
4+
@author LeslieZhao
5+
@date 20221221
6+
'''
7+
8+
import os
9+
from torchvision import transforms
10+
import PIL.Image as Image
11+
from dataloader.DataLoader import DatasetBase
12+
import random
13+
import math
14+
import torch
15+
import numpy as np
16+
17+
18+
class AlignData(DatasetBase):
19+
def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
20+
super().__init__(slice_id, slice_count,dist, **kwargs)
21+
22+
23+
self.transform = transforms.Compose([
24+
transforms.Resize((kwargs['size'], kwargs['size'])),
25+
transforms.ToTensor(),
26+
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
27+
])
28+
29+
self.resize = transforms.Compose([
30+
transforms.Resize((256,256))])
31+
32+
# source root
33+
root = kwargs['root']
34+
self.idinfo = np.load(root,allow_pickle=True).item()
35+
keys = list(self.idinfo.keys())
36+
37+
dis = math.floor(len(keys)/self.count)
38+
self.keys = keys[self.id*dis:(self.id+1)*dis]
39+
self.length = len(self.keys)
40+
random.shuffle(self.keys)
41+
self.eval = kwargs['eval']
42+
self.size = kwargs['size']
43+
self.params_w0 = self.params_h0 = 256
44+
self.params_target_size = 224
45+
46+
47+
def __getitem__(self,i):
48+
49+
src_img_path,\
50+
tgt_img_path,\
51+
src_param_path,\
52+
tgt_param_path,\
53+
src_box_path,\
54+
tgt_box_path = self.get_path(i)
55+
56+
tube_box_path = os.path.join(os.path.split(src_img_path)[0].replace('crop','img'),'box.npy')
57+
tube_box = np.load(tube_box_path)
58+
with Image.open(src_img_path) as img:
59+
xs = self.transform(img.convert('RGB'))
60+
xs_params = torch.from_numpy(np.load(src_param_path).astype(np.float32))
61+
xs_bbox = np.load(src_box_path)
62+
xs_bbox = torch.from_numpy(
63+
np.concatenate([self.fix_bbox(xs_bbox,tube_box),
64+
self.get_params_box(xs_params.numpy())],-1).astype(np.float32))
65+
66+
flag = 1
67+
# ÷if self.eval
68+
if random.random() > 0.5:
69+
tgt_img_path,tgt_param_path,tgt_box_path = self.get_another_tgt(i)
70+
tube_box_path = os.path.join(os.path.split(tgt_img_path)[0].replace('crop','img'),'box.npy')
71+
tube_box = np.load(tube_box_path)
72+
flag = 0
73+
74+
75+
with Image.open(tgt_img_path) as img:
76+
xt = self.transform(img.convert('RGB'))
77+
78+
xt_params = torch.from_numpy(np.load(tgt_param_path).astype(np.float32))
79+
xt_bbox = np.load(tgt_box_path)
80+
xt_bbox = torch.from_numpy(
81+
np.concatenate([self.fix_bbox(xt_bbox,tube_box),
82+
self.get_params_box(xt_params.numpy())],-1).astype(np.float32))
83+
84+
return self.resize(xs),self.resize(xt),xs,xt,xs_params,xt_params,xs_bbox,xt_bbox,flag
85+
86+
def get_path(self,i):
87+
idx = i % self.length
88+
video_paths = self.idinfo[self.keys[idx]]
89+
90+
if len(video_paths) == 1:
91+
vIdx = 0
92+
else:
93+
vIdx = random.randint(0, len(video_paths) - 1)
94+
img_paths = video_paths[vIdx]
95+
96+
src_idx,tgt_idx = self.select_path(img_paths)
97+
98+
src_img_path = img_paths[src_idx].replace('id','crop').replace('.npy','.png')
99+
tgt_img_path = img_paths[tgt_idx].replace('id','crop').replace('.npy','.png')
100+
101+
src_param_path = img_paths[src_idx].replace('id','3dmm')
102+
tgt_param_path = img_paths[tgt_idx].replace('id','3dmm')
103+
104+
src_box_path = img_paths[src_idx].replace('id','bbox')
105+
tgt_box_path = img_paths[tgt_idx].replace('id','bbox')
106+
return src_img_path,tgt_img_path,src_param_path,tgt_param_path,src_box_path,tgt_box_path
107+
108+
def get_another_tgt(self,i):
109+
idx = (i + random.randint(0,self.length-1)) % self.length
110+
video_paths = self.idinfo[self.keys[idx]]
111+
112+
if len(video_paths) == 1:
113+
vIdx = 0
114+
else:
115+
vIdx = random.randint(0, len(video_paths) - 1)
116+
img_paths = video_paths[vIdx]
117+
118+
tgt_idx = random.randint(0,len(img_paths)-1)
119+
120+
tgt_img_path = img_paths[tgt_idx].replace('id','crop').replace('.npy','.png')
121+
122+
tgt_param_path = img_paths[tgt_idx].replace('id','3dmm')
123+
124+
tgt_box_path = img_paths[tgt_idx].replace('id','bbox')
125+
return tgt_img_path,tgt_param_path,tgt_box_path
126+
127+
def fix_bbox(self,bbox,tube_bbox):
128+
x_min,y_min,x_max,y_max = tube_bbox[:4]
129+
130+
center_x = (x_min + x_max) / 2.0
131+
center_y = (y_min + y_max) / 2.0
132+
bbox_size = int(max(y_max-y_min,x_max-x_min) * 1.8)
133+
134+
x_min = int(center_x-bbox_size * 0.5)
135+
y_min = int(center_y-bbox_size * 0.5)
136+
scale = self.size * 1. / bbox_size
137+
138+
return np.array([(bbox[0] - x_min) * scale,
139+
(bbox[1] - y_min) * scale,
140+
(bbox[2] - x_min) * scale,
141+
(bbox[3] - y_min) * scale])
142+
143+
def select_path(self,img_paths):
144+
length = len(img_paths)
145+
if length <= 15:
146+
src_idx,tgt_idx = 0,-1
147+
else:
148+
src_idx = random.randint(0, length - 15-1)
149+
tgt_idx = random.randint(min(src_idx+15,length-1),length-1)
150+
return src_idx,tgt_idx
151+
152+
def get_params_box(self,params):
153+
154+
s,t0,t1 = params.reshape(-1)[-3:]
155+
s = s + 1e-8
156+
w = (self.params_w0*s)
157+
h = (self.params_h0*s)
158+
159+
left = max(0,w/2 - self.params_target_size/2 + float((t0 - self.params_w0/2)*s))
160+
right = left + self.params_target_size
161+
up = max(0,h/2 - self.params_target_size/2 + float((self.params_h0/2 - t1)*s))
162+
below = up + self.params_target_size
163+
164+
return np.array([left/s,up/s,right/s,below/s])
165+
166+
167+
def __len__(self):
168+
if self.eval:
169+
return max(self.length,1000)
170+
else:
171+
# return self.length
172+
return max(self.length,100000)
173+

0 commit comments

Comments
 (0)