-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0927651
commit b303f41
Showing
35 changed files
with
4,069 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
*__pycache__* | ||
pretrain_models | ||
*checkpoint* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# DCT-NET.Pytorch | ||
unofficial implementation of DCT-Net: Domain-Calibrated Translation for Portrait Stylization.<br> | ||
you can find official version [here](https://github.com/menyifang/DCT-Net) | ||
![](assets/net.png) | ||
|
||
## show | ||
![img](assets/ldh.png) | ||
![video](assets/xcaq.gif) | ||
|
||
## environment | ||
you can build your environment follow [this](https://github.com/rosinality/stylegan2-pytorch)<br> | ||
```pip install tensorboardX ``` for show | ||
|
||
## how to run | ||
### train | ||
download pretrain weights | ||
#### CCN | ||
1. prepare the style pictures and align them<br> | ||
the image path is like this<br> | ||
style-photos/<br> | ||
|-- 000000.png<br> | ||
|-- 000006.png<br> | ||
|-- 000010.png<br> | ||
|-- 000011.png<br> | ||
|-- 000015.png<br> | ||
|-- 000028.png<br> | ||
|-- 000039.png<br> | ||
2. change your own path in [ccn_config](./model/styleganModule/config.py#L7) | ||
3. train ccn<br> | ||
|
||
```shell | ||
# single gpu | ||
python train.py \ | ||
--model ccn \ | ||
--batch_size 16 \ | ||
--checkpoint_path checkpoint \ | ||
--lr 0.002 \ | ||
--print_interval 100 \ | ||
--save_interval 100 --dist | ||
``` | ||
|
||
```shell | ||
# multi gpu | ||
python -m torch.distributed.launch train.py \ | ||
--model ccn \ | ||
--batch_size 16 \ | ||
--checkpoint_path checkpoint \ | ||
--lr 0.002 \ | ||
--print_interval 100 \ | ||
--save_interval 100 | ||
``` | ||
almost 1000 steps, you can stop | ||
#### TTN | ||
1. prepare expression information<br> | ||
you can follow [LVT](https://github.com/LeslieZhoa/LVT) to estimate facial landmark<br> | ||
```shell | ||
cd utils | ||
python get_face_expression.py \ | ||
--img_base '' # your real image path base,like ffhq \ | ||
--pool_num 2 # multiprocess number \ | ||
--LVT '' # the LVT path you put \ | ||
--train # train data or val data | ||
``` | ||
2. prepare your generator image<br> | ||
```shell | ||
cd utils | ||
python get_tcc_input.py \ | ||
--model_path '' # ccn model path \ | ||
--output_path '' # save path | ||
``` | ||
__select almost 5k~1w good image manually__ | ||
3. change your own path in [ttn_config](./model/Pix2PixModule/config.py#21) | ||
```shell | ||
# like | ||
self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img' | ||
self.train_tgt_root = '/StyleTransform/DATA/select-style-gan' | ||
self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img' | ||
self.val_tgt_root = '/StyleTransform/DATA/select-style-gan' | ||
``` | ||
4. train tnn | ||
```shell | ||
# like ccn single and multi gpus | ||
python train.py \ | ||
--model ttn \ | ||
--batch_size 64 \ | ||
--checkpoint_path checkpoint \ | ||
--lr 2e-4 \ | ||
--print_interval 100 \ | ||
--save_interval 100 \ | ||
--dist | ||
``` | ||
## inference | ||
you can follow inference.py to put your own ttn model path and image path<br> | ||
```python inference.py``` | ||
|
||
## Credits | ||
SEAN model and implementation:<br> | ||
https://github.com/ZPdesu/SEAN Copyright © 2020, ZPdesu.<br> | ||
License https://github.com/ZPdesu/SEAN/blob/master/LICENSE.md | ||
|
||
stylegan2-pytorch model and implementation:<br> | ||
https://github.com/rosinality/stylegan2-pytorch Copyright © 2019, rosinality.<br> | ||
License https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE | ||
|
||
White-box-Cartoonization model and implementation:<br> | ||
https://github.com/SystemErrorWang/White-box-Cartoonization Copyright © 2020, SystemErrorWang.<br> | ||
|
||
White-box-Cartoonization model pytorch model and implementation:<br> | ||
https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch Copyright © 2022, vinesmsuic.<br> | ||
License https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch/blob/main/LICENSE | ||
|
||
arcface pytorch model pytorch model and implementation:<br> | ||
https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.<br> | ||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#! /usr/bin/python | ||
# -*- encoding: utf-8 -*- | ||
''' | ||
@author LeslieZhao | ||
@date 20220721 | ||
''' | ||
|
||
import os | ||
|
||
from torchvision import transforms | ||
import PIL.Image as Image | ||
from data.DataLoader import DatasetBase | ||
import random | ||
|
||
|
||
class CCNData(DatasetBase): | ||
def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs): | ||
super().__init__(slice_id, slice_count,dist, **kwargs) | ||
|
||
|
||
self.transform = transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
|
||
root = kwargs['root'] | ||
self.paths = [os.path.join(root,f) for f in os.listdir(root)] | ||
self.length = len(self.paths) | ||
random.shuffle(self.paths) | ||
|
||
def __getitem__(self,i): | ||
idx = i % self.length | ||
img_path = self.paths[idx] | ||
|
||
with Image.open(img_path) as img: | ||
Img = self.transform(img) | ||
|
||
return Img | ||
|
||
|
||
def __len__(self): | ||
return max(100000,self.length) | ||
# return 4 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#! /usr/bin/python | ||
# -*- encoding: utf-8 -*- | ||
''' | ||
@author LeslieZhao | ||
@date 20220721 | ||
''' | ||
|
||
|
||
from torch.utils.data import Dataset | ||
import torch.distributed as dist | ||
|
||
|
||
class DatasetBase(Dataset): | ||
def __init__(self,slice_id=0,slice_count=1,use_dist=False,**kwargs): | ||
|
||
if use_dist: | ||
slice_id = dist.get_rank() | ||
slice_count = dist.get_world_size() | ||
self.id = slice_id | ||
self.count = slice_count | ||
|
||
|
||
def __getitem__(self,i): | ||
pass | ||
|
||
|
||
|
||
|
||
def __len__(self): | ||
return 1000 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#! /usr/bin/python | ||
# -*- encoding: utf-8 -*- | ||
''' | ||
@author LeslieZhao | ||
@date 20220721 | ||
''' | ||
import os | ||
from torchvision import transforms | ||
import PIL.Image as Image | ||
from data.DataLoader import DatasetBase | ||
import random | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class TTNData(DatasetBase): | ||
def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs): | ||
super().__init__(slice_id, slice_count,dist, **kwargs) | ||
|
||
|
||
self.transform = transforms.Compose([ | ||
transforms.Resize([256,256]), | ||
transforms.RandomResizedCrop(256,scale=(0.8,1.2)), | ||
transforms.RandomRotation(degrees=(-90,90)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
|
||
if kwargs['eval']: | ||
self.transform = transforms.Compose([ | ||
transforms.Resize([256,256]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
self.length = 100 | ||
|
||
src_root = kwargs['src_root'] | ||
tgt_root = kwargs['tgt_root'] | ||
|
||
self.src_paths = [os.path.join(src_root,f) for f in os.listdir(src_root) if f.endswith('.png')] | ||
self.tgt_paths = [os.path.join(tgt_root,f) for f in os.listdir(tgt_root) if f.endswith('.png')] | ||
self.src_length = len(self.src_paths) | ||
self.tgt_length = len(self.tgt_paths) | ||
random.shuffle(self.src_paths) | ||
random.shuffle(self.tgt_paths) | ||
|
||
self.mx_left_eye_all,\ | ||
self.mn_left_eye_all,\ | ||
self.mx_right_eye_all,\ | ||
self.mn_right_eye_all,\ | ||
self.mx_lip_all,\ | ||
self.mn_lip_all = \ | ||
np.load(kwargs['score_info']) | ||
|
||
def __getitem__(self,i): | ||
src_idx = i % self.src_length | ||
tgt_idx = i % self.tgt_length | ||
|
||
src_path = self.src_paths[src_idx] | ||
tgt_path = self.tgt_paths[tgt_idx] | ||
exp_path = src_path.replace('img','express')[:-3] + 'npy' | ||
|
||
with Image.open(src_path) as img: | ||
srcImg = self.transform(img) | ||
|
||
with Image.open(tgt_path) as img: | ||
tgtImg = self.transform(img) | ||
|
||
score = np.load(exp_path) | ||
score[0] = (score[0] - self.mn_left_eye_all) / (self.mx_left_eye_all - self.mn_left_eye_all) | ||
score[1] = (score[1] - self.mn_right_eye_all) / (self.mx_right_eye_all - self.mn_right_eye_all) | ||
score[2] = (score[2] - self.mn_lip_all) / (self.mx_lip_all - self.mn_lip_all) | ||
score = torch.from_numpy(score.astype(np.float32)) | ||
|
||
return srcImg,tgtImg,score | ||
|
||
|
||
def __len__(self): | ||
# return max(self.src_length,self.tgt_length) | ||
if hasattr(self,'length'): | ||
return self.length | ||
else: | ||
return 10000 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
import os | ||
import cv2 | ||
import torch | ||
from model.Pix2PixModule.model import Generator | ||
from utils.utils import convert_img | ||
|
||
class Infer: | ||
def __init__(self,model_path): | ||
self.net = Generator(img_channels=3) | ||
self.load_checkpoint(model_path) | ||
|
||
|
||
def run(self,img): | ||
if isinstance(img,str): | ||
img = cv2.imread(img) | ||
inp = self.preprocess(img) | ||
with torch.no_grad(): | ||
xg = self.net(inp) | ||
oup = self.postprocess(xg[0]) | ||
return oup | ||
|
||
def load_checkpoint(self,path): | ||
ckpt = torch.load(path, map_location=lambda storage, loc: storage) | ||
self.net.load_state_dict(ckpt['netG'],strict=False) | ||
if torch.cuda.is_available(): | ||
self.net.cuda() | ||
self.net.eval() | ||
|
||
def preprocess(self,img): | ||
|
||
img = (img[...,::-1] / 255.0 - 0.5) * 2 | ||
img = img.transpose(2,0,1)[np.newaxis,:].astype(np.float32) | ||
img = torch.from_numpy(img) | ||
if torch.cuda.is_available(): | ||
img = img.cuda() | ||
return img | ||
def postprocess(self,img): | ||
img = convert_img(img,unit=True) | ||
return img.permute(1,2,0).cpu().numpy()[...,::-1] | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
|
||
path = 'pretrain_models/final.pth' | ||
model = Infer(path) | ||
|
||
img = cv2.imread('') | ||
|
||
img_h,img_w,_ = img.shape | ||
n_h,n_w = img_h // 8 * 8,img_w // 8 * 8 | ||
img = cv2.resize(img,(n_w,n_h)) | ||
|
||
oup = model.run(img)[...,::-1] | ||
cv2.imwrite('output.png',oup) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
class Params: | ||
def __init__(self): | ||
|
||
self.name = 'Pix2Pix' | ||
|
||
self.pretrain_path = None | ||
self.vgg_model = 'pretrain_models/vgg19-dcbb9e9d.pth' | ||
self.lr = 2e-4 | ||
self.beta1 = 0.5 | ||
self.beta2 = 0.99 | ||
|
||
self.use_exp = True | ||
self.lambda_surface = 2.0 | ||
self.lambda_texture = 2.0 | ||
self.lambda_content = 200 | ||
self.lambda_tv = 1e4 | ||
|
||
self.lambda_exp = 1.0 | ||
|
||
|
||
self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img' | ||
self.train_tgt_root = '/StyleTransform/DATA/select-style-gan' | ||
self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img' | ||
self.val_tgt_root = '/StyleTransform/DATA/select-style-gan' | ||
self.score_info = 'pretrain_models/all_express_mean.npy' | ||
|
||
self.infer_batch_size = 2 | ||
|
Oops, something went wrong.