Skip to content

Commit

Permalink
Feature: Mean Deviation Similarity Index (MDSI) (#148)
Browse files Browse the repository at this point in the history
* enhance(brisque/ssim/vsi): minor enhancements

* docstr(vsi): minor formatting

* feature(mdsi): new feature

* feature(mdsi): changes proposed by @zakajd

* docstr(base): description fix

* feat(mdsi): change dims of complex numbers
  • Loading branch information
denproc authored Jul 23, 2020
1 parent d3a0540 commit f2c42cb
Show file tree
Hide file tree
Showing 13 changed files with 461 additions and 115 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,35 @@ Now LPIPS is supported only for VGG16 model. If you need other models, check [or
</p>
</details>

<!-- MDSI EXAMPLES -->
<details>
<summary>Mean Deviation Similarity Index (MDSI)</summary>
<p>

To compute MDSI as a measure, use lower case function from the library:
```python
import torch
from piq import mdsi

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
mdsi_score: torch.Tensor = mdsi(prediction, target, data_range=1.)
```

In order to use MDSI as a loss function, use corresponding PyTorch module:
```python
import torch
from piq import MDSILoss

loss = MDSILoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
```
</p>
</details>

<!-- MSID EXAMPLES -->
<details>
<summary>Multi-Scale Intrinsic Distance (MSID)</summary>
Expand Down
1 change: 1 addition & 0 deletions piq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .psnr import psnr
from .fsim import fsim, FSIMLoss
from .vsi import vsi, VSILoss
from .mdsi import mdsi, MDSILoss
2 changes: 1 addition & 1 deletion piq/brisque.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def brisque(x: torch.Tensor,
.. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
"""
_validate_input(input_tensors=x, allow_5d=False)
_validate_input(input_tensors=x, allow_5d=False, kernel_size=kernel_size)
x = _adjust_dimensions(input_tensors=x)

assert data_range >= x.max(), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.'
Expand Down
8 changes: 4 additions & 4 deletions piq/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from piq.functional.base import ifftshift, get_meshgrid, similarity_map, gradient_map
from piq.functional.colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq
from piq.functional.base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex
from piq.functional.colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm
from piq.functional.filters import hann_filter, scharr_filter, prewitt_filter, gaussian_filter
from piq.functional.layers import L2Pool2d


__all__ = [
'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map',
'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq',
'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex',
'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm',
'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter',
'L2Pool2d',
]
30 changes: 29 additions & 1 deletion piq/functional/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
r"""General purpose functions"""
from typing import Tuple
from typing import Tuple, Union
import torch


Expand Down Expand Up @@ -54,3 +54,31 @@ def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor:
grads = torch.nn.functional.conv2d(x, kernels.to(x), padding=padding)

return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True))


def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor:
r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values.
Complex numbers are represented by modulus and argument: r * \exp(i * \phi).
It will likely to be redundant with introduction of torch.ComplexTensor.
Args:
base: Tensor with shape (B x C x H x W) or (B x C x H x W x 2)
exp: Exponent
Returns:
Complex tensor with shape (B x C x H x W x 2)
"""
if base.dim() == 4:
x_complex_r = base.abs()
x_complex_phi = torch.atan2(torch.zeros_like(base), base)
elif base.dim() == 5 and base.size(-1) == 2:
x_complex_r = base.pow(2).sum(dim=-1).sqrt()
x_complex_phi = torch.atan2(base[..., 1], base[..., 0])
else:
raise ValueError(f'Expected real or complex tensor, got {base.size()}')

x_complex_pow_r = x_complex_r ** exp
x_complex_pow_phi = x_complex_phi * exp
x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi)
x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi)
return torch.stack((x_real_pow, x_imag_pow), dim=-1)
20 changes: 20 additions & 0 deletions piq/functional/colour_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,23 @@ def rgb2yiq(x: torch.Tensor) -> torch.Tensor:
[0.2115, -0.5227, 0.3112]]).t().to(x)
x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2)
return x_yiq


def rgb2lhm(x: torch.Tensor) -> torch.Tensor:
r"""Convert a batch of RGB images to a batch of LHM images
Args:
x: Batch of 4D (N x 3 x H x W) images in RGB colour space.
Returns:
Batch of 4D (N x 3 x H x W) images in LHM colour space.
Reference:
https://arxiv.org/pdf/1608.07433.pdf
"""
lhm_weights = torch.tensor([
[0.2989, 0.587, 0.114],
[0.3, 0.04, -0.35],
[0.34, -0.6, 0.17]]).t().to(x)
x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2)
return x_lhm
186 changes: 186 additions & 0 deletions piq/mdsi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
r"""Implemetation of Mean Deviation Similarity Index (MDSI)
Code supports the functionality proposed with the original MATLAB version for computations in pixel domain
https://www.mathworks.com/matlabcentral/fileexchange/59809
References:
https://arxiv.org/pdf/1608.07433.pdf
"""
import warnings
import functools
import torch
from torch.nn.modules.loss import _Loss
from torch.nn.functional import pad, avg_pool2d
from typing import Union
from piq.functional import rgb2lhm, gradient_map, similarity_map, prewitt_filter, pow_for_complex
from piq.utils import _validate_input, _adjust_dimensions


def mdsi(prediction: torch.Tensor, target: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean',
c1: float = 140., c2: float = 55., c3: float = 550., combination: str = 'sum', alpha: float = 0.6,
beta: float = 0.1, gamma: float = 0.2, rho: float = 1., q: float = 0.25, o: float = 0.25):
r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.
Note:
Both inputs are supposed to have RGB order in accordance with the original approach.
Nevertheless, the method supports greyscale images, which are converted to RGB by copying the grey
channel 3 times.
Args:
prediction: Batch of predicted (distorted) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W),
channels first.
target: Batch of target (reference) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W), channels first.
data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
reduction: Reduction over samples in batch: "mean"|"sum"|"none"
c1: coefficient to calculate gradient similarity. Default: 140.
c2: coefficient to calculate gradient similarity. Default: 55.
c3: coefficient to calculate chromaticity similarity. Default: 550.
combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
beta: power to combine gradient similarity with chromaticity similarity using multiplication.
gamma: to combine gradient similarity and chromaticity similarity using multiplication.
rho: order of the Minkowski distance
q: coefficient to adjusts the emphasis of the values in image and MCT
o: the power pooling applied on the final value of the deviation
Returns:
torch.Tensor: the batch of Mean Deviation Similarity Index (MDSI) score reduced accordingly
Note:
The ratio between constants is usually equal c3 = 4c1 = 10c2
"""
_validate_input(input_tensors=(prediction, target), allow_5d=False)
prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

if prediction.size(1) == 1:
prediction = prediction.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
warnings.warn('The original MDSI supports only RGB images. The input images were converted to RGB by copying '
'the grey channel 3 times.')

prediction = prediction * 255. / data_range
target = target * 255. / data_range

# Averaging image if the size is large enough
kernel_size = max(1, round(min(prediction.size()[-2:]) / 256))
padding = kernel_size // 2

if padding:
up_pad = (kernel_size - 1) // 2
down_pad = padding
pad_to_use = [up_pad, down_pad, up_pad, down_pad]
prediction = pad(prediction, pad=pad_to_use)
target = pad(target, pad=pad_to_use)

prediction = avg_pool2d(prediction, kernel_size=kernel_size)
target = avg_pool2d(target, kernel_size=kernel_size)

prediction_lhm = rgb2lhm(prediction)
target_lhm = rgb2lhm(target)

kernels = torch.stack([prewitt_filter(), prewitt_filter().transpose(1, 2)]).to(prediction)
gm_prediction = gradient_map(prediction_lhm[:, :1], kernels)
gm_target = gradient_map(target_lhm[:, :1], kernels)
gm_avg = gradient_map((prediction_lhm[:, :1] + target_lhm[:, :1]) / 2., kernels)

gs_prediction_target = similarity_map(gm_prediction, gm_target, c1)
gs_prediction_average = similarity_map(gm_prediction, gm_avg, c2)
gs_target_average = similarity_map(gm_target, gm_avg, c2)

gs_total = gs_prediction_target + gs_prediction_average - gs_target_average

cs_total = (2 * (prediction_lhm[:, 1:2] * target_lhm[:, 1:2] +
prediction_lhm[:, 2:] * target_lhm[:, 2:]) + c3) / (prediction_lhm[:, 1:2] ** 2 +
target_lhm[:, 1:2] ** 2 +
prediction_lhm[:, 2:] ** 2 +
target_lhm[:, 2:] ** 2 + c3)

if combination == 'sum':
gcs = (alpha * gs_total + (1 - alpha) * cs_total)
elif combination == 'mult':
gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
gs_total_pow[..., 1] + cs_total_pow[..., 1]), dim=-1)
else:
raise ValueError(f'Expected combination method "sum" or "mult", got {combination}')

mct_complex = pow_for_complex(base=gcs, exp=q)
mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) # split to increase precision
score = (pow_for_complex(base=gcs, exp=q) - mct_complex).pow(2).sum(dim=-1).sqrt()
score = ((score ** rho).mean(dim=(-1, -2)) ** (o / rho)).squeeze(1)
if reduction == 'none':
return score
return {'mean': score.mean,
'sum': score.sum}[reduction](dim=0)


class MDSILoss(_Loss):
r"""Creates a criterion that measures Mean Deviation Similarity Index (MDSI) error between the prediction and
target.
Args:
data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
reduction: Reduction over samples in batch: "mean"|"sum"|"none"
c1: coefficient to calculate gradient similarity. Default: 140.
c2: coefficient to calculate gradient similarity. Default: 55.
c3: coefficient to calculate chromaticity similarity. Default: 550.
combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
beta: power to combine gradient similarity with chromaticity similarity using multiplication.
gamma: to combine gradient similarity and chromaticity similarity using multiplication.
rho: order of the Minkowski distance
q: coefficient to adjusts the emphasis of the values in image and MCT
o: the power pooling applied on the final value of the deviation
Shape:
- Input: Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W), channels first.
- Target: Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W), channels first.
Both inputs are supposed to have RGB order in accordance with the original approach.
Nevertheless, the method supports greyscale images, which they are converted to RGB
by copying the grey channel 3 times.
Examples::
>>> loss = MDSILoss(data_range=1.)
>>> prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> target = torch.rand(3, 3, 256, 256)
>>> output = loss(prediction, target)
>>> output.backward()
References:
.. [1] Nafchi, Hossein Ziaei and Shahkolaei, Atena and Hedjam, Rachid and Cheriet, Mohamed
(2016). Mean deviation similarity index: Efficient and reliable full-reference image quality evaluator.
IEEE Ieee Access,
4, 5579--5590.
https://ieeexplore.ieee.org/abstract/document/7556976/,
:DOI:`10.1109/ACCESS.2016.2604042`
"""
def __init__(self, data_range: Union[int, float] = 1., reduction: str = 'mean',
c1: float = 140., c2: float = 55., c3: float = 550., alpha: float = 0.6,
rho: float = 1., q: float = 0.25, o: float = 0.25, combination: str = 'sum',
beta: float = 0.1, gamma: float = 0.2):
super().__init__()
self.reduction = reduction
self.data_range = data_range
self.mdsi = functools.partial(mdsi, c1=c1, c2=c2, c3=c3, alpha=alpha, rho=rho, q=q, o=o,
combination=combination, beta=beta, gamma=gamma, data_range=self.data_range,
reduction=self.reduction)

def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""Computation of Mean Deviation Similarity Index (MDSI) as a loss function.
Args:
prediction: Tensor of prediction of the network. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W),
channels first.
target: Reference tensor. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W), channels first.
Returns:
Value of MDSI loss to be minimized. 0 <= MDSI loss <= 1.
Note:
Both inputs are supposed to have RGB order in accordance with the original approach.
Nevertheless, the method supports greyscale images, which are converted to RGB by copying the grey
channel 3 times.
"""
return 1. - torch.clamp(self.mdsi(prediction=prediction, target=target), min=0., max=1.)
16 changes: 8 additions & 8 deletions piq/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ def _multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, f
ssim_val = None
for iteration in range(levels):
if iteration > 0:
padding = (x.shape[2] % 2, x.shape[3] % 2)
x = F.pad(x, pad=[padding[0], 0, padding[1], 0], mode='replicate')
y = F.pad(y, pad=[padding[0], 0, padding[1], 0], mode='replicate')
padding = max(x.shape[2] % 2, x.shape[3] % 2)
x = F.pad(x, pad=[padding, 0, padding, 0], mode='replicate')
y = F.pad(y, pad=[padding, 0, padding, 0], mode='replicate')
x = F.avg_pool2d(x, kernel_size=2, padding=0)
y = F.avg_pool2d(y, kernel_size=2, padding=0)

Expand Down Expand Up @@ -536,11 +536,11 @@ def _multi_scale_ssim_complex(x: torch.Tensor, y: torch.Tensor, data_range: Unio
y_real = y[..., 0]
y_imag = y[..., 1]
if iteration > 0:
padding = (x.size(2) % 2, x.size(3) % 2)
x_real = F.pad(x_real, pad=[padding[0], 0, padding[1], 0], mode='replicate')
x_imag = F.pad(x_imag, pad=[padding[0], 0, padding[1], 0], mode='replicate')
y_real = F.pad(y_real, pad=[padding[0], 0, padding[1], 0], mode='replicate')
y_imag = F.pad(y_imag, pad=[padding[0], 0, padding[1], 0], mode='replicate')
padding = max(x.size(2) % 2, x.size(3) % 2)
x_real = F.pad(x_real, pad=[padding, 0, padding, 0], mode='replicate')
x_imag = F.pad(x_imag, pad=[padding, 0, padding, 0], mode='replicate')
y_real = F.pad(y_real, pad=[padding, 0, padding, 0], mode='replicate')
y_imag = F.pad(y_imag, pad=[padding, 0, padding, 0], mode='replicate')

x_real = F.avg_pool2d(x_real, kernel_size=2, padding=0)
x_imag = F.avg_pool2d(x_imag, kernel_size=2, padding=0)
Expand Down
Loading

0 comments on commit f2c42cb

Please sign in to comment.