Skip to content

Commit

Permalink
📝 Add examples in documentation
Browse files Browse the repository at this point in the history
The examples also act as tests.
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent d264a22 commit d277888
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 39 deletions.
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
specific image quality assessement metric.
"""

__version__ = '1.0.4'
__version__ = '1.0.5'
21 changes: 18 additions & 3 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import build_reduce, prewitt_kernel, gradient2d, tensor_norm
from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm

_L_WEIGHTS = torch.FloatTensor([0.2989, 0.587, 0.114])

Expand All @@ -32,6 +32,13 @@ def gmsd(
value_range: The value range of the inputs (usually 1. or 255).
For the remaining arguments, refer to [1].
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = gmsd(x, y)
>>> l.size()
torch.Size([5])
"""

_, _, h, w = x.size()
Expand All @@ -57,8 +64,8 @@ def gmsd(
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1).to(x.device)

gm_x = tensor_norm(gradient2d(x, kernel), dim=1)
gm_y = tensor_norm(gradient2d(y, kernel), dim=1)
gm_x = tensor_norm(filter2d(x, kernel, padding=1), dim=1)
gm_y = tensor_norm(filter2d(y, kernel, padding=1), dim=1)

# Gradient magnitude similarity
gms = (2. * gm_x * gm_y + c) / (gm_x ** 2 + gm_y ** 2 + c)
Expand All @@ -84,6 +91,14 @@ class GMSD(nn.Module):
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = GMSD()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(self, reduction: str = 'mean', **kwargs):
Expand Down
18 changes: 15 additions & 3 deletions piqa/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def get_weights(
`'alex'` | `'squeeze'` | `'vgg'`.
version: Specifies the official version release:
`'v0.0'` | `'v0.1'`.
Example:
>>> w = get_weights(network='alex')
>>> w.keys()
dict_keys(['0.1.weight', '1.1.weight', '2.1.weight', '3.1.weight', '4.1.weight'])
"""

# Load from URL
Expand All @@ -53,7 +58,7 @@ def get_weights(
# Format keys
weights = {
k.replace('lin', '').replace('.model', ''): v
for k, v in weights.items()
for (k, v) in weights.items()
}

return weights
Expand Down Expand Up @@ -81,6 +86,14 @@ class LPIPS(nn.Module):
Note:
`LPIPS` is a *trainable* metric. To prevent the weights from updating,
use the `torch.no_grad()` context or freeze the weights.
Example:
>>> criterion = LPIPS()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(
Expand Down Expand Up @@ -124,8 +137,7 @@ def __init__(
nn.Sequential(
nn.Dropout(inplace=True) if dropout else nn.Identity(),
nn.Conv2d(c, 1, kernel_size=1, stride=1, padding=0, bias=False),
)
for c in channels
) for c in channels
])

if pretrained:
Expand Down
26 changes: 22 additions & 4 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import build_reduce, prewitt_kernel, gradient2d, tensor_norm
from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm

_LHM_WEIGHTS = torch.FloatTensor([
[0.2989, 0.587, 0.114],
Expand Down Expand Up @@ -48,6 +48,13 @@ def mdsi(
`'sum'` | `'prod'`.
For the remaining arguments, refer to [1].
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = mdsi(x, y)
>>> l.size()
torch.Size([5])
"""

_, _, h, w = x.size()
Expand All @@ -74,9 +81,12 @@ def mdsi(
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1).to(x.device)

gm_x = tensor_norm(gradient2d(x[:, :1], kernel), dim=1)
gm_y = tensor_norm(gradient2d(y[:, :1], kernel), dim=1)
gm_avg = tensor_norm(gradient2d((x + y)[:, :1] / 2., kernel), dim=1)
gm_x = tensor_norm(filter2d(x[:, :1], kernel, padding=1), dim=1)
gm_y = tensor_norm(filter2d(y[:, :1], kernel, padding=1), dim=1)
gm_avg = tensor_norm(
filter2d((x + y)[:, :1] / 2., kernel, padding=1),
dim=1,
)

gm_x_sq, gm_y_sq, gm_avg_sq = gm_x ** 2, gm_y ** 2, gm_avg ** 2

Expand Down Expand Up @@ -122,6 +132,14 @@ class MDSI(nn.Module):
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = MDSI()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(self, reduction: str = 'mean', **kwargs):
Expand Down
18 changes: 16 additions & 2 deletions piqa/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def psnr(
keepdim: Whether the output tensor has `dim` retained or not.
value_range: The value range of the inputs (usually 1. or 255).
epsilon: A numerical stability term.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = psnr(x, y)
>>> l.size()
torch.Size([])
"""

mse = ((x - y) ** 2).mean(dim=dim, keepdim=keepdim) + epsilon
Expand All @@ -52,6 +59,14 @@ class PSNR(nn.Module):
* Input: (N, *), where * means any number of additional dimensions
* Target: (N, *), same shape as the input
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = PSNR()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(self, reduction: str = 'mean', **kwargs):
Expand All @@ -60,8 +75,7 @@ def __init__(self, reduction: str = 'mean', **kwargs):

self.reduce = build_reduce(reduction)
self.kwargs = {
k: v for k, v in kwargs.items()
if k not in ['dim', 'keepdim']
k: v for k, v in kwargs.items() if k not in ['dim', 'keepdim']
}

def forward(
Expand Down
83 changes: 66 additions & 17 deletions piqa/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import build_reduce, gaussian_kernel
from piqa.utils import build_reduce, gaussian_kernel, filter2d

from typing import Tuple

Expand All @@ -35,12 +35,21 @@ def create_window(window_size: int, n_channels: int) -> torch.Tensor:
Args:
window_size: The size of the window.
n_channels: A number of channels.
Example:
>>> win = create_window(5, n_channels=3)
>>> win.size()
torch.Size([3, 1, 5, 5])
>>> win[0]
tensor([[[0.0144, 0.0281, 0.0351, 0.0281, 0.0144],
[0.0281, 0.0547, 0.0683, 0.0547, 0.0281],
[0.0351, 0.0683, 0.0853, 0.0683, 0.0351],
[0.0281, 0.0547, 0.0683, 0.0547, 0.0281],
[0.0144, 0.0281, 0.0351, 0.0281, 0.0144]]])
"""

kernel = gaussian_kernel(window_size, 1.5)

window = kernel.unsqueeze(0).unsqueeze(0)
window = window.expand(n_channels, 1, window_size, window_size)
window = kernel.repeat(n_channels, 1, 1, 1)

return window

Expand All @@ -63,28 +72,31 @@ def ssim_per_channel(
value_range: The value range of the inputs (usually 1. or 255).
For the remaining arguments, refer to [1].
"""
n_channels, _, window_size, _ = window.size()
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> window = create_window(7, 3)
>>> ss, cs = ssim_per_channel(x, y, window)
>>> ss.size(), cs.size()
(torch.Size([5, 3]), torch.Size([5, 3]))
"""

c1 = (k1 * value_range) ** 2
c2 = (k2 * value_range) ** 2

# Mean (mu)
mu_x = F.conv2d(x, window, padding=0, groups=n_channels)
mu_y = F.conv2d(y, window, padding=0, groups=n_channels)
mu_x = filter2d(x, window)
mu_y = filter2d(y, window)

mu_x_sq = mu_x ** 2
mu_y_sq = mu_y ** 2
mu_xy = mu_x * mu_y

# Variance (sigma)
sigma_x_sq = F.conv2d(x ** 2, window, padding=0, groups=n_channels)
sigma_x_sq -= mu_x_sq
sigma_y_sq = F.conv2d(y ** 2, window, padding=0, groups=n_channels)
sigma_y_sq -= mu_y_sq
sigma_xy = F.conv2d(x * y, window, padding=0, groups=n_channels)
sigma_xy -= mu_xy
sigma_x_sq = filter2d(x ** 2, window) - mu_x_sq
sigma_y_sq = filter2d(y ** 2, window) - mu_y_sq
sigma_xy = filter2d(x * y, window) - mu_xy

# Contrast sensitivity
cs = (2. * sigma_xy + c2) / (sigma_x_sq + sigma_y_sq + c2)
Expand All @@ -109,6 +121,13 @@ def ssim(
window_size: The size of the window.
`**kwargs` are transmitted to `ssim_per_channel`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = ssim(x, y)
>>> l.size()
torch.Size([5])
"""

n_channels = x.size(1)
Expand All @@ -131,6 +150,14 @@ def msssim_per_channel(
window: A convolution window.
`**kwargs` are transmitted to `ssim_per_channel`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> window = create_window(7, 3)
>>> l = msssim_per_channel(x, y, window)
>>> l.size()
torch.Size([5, 3])
"""

weights = _WEIGHTS.to(x.device)
Expand All @@ -146,9 +173,8 @@ def msssim_per_channel(
ss, cs = ssim_per_channel(x, y, window, **kwargs)
mcs.append(torch.relu(cs))

msss = torch.stack(mcs[:-1] + [ss], dim=0)
msss = msss ** weights.view(-1, 1, 1)
msss = msss.prod(dim=0)
msss = torch.stack(mcs[:-1] + [ss], dim=-1)
msss = (msss ** weights).prod(dim=-1)

return msss

Expand All @@ -167,6 +193,13 @@ def msssim(
window_size: The size of the window.
`**kwargs` are transmitted to `msssim_per_channel`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = msssim(x, y)
>>> l.size()
torch.Size([5])
"""

n_channels = x.size(1)
Expand All @@ -191,6 +224,14 @@ class SSIM(nn.Module):
* Input: (N, C, H, W)
* Target: (N, C, H, W), same shape as the input
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = SSIM()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(
Expand Down Expand Up @@ -239,6 +280,14 @@ class MSSSIM(SSIM):
* Input: (N, C, H, W)
* Target: (N, C, H, W), same shape as the input
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = MSSSIM()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def forward(
Expand Down
13 changes: 13 additions & 0 deletions piqa/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor:
x: An input tensor, (*, C, H, W).
norm: Specifies the norm funcion to apply:
`'L1'` | `'L2'` | `'L2_squared'`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> l = tv(x)
>>> l.size()
torch.Size([5])
"""

w_var = x[..., :, 1:] - x[..., :, :-1]
Expand Down Expand Up @@ -51,6 +57,13 @@ class TV(nn.Module):
Shape:
* Input: (N, C, H, W)
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = TV()
>>> x = torch.rand(5, 3, 256, 256)
>>> l = criterion(x)
>>> l.size()
torch.Size([])
"""

def __init__(self, reduction: str = 'mean', **kwargs):
Expand Down
Loading

0 comments on commit d277888

Please sign in to comment.