Skip to content

Commit

Permalink
📝 Global update of the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Nov 2, 2020
1 parent 6f85612 commit fd8705f
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 191 deletions.
25 changes: 19 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Simple PyTorch Image Quality

Collection of measures and metrics for automatic image quality assessment in various image-to-image tasks such as denoising, super-resolution, image interpolation, etc.
This package is a collection of measures and metrics for image quality assessment in various image processing tasks such as denoising, super-resolution, image interpolation, etc. It relies heavily on [PyTorch](https://github.com/pytorch/pytorch) and takes advantage of its efficiency and automatic differentiation.

This library is directly inspired from [piq](https://github.com/photosynthesis-team/piq). However, it focuses on the simplicity, readability and understandability of its modules, such that anyone can freely and easily reuse and/or adapt them to its needs.
It should noted that `spiq` is directly inspired from the [`piq`](https://github.com/photosynthesis-team/piq) project. However, it focuses on the conciseness, readability and understandability of its (sub-)modules, such that anyone can freely and easily reuse and/or adapt them to its needs.

## Installation

Expand All @@ -14,7 +14,7 @@ cd spiq
python setup.py install
```

You can also copy the library directly to your project.
You can also copy the package directly to your project.

```bash
git clone https://github.com/francois-rozet/spiq
Expand All @@ -26,11 +26,24 @@ cp -R spiq <path/to/project>/spiq

```python
import torch
import spiq
import spiq.psnr as psnr
import spiq.ssim as ssim

x = torch.rand(3, 3, 256, 256)
y = torch.rand(3, 3, 256, 256)

a = spiq.psnr(x, y)
b = spiq.ssim(x, y)
# PSNR function
l = psnr.psnr(x, y)

# SSIM instantiable object
criterion = ssim.SSIM().cuda()
l = criterion(x, y)
```

## Documentation

The [documentation](https://francois-rozet.github.io/spiq/) of this package is generated automatically using [`pdoc`](https://github.com/pdoc3/pdoc).

```bash
pdoc spiq --html --config "git_link_template='https://github.com/francois-rozet/spiq/blob/{commit}/{path}#L{start_line}-L{end_line}'"
```
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

setuptools.setup(
name='spiq',
version='0.0.1',
description='Image quality metrics in PyTorch.',
version='0.0.2',
description='Image quality metrics in PyTorch',
long_description=readme,
long_description_content_type='text/markdown',
keywords='pytorch image processing metrics',
Expand Down
10 changes: 5 additions & 5 deletions spiq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__version__ = '0.0.1'
r"""Simple PyTorch Image Quality
from .psnr import psnr, PSNR
from .ssim import ssim, msssim, SSIM, MSSSIM
from .tv import tv, TV
from .lpips import LPIPS
The spiq package is divided in several submodules, each of which implements the functions and/or classes related to a specific image quality metric.
"""

__version__ = '0.0.2'
117 changes: 30 additions & 87 deletions spiq/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
This module implements the LPIPS in PyTorch.
Credits:
Inspired by lpips-pytorch
https://github.com/S-aiueo32/lpips-pytorch
Inspired by [lpips-pytorch](https://github.com/S-aiueo32/lpips-pytorch)
References:
[1] The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
Expand All @@ -20,89 +19,44 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models


#############
# Constants #
#############

_SHIFT = torch.Tensor([-.030, -.088, -.188])
_SCALE = torch.Tensor([.458, .448, .450])
from spiq.utils import normalize_tensor, Intermediary


#############
# Functions #
# Constants #
#############

def normalize(x: torch.Tensor, dim=(), norm='L2', epsilon: float=1e-8) -> torch.Tensor:
r"""Returns `x` normalized.
Args:
x: input tensor
dim: dimension(s) to normalize
norm: norm function name ('L1' or 'L2')
epsilon: numerical stability
Wikipedia:
https://en.wikipedia.org/wiki/Norm_(mathematics)
"""

if norm == 'L1':
norm = x.abs().sum(dim=dim, keepdim=True)
else: # norm == 'L2'
norm = torch.sqrt((x ** 2).sum(dim=dim, keepdim=True))

return x / (norm + epsilon)
_SHIFT = torch.Tensor([0.485, 0.456, 0.406])
_SCALE = torch.Tensor([0.229, 0.224, 0.225])


###########
# Classes #
###########

class Intermediate(nn.Module):
r"""Module that returns the outputs of target indermediate layers of a sequential module during its forward pass.
class LPIPS(nn.Module):
r"""Creates a criterion that measures the LPIPS between an input and a target.
Args:
layers: sequential module
targets: target layer indexes
"""

def __init__(self, layers: nn.Sequential, targets: list):
super().__init__()
network: perception network name (`'AlexNet'`, `'SqueezeNet'` or `'VGG16'`)
scaling: whether the input and target are sclaed w.r.t. ImageNet
reduction: reduction type (`'mean'`, `'sum'` or `'none'`)
self.layers = layers
self.targets = set(targets)

def forward(self, input: torch.Tensor) -> torch.Tensor:
r"""
Args:
input: input tensor
"""

output = []

for i, layer in enumerate(self.layers):
input = layer(input)

if i in self.targets:
output.append(input.clone())

if len(output) == len(self.targets):
break

return output


class LPIPS(nn.Module):
r"""Creates a criterion that measures the LPIPS between an input and a target.
Call:
The input and target tensors should be of shape (N, C, H, W).
"""

def __init__(self, network='AlexNet', normalize=False, reduction='mean'):
def __init__(self, network: str = 'AlexNet', scaling: bool = False, reduction: str = 'mean'):
super().__init__()

# ImageNet scaling
self.scaling = scaling
self.register_buffer('shift', _SHIFT.view(1, -1, 1, 1))
self.register_buffer('scale', _SCALE.view(1, -1, 1, 1))

# Perception layers
if network == 'AlexNet':
layers = models.alexnet(pretrained=True).features
targets = [1, 4, 7, 9, 11]
Expand All @@ -118,8 +72,9 @@ def __init__(self, network='AlexNet', normalize=False, reduction='mean'):
else:
raise ValueError('Unknown network architecture ' + network)

self.net = Intermediate(layers, targets)
self.net = Intermediary(layers, targets)

# Linear comparators
state_path = os.path.join(
os.path.dirname(inspect.getsourcefile(self.__init__)),
f'weights/lpips_{network}.pth'
Expand All @@ -131,43 +86,31 @@ def __init__(self, network='AlexNet', normalize=False, reduction='mean'):
])
self.lin.load_state_dict(torch.load(state_path))

self.register_buffer('shift', _SHIFT.view(1, -1, 1, 1))
self.register_buffer('scale', _SCALE.view(1, -1, 1, 1))

# Prevent gradients
for x in [self.parameters(), self.buffers()]:
for y in x:
y.requires_grad = False

self.normalize = normalize
self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Args:
input: input tensor, (N, C, H, W)
target: target tensor, (N, C, H, W)
"""

if self.normalize:
input = input * 2 - 1
target = target * 2 - 1

input_features = self.net((input - self.shift) / self.scale)
target_features = self.net((target - self.shift) / self.scale)
if self.scaling:
input = (input - self.shift) / self.scale
target = (target - self.shift) / self.scale

residuals = []

for loss, (fx, fy) in zip(self.lin, zip(input_features, target_features)):
fx = normalize(fx, dim=1, norm='L2')
fy = normalize(fy, dim=1, norm='L2')
for loss, (fx, fy) in zip(self.lin, zip(self.net(input), self.net(target))):
fx = normalize_tensor(fx, dim=1, norm='L2')
fy = normalize_tensor(fy, dim=1, norm='L2')

residuals.append(loss((fx - fy) ** 2).mean(dim=(-1, -2)))

l = torch.cat(residuals, dim=-1).sum(dim=-1)

if self.reduction == 'mean':
return l.mean()
l = l.mean()
elif self.reduction == 'sum':
return l.sum()
l = l.sum()

return l
30 changes: 17 additions & 13 deletions spiq/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,26 @@
import torch
import torch.nn as nn

from typing import Tuple


#############
# Functions #
#############

def psnr(x: torch.Tensor, y: torch.Tensor, dim: tuple=(), value_range: float=1., epsilon: float=1e-8) -> torch.Tensor:
def psnr(x: torch.Tensor, y: torch.Tensor, dim: Tuple[int, ...] = (), keepdim: bool = False, value_range: float = 1., epsilon: float = 1e-8) -> torch.Tensor:
r"""Returns the PSNR between `x` and `y`.
Args:
x: input tensor
y: target tensor
dim: dimension(s) to reduce
dim: dimension(s) along which to average
keepdim: whether the output tensor has `dim` retained or not
value_range: value range of the inputs (usually 1. or 255)
epsilon: numerical stability
epsilon: numerical stability term
"""

mse = ((x - y) ** 2).mean(dim=dim) + epsilon
mse = ((x - y) ** 2).mean(dim=dim, keepdim=keepdim) + epsilon
return 10 * torch.log10(value_range ** 2 / mse)


Expand All @@ -39,30 +42,31 @@ def psnr(x: torch.Tensor, y: torch.Tensor, dim: tuple=(), value_range: float=1.,

class PSNR(nn.Module):
r"""Creates a criterion that measures the PSNR between an input and a target.
Args:
value_range: value range of the inputs (usually 1. or 255)
reduction: reduction type (`'mean'`, `'sum'` or `'none'`)
Call:
The input and target tensors should be of shape (N, ...).
"""

def __init__(self, value_range: float=1., reduction='mean'):
def __init__(self, value_range: float = 1., reduction: str = 'mean'):
super().__init__()

self.value_range = value_range
self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Args:
input: input tensor, (N, ...)
target: target tensor, (N, ...)
"""

l = psnr(
input, target,
dim=tuple(range(1, input.ndimension())),
value_range=self.value_range
)

if self.reduction == 'mean':
return l.mean()
l = l.mean()
elif self.reduction == 'sum':
return l.sum()
l = l.sum()

return l
Loading

0 comments on commit fd8705f

Please sign in to comment.