Skip to content

Commit

Permalink
first
Browse files Browse the repository at this point in the history
  • Loading branch information
LeslieZhoa committed Jul 31, 2022
1 parent 0927651 commit b303f41
Show file tree
Hide file tree
Showing 35 changed files with 4,069 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*__pycache__*
pretrain_models
*checkpoint*
116 changes: 116 additions & 0 deletions ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# DCT-NET.Pytorch
unofficial implementation of DCT-Net: Domain-Calibrated Translation for Portrait Stylization.<br>
you can find official version [here](https://github.com/menyifang/DCT-Net)
![](assets/net.png)

## show
![img](assets/ldh.png)
![video](assets/xcaq.gif)

## environment
you can build your environment follow [this](https://github.com/rosinality/stylegan2-pytorch)<br>
```pip install tensorboardX ``` for show

## how to run
### train
download pretrain weights
#### CCN
1. prepare the style pictures and align them<br>
the image path is like this<br>
style-photos/<br>
|-- 000000.png<br>
|-- 000006.png<br>
|-- 000010.png<br>
|-- 000011.png<br>
|-- 000015.png<br>
|-- 000028.png<br>
|-- 000039.png<br>
2. change your own path in [ccn_config](./model/styleganModule/config.py#L7)
3. train ccn<br>

```shell
# single gpu
python train.py \
--model ccn \
--batch_size 16 \
--checkpoint_path checkpoint \
--lr 0.002 \
--print_interval 100 \
--save_interval 100 --dist
```

```shell
# multi gpu
python -m torch.distributed.launch train.py \
--model ccn \
--batch_size 16 \
--checkpoint_path checkpoint \
--lr 0.002 \
--print_interval 100 \
--save_interval 100
```
almost 1000 steps, you can stop
#### TTN
1. prepare expression information<br>
you can follow [LVT](https://github.com/LeslieZhoa/LVT) to estimate facial landmark<br>
```shell
cd utils
python get_face_expression.py \
--img_base '' # your real image path base,like ffhq \
--pool_num 2 # multiprocess number \
--LVT '' # the LVT path you put \
--train # train data or val data
```
2. prepare your generator image<br>
```shell
cd utils
python get_tcc_input.py \
--model_path '' # ccn model path \
--output_path '' # save path
```
__select almost 5k~1w good image manually__
3. change your own path in [ttn_config](./model/Pix2PixModule/config.py#21)
```shell
# like
self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img'
self.train_tgt_root = '/StyleTransform/DATA/select-style-gan'
self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img'
self.val_tgt_root = '/StyleTransform/DATA/select-style-gan'
```
4. train tnn
```shell
# like ccn single and multi gpus
python train.py \
--model ttn \
--batch_size 64 \
--checkpoint_path checkpoint \
--lr 2e-4 \
--print_interval 100 \
--save_interval 100 \
--dist
```
## inference
you can follow inference.py to put your own ttn model path and image path<br>
```python inference.py```

## Credits
SEAN model and implementation:<br>
https://github.com/ZPdesu/SEAN Copyright © 2020, ZPdesu.<br>
License https://github.com/ZPdesu/SEAN/blob/master/LICENSE.md

stylegan2-pytorch model and implementation:<br>
https://github.com/rosinality/stylegan2-pytorch Copyright © 2019, rosinality.<br>
License https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE

White-box-Cartoonization model and implementation:<br>
https://github.com/SystemErrorWang/White-box-Cartoonization Copyright © 2020, SystemErrorWang.<br>

White-box-Cartoonization model pytorch model and implementation:<br>
https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch Copyright © 2022, vinesmsuic.<br>
License https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch/blob/main/LICENSE

arcface pytorch model pytorch model and implementation:<br>
https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.<br>



Binary file added assets/ldh.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/net.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/xcaq.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions data/CCNLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#! /usr/bin/python
# -*- encoding: utf-8 -*-
'''
@author LeslieZhao
@date 20220721
'''

import os

from torchvision import transforms
import PIL.Image as Image
from data.DataLoader import DatasetBase
import random


class CCNData(DatasetBase):
def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
super().__init__(slice_id, slice_count,dist, **kwargs)


self.transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

root = kwargs['root']
self.paths = [os.path.join(root,f) for f in os.listdir(root)]
self.length = len(self.paths)
random.shuffle(self.paths)

def __getitem__(self,i):
idx = i % self.length
img_path = self.paths[idx]

with Image.open(img_path) as img:
Img = self.transform(img)

return Img


def __len__(self):
return max(100000,self.length)
# return 4

31 changes: 31 additions & 0 deletions data/DataLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#! /usr/bin/python
# -*- encoding: utf-8 -*-
'''
@author LeslieZhao
@date 20220721
'''


from torch.utils.data import Dataset
import torch.distributed as dist


class DatasetBase(Dataset):
def __init__(self,slice_id=0,slice_count=1,use_dist=False,**kwargs):

if use_dist:
slice_id = dist.get_rank()
slice_count = dist.get_world_size()
self.id = slice_id
self.count = slice_count


def __getitem__(self,i):
pass




def __len__(self):
return 1000

84 changes: 84 additions & 0 deletions data/TTNLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#! /usr/bin/python
# -*- encoding: utf-8 -*-
'''
@author LeslieZhao
@date 20220721
'''
import os
from torchvision import transforms
import PIL.Image as Image
from data.DataLoader import DatasetBase
import random
import numpy as np
import torch


class TTNData(DatasetBase):
def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
super().__init__(slice_id, slice_count,dist, **kwargs)


self.transform = transforms.Compose([
transforms.Resize([256,256]),
transforms.RandomResizedCrop(256,scale=(0.8,1.2)),
transforms.RandomRotation(degrees=(-90,90)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

if kwargs['eval']:
self.transform = transforms.Compose([
transforms.Resize([256,256]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
self.length = 100

src_root = kwargs['src_root']
tgt_root = kwargs['tgt_root']

self.src_paths = [os.path.join(src_root,f) for f in os.listdir(src_root) if f.endswith('.png')]
self.tgt_paths = [os.path.join(tgt_root,f) for f in os.listdir(tgt_root) if f.endswith('.png')]
self.src_length = len(self.src_paths)
self.tgt_length = len(self.tgt_paths)
random.shuffle(self.src_paths)
random.shuffle(self.tgt_paths)

self.mx_left_eye_all,\
self.mn_left_eye_all,\
self.mx_right_eye_all,\
self.mn_right_eye_all,\
self.mx_lip_all,\
self.mn_lip_all = \
np.load(kwargs['score_info'])

def __getitem__(self,i):
src_idx = i % self.src_length
tgt_idx = i % self.tgt_length

src_path = self.src_paths[src_idx]
tgt_path = self.tgt_paths[tgt_idx]
exp_path = src_path.replace('img','express')[:-3] + 'npy'

with Image.open(src_path) as img:
srcImg = self.transform(img)

with Image.open(tgt_path) as img:
tgtImg = self.transform(img)

score = np.load(exp_path)
score[0] = (score[0] - self.mn_left_eye_all) / (self.mx_left_eye_all - self.mn_left_eye_all)
score[1] = (score[1] - self.mn_right_eye_all) / (self.mx_right_eye_all - self.mn_right_eye_all)
score[2] = (score[2] - self.mn_lip_all) / (self.mx_lip_all - self.mn_lip_all)
score = torch.from_numpy(score.astype(np.float32))

return srcImg,tgtImg,score


def __len__(self):
# return max(self.src_length,self.tgt_length)
if hasattr(self,'length'):
return self.length
else:
return 10000

58 changes: 58 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import os
import cv2
import torch
from model.Pix2PixModule.model import Generator
from utils.utils import convert_img

class Infer:
def __init__(self,model_path):
self.net = Generator(img_channels=3)
self.load_checkpoint(model_path)


def run(self,img):
if isinstance(img,str):
img = cv2.imread(img)
inp = self.preprocess(img)
with torch.no_grad():
xg = self.net(inp)
oup = self.postprocess(xg[0])
return oup

def load_checkpoint(self,path):
ckpt = torch.load(path, map_location=lambda storage, loc: storage)
self.net.load_state_dict(ckpt['netG'],strict=False)
if torch.cuda.is_available():
self.net.cuda()
self.net.eval()

def preprocess(self,img):

img = (img[...,::-1] / 255.0 - 0.5) * 2
img = img.transpose(2,0,1)[np.newaxis,:].astype(np.float32)
img = torch.from_numpy(img)
if torch.cuda.is_available():
img = img.cuda()
return img
def postprocess(self,img):
img = convert_img(img,unit=True)
return img.permute(1,2,0).cpu().numpy()[...,::-1]



if __name__ == "__main__":

path = 'pretrain_models/final.pth'
model = Infer(path)

img = cv2.imread('')

img_h,img_w,_ = img.shape
n_h,n_w = img_h // 8 * 8,img_w // 8 * 8
img = cv2.resize(img,(n_w,n_h))

oup = model.run(img)[...,::-1]
cv2.imwrite('output.png',oup)


28 changes: 28 additions & 0 deletions model/Pix2PixModule/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
class Params:
def __init__(self):

self.name = 'Pix2Pix'

self.pretrain_path = None
self.vgg_model = 'pretrain_models/vgg19-dcbb9e9d.pth'
self.lr = 2e-4
self.beta1 = 0.5
self.beta2 = 0.99

self.use_exp = True
self.lambda_surface = 2.0
self.lambda_texture = 2.0
self.lambda_content = 200
self.lambda_tv = 1e4

self.lambda_exp = 1.0


self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img'
self.train_tgt_root = '/StyleTransform/DATA/select-style-gan'
self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img'
self.val_tgt_root = '/StyleTransform/DATA/select-style-gan'
self.score_info = 'pretrain_models/all_express_mean.npy'

self.infer_batch_size = 2

Loading

0 comments on commit b303f41

Please sign in to comment.