diff --git a/.gitignore b/.gitignore index e0e1cd7..f4d899c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ __pycache__ .python-version build -dist \ No newline at end of file +dist +pytorch_gan_metrics.egg-info \ No newline at end of file diff --git a/README.md b/README.md index cdcbd6a..ace4d14 100644 --- a/README.md +++ b/README.md @@ -5,31 +5,33 @@ ## Notes The FID implementation is inspired from [pytorch-fid](https://github.com/mseitzer/pytorch-fid). +This repository is developed for personal research. If you think this package can also benefit your life, please feel free to open issues. + +## Install +``` +pip install pytorch-gan-metrics +``` + ## Feature - Currently, this package supports following metrics: - [Inception Score](https://github.com/openai/improved-gan) (IS) - [Fréchet Inception Distance](https://github.com/bioinf-jku/TTUR) (FID) - The computation processes of IS and FID are integrated to avoid multiple forward propagations. -- Read image on the fly for both metrics. +- Support reading image on the fly to avoid out of memory especially for large scale images. +- Support computation on GPU to speed up some cpu operations such as `np.cov` and `scipy.linalg.sqrtm`. -## Reproducing Results of Official Implementations +## Reproducing Results of Official Implementations on CIFAR-10 -- CIFAR-10 - - | |Train IS |Test IS |Train(50k) vs Test(10k)
FID| - |-------------------|:--------:|:--------:|:----------------------------:| - |Official |11.24±0.20|10.98±0.22|3.1508 | - |pytorch-gan-metrics|11.26±0.27|10.97±0.33|3.1517 | +| |Train IS |Test IS |Train(50k) vs Test(10k)
FID| +|-------------------|:--------:|:--------:|:----------------------------:| +|Official |11.24±0.20|10.98±0.22|3.1508 | +|pytorch-gan-metrics|11.26±0.27|10.97±0.33|3.1517 | +|pytorch-gan-metrics
`use_torch=True`|11.26±0.21|10.97±0.34|3.1377 | - Due to the framework difference between PyTorch and TensorFlow, the results are slightly different from official implementations. - -## Install -``` -pip install pytorch-gan-metrics -``` +The results are slightly different from official implementations due to the framework difference between PyTorch and TensorFlow. ## Prepare Statistics for FID -- [Download](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) precalculated statistics for dataset or +- [Download](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) precalculated statistics or - Calculate statistics for your custom dataset using command line tool ```bash python -m pytorch_gan_metrics.calc_fid_stats --path path/to/images --output name.npz @@ -37,6 +39,10 @@ pip install pytorch-gan-metrics See [calc_fid_stats.py](./pytorch_gan_metrics/calc_fid_stats.py) for implementation details. ### Documentation + +#### How to use GPU? +`pytorch_gan_metrics` default uses `torch.device('cuda:0')` if GPU is available; Otherwise, it uses `cpu` to calculate inception feature. + #### Using `torch.Tensor` as images - Prepare images in type `torch.float32` with shape `[N, 3, H, W]` and normalized to `[0,1]`. ```python @@ -54,8 +60,9 @@ pip install pytorch-gan-metrics images, 'path/to/statistics.npz') ``` -#### Using PyTorch DataLoader -- Use `pytorch_gan_metrics.ImageDataset` to collect images on disk or use custom dataset which should only return an image in `__getitem__`. + +#### Using PyTorch DataLoader to Provide Images +- Use `pytorch_gan_metrics.ImageDataset` to collect images on disk or use custom `torch.utils.data.Dataset` which should only return an image in the end of `__getitem__`. ```python from pytorch_gan_metrics import ImageDataset @@ -92,7 +99,7 @@ pip install pytorch-gan-metrics loader, 'path/to/statistics.npz') ``` -#### From directory +#### Specify Images by a Directory Path - Calculate metrics for images in the directory. ```python from pytorch_gan_metrics import ( @@ -106,6 +113,11 @@ pip install pytorch-gan-metrics 'path/to/images', fid_stats_path) ``` +#### Set PyTorch as backend +- Set `use_torch=True` when calling functions `get_*` such as `get_inception_score`, `get_fid`, etc. +- **WARNING** when set `use_torch=True`, the FID might be `nan` due to the unstable implementation of matrix sqrt. +- This option is recommended to be used when evaluate generative models on a server machine which is equipped with high efficiency GPUs while the cpu frequency is low. + ## License This implementation is licensed under the Apache License 2.0. diff --git a/pytorch_gan_metrics/__init__.py b/pytorch_gan_metrics/__init__.py index b2956b7..75a61d3 100644 --- a/pytorch_gan_metrics/__init__.py +++ b/pytorch_gan_metrics/__init__.py @@ -17,4 +17,4 @@ get_inception_score_and_fid_from_directory ] -__version__ = '0.1.0' +__version__ = '0.2.0' diff --git a/pytorch_gan_metrics/calc_fid_stats.py b/pytorch_gan_metrics/calc_fid_stats.py index 5545c79..f14cf4b 100644 --- a/pytorch_gan_metrics/calc_fid_stats.py +++ b/pytorch_gan_metrics/calc_fid_stats.py @@ -16,8 +16,6 @@ help="output path") parser.add_argument("--batch_size", type=int, default=50, help="batch size (default=50)") - parser.add_argument("--inception_dir", type=str, default='/tmp', - help='path to inception model dir') args = parser.parse_args() dataset = ImageDataset(args.path, exts=['png', 'jpg']) diff --git a/pytorch_gan_metrics/calc_metrics.py b/pytorch_gan_metrics/calc_metrics.py index 57cd6a9..0f827d7 100644 --- a/pytorch_gan_metrics/calc_metrics.py +++ b/pytorch_gan_metrics/calc_metrics.py @@ -16,5 +16,5 @@ dataset = ImageDataset(args.path, exts=['png', 'jpg']) loader = DataLoader(dataset, batch_size=50, num_workers=4) (IS, IS_std), FID = get_inception_score_and_fid( - loader, args.stats, use_torch=True, verbose=True) + loader, args.stats, verbose=True) print(IS, IS_std, FID) diff --git a/pytorch_gan_metrics/inception.py b/pytorch_gan_metrics/inception.py index 6d1b781..fef8256 100644 --- a/pytorch_gan_metrics/inception.py +++ b/pytorch_gan_metrics/inception.py @@ -6,8 +6,8 @@ # Inception weights ported to Pytorch from # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -FID_WEIGHTS_URL = ('https://github.com/mseitzer/pytorch-fid/releases/download/' - 'fid_weights/pt_inception-2015-12-05-6726825d.pth') +FID_WEIGHTS_URL = ('https://github.com/w86763777/pytorch-gan-metrics/releases/' + 'download/v0.1.0/pt_inception-2015-12-05-6726825d.pth') class InceptionV3(nn.Module): diff --git a/setup.py b/setup.py index cdfc434..176ce08 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,10 @@ import os - import setuptools +import pytorch_gan_metrics + + def read(rel_path): base_path = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(base_path, rel_path), 'r') as f: @@ -12,11 +14,10 @@ def read(rel_path): if __name__ == '__main__': setuptools.setup( name='pytorch_gan_metrics', - version='0.1.0', + version=pytorch_gan_metrics.__version__, author='Yi-Lun Wu', author_email='w86763777@gmail.com', - description=( - 'Package for calculating GAN metrics using Pytorch'), + description=('Package for calculating GAN metrics using Pytorch'), long_description=read('README.md'), long_description_content_type='text/markdown', url='https://github.com/w86763777/pytorch-gan-metrics', @@ -24,8 +25,10 @@ def read(rel_path): keywords=[ 'PyTorch', 'GAN', - 'Inception Score', 'IS', - 'Frechet Inception Distance', 'FID'], + 'Inception Score', + 'IS', + 'Frechet Inception Distance', + 'FID'], classifiers=[ 'Programming Language :: Python :: 3', 'License :: OSI Approved :: Apache Software License',