Skip to content

Commit

Permalink
♻️ Remove duplicate definitions (functional and object-oriented)
Browse files Browse the repository at this point in the history
⚡️ Simplify (MS-)SSIM convolutions

📝 Add math in documentation
  • Loading branch information
francois-rozet committed Jan 15, 2021
1 parent c6d924d commit 785b449
Show file tree
Hide file tree
Showing 11 changed files with 537 additions and 466 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ l.backward()
| TV | [piqa.tv] | 1937 | [Total Variation](https://en.wikipedia.org/wiki/Total_variation) |
| PSNR | [piqa.psnr] | / | [Peak Signal-to-Noise Ratio](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) |
| SSIM | [piqa.ssim] | 2004 | [Structural Similarity](https://en.wikipedia.org/wiki/Structural_similarity) |
| MS-SSIM | [piqa.ssim] | 2003 | [Multi-Scale Structural Similarity](https://ieeexplore.ieee.org/abstract/document/1292216/) |
| MS-SSIM | [piqa.ssim] | 2004 | [Multi-Scale Structural Similarity](https://ieeexplore.ieee.org/abstract/document/1292216/) |
| LPIPS | [piqa.lpips] | 2018 | [Learned Perceptual Image Patch Similarity](https://arxiv.org/abs/1801.03924) |
| GMSD | [piqa.gmsd] | 2013 | [Gradient Magnitude Similarity Deviation](https://arxiv.org/abs/1308.3052) |
| MS-GMSD | [piqa.gmsd] | 2017 | [Multi-Scale Gradient Magnitude Similiarity Deviation](https://ieeexplore.ieee.org/document/7952357) |
Expand Down
6 changes: 3 additions & 3 deletions piqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
r"""PyTorch Image Quality Assessement
r"""PyTorch Image Quality Assessement (PIQA)
The piqa package is divided in several submodules, each of
The `piqa` package is divided in several submodules, each of
which implements the functions and/or classes related to a
specific image quality assessement metric.
"""

__version__ = '1.0.13'
__version__ = '1.1.0'
212 changes: 105 additions & 107 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
https://ieeexplore.ieee.org/document/7952357
"""

__pdoc__ = {'_gmsd': True, '_msgmsd': True}
__pdoc__ = {'_gmsd': True, '_ms_gmsd': True}

import torch
import torch.nn as nn
Expand All @@ -30,7 +30,7 @@
tensor_norm,
)

_L_WEIGHTS = torch.FloatTensor([0.299, 0.587, 0.114])
_Y_WEIGHTS = torch.FloatTensor([0.299, 0.587, 0.114])
_MS_WEIGHTS = torch.FloatTensor([0.096, 0.596, 0.289, 0.019])


Expand All @@ -43,19 +43,31 @@ def _gmsd(
c: float = 0.00261, # 170. / (255. ** 2)
alpha: float = 0.,
) -> torch.Tensor:
r"""Returns the GMSD between `x` and `y`,
without downsampling and color space conversion.
r"""Returns the GMSD between \(x\) and \(y\),
without color space conversion and downsampling.
`_gmsd` is an auxiliary function for `gmsd` and `GMSD`.
\(\text{GMSD}(x, y)\) is the standard deviation of the Gradient
Magnitude Similarity (GMS).
$$ \text{GMS}(x, y) = \frac{(2 - \alpha) \text{GM}(x) \text{GM}(y)
+ C}{\text{GM}(x)^2 + \text{GM}(y)^2 - \alpha \text{GM}(x)
\text{GM}(y) + C} $$
$$ \text{GM}(z) = \left\| \nabla z \right\|_2 $$
where \(\nabla z\) is the result of a gradient convolution over \(z\).
Args:
x: An input tensor, (N, 1, H, W).
y: A target tensor, (N, 1, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
value_range: The value range of the inputs (usually 1. or 255).
x: An input tensor, \((N, 1, H, W)\).
y: A target tensor, \((N, 1, H, W)\).
kernel: A gradient kernel, \((2, 1, K, K)\).
value_range: The value range \(L\) of the inputs (usually 1. or 255).
For the remaining arguments, refer to [1].
Returns:
The GMSD vector, \((N,)\).
Example:
>>> x = torch.rand(5, 1, 256, 256)
>>> y = torch.rand(5, 1, 256, 256)
Expand Down Expand Up @@ -90,49 +102,8 @@ def _gmsd(
return gmsd


def gmsd(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""Returns the GMSD between `x` and `y`.
Args:
x: An input tensor, (N, 3, H, W).
y: A target tensor, (N, 3, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
`**kwargs` are transmitted to `_gmsd`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = gmsd(x, y)
>>> l.size()
torch.Size([5])
"""

# Downsample
x = F.avg_pool2d(x, kernel_size=2, ceil_mode=True)
y = F.avg_pool2d(y, kernel_size=2, ceil_mode=True)

# RGB to luminance
l_weights = _L_WEIGHTS.to(x.device).view(1, 3, 1, 1)

x = F.conv2d(x, l_weights)
y = F.conv2d(y, l_weights)

# Kernel
if kernel is None:
kernel = gradient_kernel(prewitt_kernel(), device=x.device)

return _gmsd(x, y, kernel, **kwargs)


@_jit
def _msgmsd(
def _ms_gmsd(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
Expand All @@ -141,26 +112,33 @@ def _msgmsd(
c: float = 0.00261,
alpha: float = 0.5,
) -> torch.Tensor:
r"""Returns the MS-GMSD between `x` and `y`,
r"""Returns the MS-GMSD between \(x\) and \(y\),
without color space conversion.
`_msgmsd` is an auxiliary function for `msgmsd` and `MSGMSD`.
$$ \text{MS-GMSD}(x, y) = \sum^{M}_{i = 1}
w_i \text{GMSD}(x^i, y^i) $$
where \(x^i\) and \(y^i\) are obtained by downsampling
the original tensors by a factor \(2^{i - 1}\).
Args:
x: An input tensor, (N, 1, H, W).
y: A target tensor, (N, 1, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
weights: The weights of the scales, (M,).
value_range: The value range of the inputs (usually 1. or 255).
x: An input tensor, \((N, 1, H, W)\).
y: A target tensor, \((N, 1, H, W)\).
kernel: A gradient kernel, \((2, 1, K, K)\).
weights: The weights \(w_i\) of the scales, \((M,)\).
value_range: The value range \(L\) of the inputs (usually 1. or 255).
For the remaining arguments, refer to [2].
Returns:
The MS-GMSD vector, \((N,)\).
Example:
>>> x = torch.rand(5, 1, 256, 256)
>>> y = torch.rand(5, 1, 256, 256)
>>> kernel = gradient_kernel(prewitt_kernel())
>>> weights = torch.rand(4)
>>> l = _msgmsd(x, y, kernel, weights)
>>> l = _ms_gmsd(x, y, kernel, weights)
>>> l.size()
torch.Size([5])
"""
Expand All @@ -184,66 +162,83 @@ def _msgmsd(
return msgmsd


def msgmsd(
def gmsd(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor = None,
weights: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""Returns the MS-GMSD between `x` and `y`.
r"""Returns the GMSD between \(x\) and \(y\).
Args:
x: An input tensor, (N, 3, H, W).
y: A target tensor, (N, 3, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
weights: The weights of the scales, (M,).
If `None`, use the official weights instead.
x: An input tensor, \((N, 3, H, W)\).
y: A target tensor, \((N, 3, H, W)\).
`**kwargs` are transmitted to `_msgmsd`.
`**kwargs` are transmitted to `GMSD`.
Returns:
The GMSD vector, \((N,)\).
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = msgmsd(x, y)
>>> l = gmsd(x, y)
>>> l.size()
torch.Size([5])
"""

# RGB to luminance
l_weights = _L_WEIGHTS.to(x.device).view(1, 3, 1, 1)
kwargs['reduction'] = 'none'

x = F.conv2d(x, l_weights)
y = F.conv2d(y, l_weights)
return GMSD(**kwargs).to(x.device)(x, y)

# Kernel
if kernel is None:
kernel = gradient_kernel(prewitt_kernel(), device=x.device)

# Weights
if weights is None:
weights = _MS_WEIGHTS.to(x.device)
def ms_gmsd(
x: torch.Tensor,
y: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""Returns the MS-GMSD between \(x\) and \(y\).
return _msgmsd(x, y, kernel, weights, **kwargs)
Args:
x: An input tensor, \((N, 3, H, W)\).
y: A target tensor, \((N, 3, H, W)\).
`**kwargs` are transmitted to `MS_GMSD`.
Returns:
The MS-GMSD vector, \((N,)\).
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = ms_gmsd(x, y)
>>> l.size()
torch.Size([5])
"""

kwargs['reduction'] = 'none'

return MS_GMSD(**kwargs).to(x.device)(x, y)


class GMSD(nn.Module):
r"""Creates a criterion that measures the GMSD
between an input and a target.
Before applying `_gmsd`, the input and target are converted from
RBG to Y, the luminance color space, and downsampled by a factor 2.
Args:
kernel: A 2D gradient kernel, (2, 1, K, K).
kernel: A gradient kernel, \((2, 1, K, K)\).
If `None`, use the Prewitt kernel instead.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
`**kwargs` are transmitted to `_gmsd`.
Shape:
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Shapes:
* Input: \((N, 3, H, W)\)
* Target: \((N, 3, H, W)\)
* Output: \((N,)\) or \(()\) depending on `reduction`
Example:
>>> criterion = GMSD().cuda()
Expand All @@ -268,7 +263,7 @@ def __init__(
kernel = gradient_kernel(prewitt_kernel())

self.register_buffer('kernel', kernel)
self.register_buffer('l_weights', _L_WEIGHTS.view(1, 3, 1, 1))
self.register_buffer('y_weights', _Y_WEIGHTS.view(1, 3, 1, 1))

self.reduce = build_reduce(reduction)
self.kwargs = kwargs
Expand All @@ -285,37 +280,40 @@ def forward(
input = F.avg_pool2d(input, 2, ceil_mode=True)
target = F.avg_pool2d(target, 2, ceil_mode=True)

# RGB to luminance
input = F.conv2d(input, self.l_weights)
target = F.conv2d(target, self.l_weights)
# RGB to Y
input = F.conv2d(input, self.y_weights)
target = F.conv2d(target, self.y_weights)

# GMSD
l = _gmsd(input, target, self.kernel, **self.kwargs)

return self.reduce(l)


class MSGMSD(nn.Module):
r"""Creates a criterion that measures the MSGMSD
class MS_GMSD(nn.Module):
r"""Creates a criterion that measures the MS-GMSD
between an input and a target.
Before applying `_ms_gmsd`, the input and target are converted from
RBG to Y, the luminance color space.
Args:
kernel: A 2D gradient kernel, (2, 1, K, K).
kernel: A gradient kernel, \((2, 1, K, K)\).
If `None`, use the Prewitt kernel instead.
weights: The weights of the scales, (M,).
weights: The weights of the scales, \((M,)\).
If `None`, use the official weights instead.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
`**kwargs` are transmitted to `_msgmsd`.
`**kwargs` are transmitted to `_ms_gmsd`.
Shape:
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Shapes:
* Input: \((N, 3, H, W)\)
* Target: \((N, 3, H, W)\)
* Output: \((N,)\) or \(()\) depending on `reduction`
Example:
>>> criterion = MSGMSD().cuda()
>>> criterion = MS_GMSD().cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
Expand All @@ -342,7 +340,7 @@ def __init__(

self.register_buffer('kernel', kernel)
self.register_buffer('weights', weights)
self.register_buffer('l_weights', _L_WEIGHTS.view(1, 3, 1, 1))
self.register_buffer('y_weights', _Y_WEIGHTS.view(1, 3, 1, 1))

self.reduce = build_reduce(reduction)
self.kwargs = kwargs
Expand All @@ -355,11 +353,11 @@ def forward(
r"""Defines the computation performed at every call.
"""

# RGB to luminance
input = F.conv2d(input, self.l_weights)
target = F.conv2d(target, self.l_weights)
# RGB to Y
input = F.conv2d(input, self.y_weights)
target = F.conv2d(target, self.y_weights)

# MSGMSD
l = _msgmsd(input, target, self.kernel, self.weights, **self.kwargs)
# MS-GMSD
l = _ms_gmsd(input, target, self.kernel, self.weights, **self.kwargs)

return self.reduce(l)
Loading

0 comments on commit 785b449

Please sign in to comment.