diff --git a/.gitignore b/.gitignore
index e0e1cd7..f4d899c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,5 @@ __pycache__
.python-version
build
-dist
\ No newline at end of file
+dist
+pytorch_gan_metrics.egg-info
\ No newline at end of file
diff --git a/README.md b/README.md
index cdcbd6a..ace4d14 100644
--- a/README.md
+++ b/README.md
@@ -5,31 +5,33 @@
## Notes
The FID implementation is inspired from [pytorch-fid](https://github.com/mseitzer/pytorch-fid).
+This repository is developed for personal research. If you think this package can also benefit your life, please feel free to open issues.
+
+## Install
+```
+pip install pytorch-gan-metrics
+```
+
## Feature
- Currently, this package supports following metrics:
- [Inception Score](https://github.com/openai/improved-gan) (IS)
- [Fréchet Inception Distance](https://github.com/bioinf-jku/TTUR) (FID)
- The computation processes of IS and FID are integrated to avoid multiple forward propagations.
-- Read image on the fly for both metrics.
+- Support reading image on the fly to avoid out of memory especially for large scale images.
+- Support computation on GPU to speed up some cpu operations such as `np.cov` and `scipy.linalg.sqrtm`.
-## Reproducing Results of Official Implementations
+## Reproducing Results of Official Implementations on CIFAR-10
-- CIFAR-10
-
- | |Train IS |Test IS |Train(50k) vs Test(10k)
FID|
- |-------------------|:--------:|:--------:|:----------------------------:|
- |Official |11.24±0.20|10.98±0.22|3.1508 |
- |pytorch-gan-metrics|11.26±0.27|10.97±0.33|3.1517 |
+| |Train IS |Test IS |Train(50k) vs Test(10k)
FID|
+|-------------------|:--------:|:--------:|:----------------------------:|
+|Official |11.24±0.20|10.98±0.22|3.1508 |
+|pytorch-gan-metrics|11.26±0.27|10.97±0.33|3.1517 |
+|pytorch-gan-metrics
`use_torch=True`|11.26±0.21|10.97±0.34|3.1377 |
- Due to the framework difference between PyTorch and TensorFlow, the results are slightly different from official implementations.
-
-## Install
-```
-pip install pytorch-gan-metrics
-```
+The results are slightly different from official implementations due to the framework difference between PyTorch and TensorFlow.
## Prepare Statistics for FID
-- [Download](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) precalculated statistics for dataset or
+- [Download](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) precalculated statistics or
- Calculate statistics for your custom dataset using command line tool
```bash
python -m pytorch_gan_metrics.calc_fid_stats --path path/to/images --output name.npz
@@ -37,6 +39,10 @@ pip install pytorch-gan-metrics
See [calc_fid_stats.py](./pytorch_gan_metrics/calc_fid_stats.py) for implementation details.
### Documentation
+
+#### How to use GPU?
+`pytorch_gan_metrics` default uses `torch.device('cuda:0')` if GPU is available; Otherwise, it uses `cpu` to calculate inception feature.
+
#### Using `torch.Tensor` as images
- Prepare images in type `torch.float32` with shape `[N, 3, H, W]` and normalized to `[0,1]`.
```python
@@ -54,8 +60,9 @@ pip install pytorch-gan-metrics
images, 'path/to/statistics.npz')
```
-#### Using PyTorch DataLoader
-- Use `pytorch_gan_metrics.ImageDataset` to collect images on disk or use custom dataset which should only return an image in `__getitem__`.
+
+#### Using PyTorch DataLoader to Provide Images
+- Use `pytorch_gan_metrics.ImageDataset` to collect images on disk or use custom `torch.utils.data.Dataset` which should only return an image in the end of `__getitem__`.
```python
from pytorch_gan_metrics import ImageDataset
@@ -92,7 +99,7 @@ pip install pytorch-gan-metrics
loader, 'path/to/statistics.npz')
```
-#### From directory
+#### Specify Images by a Directory Path
- Calculate metrics for images in the directory.
```python
from pytorch_gan_metrics import (
@@ -106,6 +113,11 @@ pip install pytorch-gan-metrics
'path/to/images', fid_stats_path)
```
+#### Set PyTorch as backend
+- Set `use_torch=True` when calling functions `get_*` such as `get_inception_score`, `get_fid`, etc.
+- **WARNING** when set `use_torch=True`, the FID might be `nan` due to the unstable implementation of matrix sqrt.
+- This option is recommended to be used when evaluate generative models on a server machine which is equipped with high efficiency GPUs while the cpu frequency is low.
+
## License
This implementation is licensed under the Apache License 2.0.
diff --git a/pytorch_gan_metrics/__init__.py b/pytorch_gan_metrics/__init__.py
index b2956b7..75a61d3 100644
--- a/pytorch_gan_metrics/__init__.py
+++ b/pytorch_gan_metrics/__init__.py
@@ -17,4 +17,4 @@
get_inception_score_and_fid_from_directory
]
-__version__ = '0.1.0'
+__version__ = '0.2.0'
diff --git a/pytorch_gan_metrics/calc_fid_stats.py b/pytorch_gan_metrics/calc_fid_stats.py
index 5545c79..f14cf4b 100644
--- a/pytorch_gan_metrics/calc_fid_stats.py
+++ b/pytorch_gan_metrics/calc_fid_stats.py
@@ -16,8 +16,6 @@
help="output path")
parser.add_argument("--batch_size", type=int, default=50,
help="batch size (default=50)")
- parser.add_argument("--inception_dir", type=str, default='/tmp',
- help='path to inception model dir')
args = parser.parse_args()
dataset = ImageDataset(args.path, exts=['png', 'jpg'])
diff --git a/pytorch_gan_metrics/calc_metrics.py b/pytorch_gan_metrics/calc_metrics.py
index 57cd6a9..0f827d7 100644
--- a/pytorch_gan_metrics/calc_metrics.py
+++ b/pytorch_gan_metrics/calc_metrics.py
@@ -16,5 +16,5 @@
dataset = ImageDataset(args.path, exts=['png', 'jpg'])
loader = DataLoader(dataset, batch_size=50, num_workers=4)
(IS, IS_std), FID = get_inception_score_and_fid(
- loader, args.stats, use_torch=True, verbose=True)
+ loader, args.stats, verbose=True)
print(IS, IS_std, FID)
diff --git a/pytorch_gan_metrics/inception.py b/pytorch_gan_metrics/inception.py
index 6d1b781..fef8256 100644
--- a/pytorch_gan_metrics/inception.py
+++ b/pytorch_gan_metrics/inception.py
@@ -6,8 +6,8 @@
# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
-FID_WEIGHTS_URL = ('https://github.com/mseitzer/pytorch-fid/releases/download/'
- 'fid_weights/pt_inception-2015-12-05-6726825d.pth')
+FID_WEIGHTS_URL = ('https://github.com/w86763777/pytorch-gan-metrics/releases/'
+ 'download/v0.1.0/pt_inception-2015-12-05-6726825d.pth')
class InceptionV3(nn.Module):
diff --git a/setup.py b/setup.py
index cdfc434..176ce08 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,10 @@
import os
-
import setuptools
+import pytorch_gan_metrics
+
+
def read(rel_path):
base_path = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(base_path, rel_path), 'r') as f:
@@ -12,11 +14,10 @@ def read(rel_path):
if __name__ == '__main__':
setuptools.setup(
name='pytorch_gan_metrics',
- version='0.1.0',
+ version=pytorch_gan_metrics.__version__,
author='Yi-Lun Wu',
author_email='w86763777@gmail.com',
- description=(
- 'Package for calculating GAN metrics using Pytorch'),
+ description=('Package for calculating GAN metrics using Pytorch'),
long_description=read('README.md'),
long_description_content_type='text/markdown',
url='https://github.com/w86763777/pytorch-gan-metrics',
@@ -24,8 +25,10 @@ def read(rel_path):
keywords=[
'PyTorch',
'GAN',
- 'Inception Score', 'IS',
- 'Frechet Inception Distance', 'FID'],
+ 'Inception Score',
+ 'IS',
+ 'Frechet Inception Distance',
+ 'FID'],
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',