Skip to content

Commit f8ed75b

Browse files
initial release
0 parents  commit f8ed75b

33 files changed

+1673
-0
lines changed

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# GANimation: Anatomically-aware Facial Animation from a Single Image
2+
Official implementation of [GANimation](http://www.albertpumarola.com/research/GANimation/index.html). In this work we introduce a novel GAN conditioning scheme based on Action Units (AU) annotations, which describe in a continuous manifold the anatomical facial movements defining a human expression. Our approach permits controlling the magnitude of activation of each AU and combine several of them. For more information please refer to the [paper](http://www.albertpumarola.com/publications/files/pumarola2018ganimation.pdf).
3+
4+
![GANimation](http://www.albertpumarola.com/images/2018/GANimation/teaser.png)
5+
6+
## Prerequisites
7+
- Install PyTorch, Torch Vision and dependencies from http://pytorch.org
8+
- Install requirements.txt (```pip install -r requirements.txt```)
9+
10+
## Data Preparation
11+
The code requires a directory containing the following files:
12+
- `imgs/`: folder with all image
13+
- `aus_openpose.pkl`: dictionary containing the images action units.
14+
- `train_ids.csv`: file containing the images names to be used to train.
15+
- `test_ids.csv`: file containing the images names to be used to test.
16+
17+
An example of this directory is shown in `sample_dataset/`.
18+
19+
To generate the `aus_openface.pkl` extract each image Action Units with [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace/wiki/Action-Units) and store each output in a csv file the same name as the image. Then run:
20+
```
21+
python data/prepare_au_annotations.py
22+
```
23+
24+
## Run
25+
To train:
26+
```
27+
bash launch/run_train.sh
28+
```
29+
To test:
30+
```
31+
python test --input_path path/to/img
32+
```
33+
34+
## Citation
35+
If you use this code or ideas from the paper for your research, please cite our paper:
36+
```
37+
@inproceedings{pumarola2018ganimation,
38+
title={{GANimation: Anatomically-aware Facial Animation from a Single Image}},
39+
author={A. Pumarola and A. Agudo and A.M. Martinez and A. Sanfeliu and F. Moreno-Noguer},
40+
booktitle={ECCV},
41+
year={2018}
42+
}
43+
```

data/__init__.py

Whitespace-only changes.

data/custom_dataset_data_loader.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch.utils.data
2+
from data.dataset import DatasetFactory
3+
4+
5+
class CustomDatasetDataLoader:
6+
def __init__(self, opt, is_for_train=True):
7+
self._opt = opt
8+
self._is_for_train = is_for_train
9+
self._num_threds = opt.n_threads_train if is_for_train else opt.n_threads_test
10+
self._create_dataset()
11+
12+
def _create_dataset(self):
13+
self._dataset = DatasetFactory.get_by_name(self._opt.dataset_mode, self._opt, self._is_for_train)
14+
self._dataloader = torch.utils.data.DataLoader(
15+
self._dataset,
16+
batch_size=self._opt.batch_size,
17+
shuffle=not self._opt.serial_batches,
18+
num_workers=int(self._num_threds),
19+
drop_last=True)
20+
21+
def load_data(self):
22+
return self._dataloader
23+
24+
def __len__(self):
25+
return len(self._dataset)

data/dataset.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch.utils.data as data
2+
from PIL import Image
3+
import torchvision.transforms as transforms
4+
import os
5+
import os.path
6+
7+
8+
class DatasetFactory:
9+
def __init__(self):
10+
pass
11+
12+
@staticmethod
13+
def get_by_name(dataset_name, opt, is_for_train):
14+
if dataset_name == 'aus':
15+
from data.dataset_aus import AusDataset
16+
dataset = AusDataset(opt, is_for_train)
17+
else:
18+
raise ValueError("Dataset [%s] not recognized." % dataset_name)
19+
20+
print('Dataset {} was created'.format(dataset.name))
21+
return dataset
22+
23+
24+
class DatasetBase(data.Dataset):
25+
def __init__(self, opt, is_for_train):
26+
super(DatasetBase, self).__init__()
27+
self._name = 'BaseDataset'
28+
self._root = None
29+
self._opt = opt
30+
self._is_for_train = is_for_train
31+
self._create_transform()
32+
33+
self._IMG_EXTENSIONS = [
34+
'.jpg', '.JPG', '.jpeg', '.JPEG',
35+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
36+
]
37+
38+
@property
39+
def name(self):
40+
return self._name
41+
42+
@property
43+
def path(self):
44+
return self._root
45+
46+
def _create_transform(self):
47+
self._transform = transforms.Compose([])
48+
49+
def get_transform(self):
50+
return self._transform
51+
52+
def _is_image_file(self, filename):
53+
return any(filename.endswith(extension) for extension in self._IMG_EXTENSIONS)
54+
55+
def _is_csv_file(self, filename):
56+
return filename.endswith('.csv')
57+
58+
def _get_all_files_in_subfolders(self, dir, is_file):
59+
images = []
60+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
61+
62+
for root, _, fnames in sorted(os.walk(dir)):
63+
for fname in fnames:
64+
if is_file(fname):
65+
path = os.path.join(root, fname)
66+
images.append(path)
67+
68+
return images

data/dataset_aus.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os.path
2+
import torchvision.transforms as transforms
3+
from data.dataset import DatasetBase
4+
from PIL import Image
5+
import random
6+
import numpy as np
7+
import pickle
8+
from utils import cv_utils
9+
10+
11+
class AusDataset(DatasetBase):
12+
def __init__(self, opt, is_for_train):
13+
super(AusDataset, self).__init__(opt, is_for_train)
14+
self._name = 'AusDataset'
15+
16+
# read dataset
17+
self._read_dataset_paths()
18+
19+
def __getitem__(self, index):
20+
assert (index < self._dataset_size)
21+
22+
# start_time = time.time()
23+
real_img = None
24+
real_cond = None
25+
while real_img is None or real_cond is None:
26+
# if sample randomly: overwrite index
27+
if not self._opt.serial_batches:
28+
index = random.randint(0, self._dataset_size - 1)
29+
30+
# get sample data
31+
sample_id = self._ids[index]
32+
33+
real_img, real_img_path = self._get_img_by_id(sample_id)
34+
real_cond = self._get_cond_by_id(sample_id)
35+
36+
if real_img is None:
37+
print 'error reading image %s, skipping sample' % sample_id
38+
if real_cond is None:
39+
print 'error reading aus %s, skipping sample' % sample_id
40+
41+
desired_cond = self._generate_random_cond()
42+
43+
# transform data
44+
img = self._transform(Image.fromarray(real_img))
45+
46+
# pack data
47+
sample = {'real_img': img,
48+
'real_cond': real_cond,
49+
'desired_cond': desired_cond,
50+
'sample_id': sample_id,
51+
'real_img_path': real_img_path
52+
}
53+
54+
# print (time.time() - start_time)
55+
56+
return sample
57+
58+
def __len__(self):
59+
return self._dataset_size
60+
61+
def _read_dataset_paths(self):
62+
self._root = self._opt.data_dir
63+
self._imgs_dir = os.path.join(self._root, self._opt.images_folder)
64+
65+
# read ids
66+
use_ids_filename = self._opt.train_ids_file if self._is_for_train else self._opt.test_ids_file
67+
use_ids_filepath = os.path.join(self._root, use_ids_filename)
68+
self._ids = self._read_ids(use_ids_filepath)
69+
70+
# read aus
71+
conds_filepath = os.path.join(self._root, self._opt.aus_file)
72+
self._conds = self._read_conds(conds_filepath)
73+
74+
self._ids = list(set(self._ids).intersection(set(self._conds.keys())))
75+
76+
# dataset size
77+
self._dataset_size = len(self._ids)
78+
79+
def _create_transform(self):
80+
if self._is_for_train:
81+
transform_list = [transforms.RandomHorizontalFlip(),
82+
transforms.ToTensor(),
83+
transforms.Normalize(mean=[0.5, 0.5, 0.5],
84+
std=[0.5, 0.5, 0.5]),
85+
]
86+
else:
87+
transform_list = [transforms.ToTensor(),
88+
transforms.Normalize(mean=[0.5, 0.5, 0.5],
89+
std=[0.5, 0.5, 0.5]),
90+
]
91+
self._transform = transforms.Compose(transform_list)
92+
93+
def _read_ids(self, file_path):
94+
ids = np.loadtxt(file_path, delimiter='\t', dtype=np.str)
95+
return [id[:-4] for id in ids]
96+
97+
def _read_conds(self, file_path):
98+
with open(file_path, 'rb') as f:
99+
return pickle.load(f)
100+
101+
def _get_cond_by_id(self, id):
102+
if id in self._conds:
103+
return self._conds[id]/5.0
104+
else:
105+
return None
106+
107+
def _get_img_by_id(self, id):
108+
filepath = os.path.join(self._root, self._imgs_dir, id+'.jpg')
109+
return cv_utils.read_cv2_img(filepath), filepath
110+
111+
def _generate_random_cond(self):
112+
cond = None
113+
while cond is None:
114+
rand_sample_id = self._ids[random.randint(0, self._dataset_size - 1)]
115+
cond = self._get_cond_by_id(rand_sample_id)
116+
cond += np.random.uniform(-0.1, 0.1, cond.shape)
117+
return cond

data/prepare_au_annotations.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
import os
3+
from tqdm import tqdm
4+
import argparse
5+
import glob
6+
import re
7+
import pickle
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('-ia', '--input_aus_filesdir', type=str, help='Dir with imgs aus files')
11+
parser.add_argument('-op', '--output_path', type=str, help='Output path')
12+
args = parser.parse_args()
13+
14+
def get_data(filepaths):
15+
data = dict()
16+
for filepath in tqdm(filepaths):
17+
content = np.loadtxt(filepath, delimiter=', ', skiprows=1)
18+
data[os.path.basename(filepath[:-4])] = content[2:19]
19+
20+
return data
21+
22+
def save_dict(data, name):
23+
with open(name + '.pkl', 'wb') as f:
24+
pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
25+
26+
def main():
27+
filepaths = glob.glob(os.path.join(args.input_aus_filesdir, '*.csv'))
28+
filepaths.sort()
29+
30+
# create aus file
31+
data = get_data(filepaths)
32+
33+
if not os.path.isdir(args.output_path):
34+
os.makedirs(args.output_path)
35+
save_dict(data, os.path.join(args.output_path, "aus"))
36+
37+
38+
if __name__ == '__main__':
39+
main()

launch/run_train.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env bash
2+
3+
python train.py \
4+
--data_dir path/to/dataset/ \
5+
--name experiment_1 \
6+
--batch_size 25 \

models/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)