diff --git a/README.md b/README.md index 3ff33eb5..f357c3ed 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,18 @@ By default input images are normalized with ImageNet statistics before forwardin This is port of MATLAB version from the authors of original paper. It can be used both as a measure and as a loss function. In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval. + + To compute GMSD as a measure, use lower case function from the library: + ```python + import torch + from piq import gmsd + + prediction = torch.rand(3, 3, 256, 256) + target = torch.rand(3, 3, 256, 256) + gmsd: torch.Tensor = gmsd(prediction, target, data_range=1.) + ``` + + In order to use GMSD as a loss function, use corresponding PyTorch module: ```python import torch from piq import GMSDLoss @@ -384,11 +396,25 @@ Now LPIPS is supported only for VGG16 model. If you need other models, check [or
- MultiScale GMSD (MS-GMSD) + Multi-Scale GMSD (MS-GMSD)

It can be used both as a measure and as a loss function. In any case it should me minimized. - By defualt scale weights are initialized with values from the paper. You can change them by passing a list of 4 variables to `scale_weights` argument during initialization. Both GMSD and MS-GMSD computed for greyscale images, but to take contrast changes into account authors propoced to also add chromatic component. Use flag `chromatic` to use MS-GMSDc version of the loss + By defualt scale weights are initialized with values from the paper. You can change them by passing a list of 4 variables to `scale_weights` argument during initialization. Both GMSD and MS-GMSD computed for greyscale images, but to take contrast changes into account authors propoced to also add chromatic component. Use flag `chromatic` to use MS-GMSDc version of the loss. + + Note that input tensors should contain images with height and width equal `2 ** number_of_scales + 1` at least. + + To compute Multi-Scale GMSD as a measure, use lower case function from the library: + ```python + import torch + from piq import multi_scale_gmsd + + prediction = torch.rand(3, 3, 256, 256) + target = torch.rand(3, 3, 256, 256) + multi_scale_gmsd: torch.Tensor = multi_scale_gmsd(prediction, target, data_range=1.) + ``` + + In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module: ```python import torch from piq import MultiScaleGMSDLoss diff --git a/piq/__init__.py b/piq/__init__.py index 6a1b9df3..24a3eea6 100644 --- a/piq/__init__.py +++ b/piq/__init__.py @@ -5,7 +5,7 @@ from .fid import FID from .kid import KID from .tv import TVLoss, total_variation -from .gmsd import GMSDLoss, MultiScaleGMSDLoss +from .gmsd import gmsd, multi_scale_gmsd, GMSDLoss, MultiScaleGMSDLoss from .gs import GS from .isc import IS, inception_score from .vif import VIFLoss, vif_p diff --git a/piq/gmsd.py b/piq/gmsd.py index fdcea0cd..97813d6f 100644 --- a/piq/gmsd.py +++ b/piq/gmsd.py @@ -15,11 +15,11 @@ from torch.nn.modules.loss import _Loss from piq.utils import _adjust_dimensions, _validate_input -from piq.functional import similarity_map, gradient_map, prewitt_filter +from piq.functional import similarity_map, gradient_map, prewitt_filter, rgb2yiq -def _gmsd(prediction: torch.Tensor, target: torch.Tensor, - reduction: Optional[str] = 'mean') -> torch.Tensor: +def gmsd(prediction: torch.Tensor, target: torch.Tensor, reduction: Optional[str] = 'mean', + data_range: Union[int, float] = 1., t: float = 170 / (255. ** 2)) -> torch.Tensor: r"""Compute Gradient Magnitude Similarity Deviation Both inputs supposed to be in range [0, 1] with RGB order. Args: @@ -29,6 +29,10 @@ def _gmsd(prediction: torch.Tensor, target: torch.Tensor, ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. + data_range: The difference between the maximum and minimum of the pixel value, + i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. + The pixel value interval of both input and output should remain the same. + t: Constant from the reference paper numerical stability of similarity map. Returns: gmsd : Gradient Magnitude Similarity Deviation between given tensors. @@ -36,37 +40,61 @@ def _gmsd(prediction: torch.Tensor, target: torch.Tensor, References: https://arxiv.org/pdf/1308.3052.pdf """ - # Constant for numerical stability - EPS: float = 0.0026 - # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B + _validate_input(input_tensors=(prediction, target), allow_5d=False, scale_weights=None) + prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) + + prediction = prediction / float(data_range) + target = target / float(data_range) + num_channels = prediction.size(1) if num_channels == 3: - prediction = 0.299 * prediction[:, 0, :, :] + 0.587 * prediction[:, 1, :, :] + 0.114 * prediction[:, 2, :, :] - target = 0.299 * target[:, 0, :, :] + 0.587 * target[:, 1, :, :] + 0.114 * target[:, 2, :, :] + prediction = rgb2yiq(prediction)[:, :1] + target = rgb2yiq(target)[:, :1] + up_pad = 0 + down_pad = max(prediction.shape[2] % 2, prediction.shape[3] % 2) + pad_to_use = [up_pad, down_pad, up_pad, down_pad] + prediction = F.pad(prediction, pad=pad_to_use) + target = F.pad(target, pad=pad_to_use) + + prediction = F.avg_pool2d(prediction, kernel_size=2, stride=2, padding=0) + target = F.avg_pool2d(target, kernel_size=2, stride=2, padding=0) + + score = _gmsd(prediction=prediction, target=target, t=t) + if reduction == 'none': + return score + + return {'mean': score.mean, + 'sum': score.sum + }[reduction](dim=0) + + +def _gmsd(prediction: torch.Tensor, target: torch.Tensor, t: float = 170 / (255. ** 2)) -> torch.Tensor: + r"""Compute Gradient Magnitude Similarity Deviation + Both inputs supposed to be in range [0, 1] with RGB order. + Args: + prediction: Tensor of shape :math:`(N, 1, H, W)` holding an distorted grayscale image. + target: Tensor of shape :math:`(N, 1, H, W)` holding an target grayscale image + t: Constant from the reference paper numerical stability of similarity map + + Returns: + gmsd : Gradient Magnitude Similarity Deviation between given tensors. + + References: + https://arxiv.org/pdf/1308.3052.pdf + """ - # Add channel dimension - prediction = prediction[:, None, :, :] - target = target[:, None, :, :] - # Compute grad direction kernels = torch.stack([prewitt_filter(), prewitt_filter().transpose(-1, -2)]) pred_grad = gradient_map(prediction, kernels) trgt_grad = gradient_map(target, kernels) # Compute GMS - gms = similarity_map(pred_grad, trgt_grad, EPS) + gms = similarity_map(pred_grad, trgt_grad, t) mean_gms = torch.mean(gms, dim=[1, 2, 3], keepdims=True) - # Compute GMSD along spatial dimensions. Shape (batch_size ) score = torch.pow(gms - mean_gms, 2).mean(dim=[1, 2, 3]).sqrt() - - if reduction == 'none': - return score - - return {'mean': score.mean, - 'sum': score.sum - }[reduction](dim=0) + return score class GMSDLoss(_Loss): @@ -81,6 +109,7 @@ class GMSDLoss(_Loss): data_range: The difference between the maximum and minimum of the pixel value, i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. The pixel value interval of both input and output should remain the same. + t: Constant from the reference paper numerical stability of similarity map Reference: Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013) @@ -88,7 +117,8 @@ class GMSDLoss(_Loss): """ - def __init__(self, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None: + def __init__(self, reduction: str = 'mean', data_range: Union[int, float] = 1., + t: float = 170 / (255. ** 2)) -> None: super().__init__() # Generic loss parameters. @@ -96,6 +126,7 @@ def __init__(self, reduction: str = 'mean', data_range: Union[int, float] = 1.) # Loss-specific parameters. self.data_range = data_range + self.t = t def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r"""Computation of Gradient Magnitude Similarity Deviation (GMSD) as a loss function. @@ -107,27 +138,110 @@ def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tenso Returns: Value of GMSD loss to be minimized. 0 <= GMSD loss <= 1. """ - _validate_input(input_tensors=(prediction, target), allow_5d=False, scale_weights=None) - prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) - return self.compute_metric(prediction, target) - - def compute_metric(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - - if self.data_range == 255: - prediction = prediction / 255. - target = target / 255. - - # Average by 2x2 filter and downsample - padding = (prediction.shape[2] % 2, prediction.shape[3] % 2) - prediction = F.avg_pool2d(prediction, kernel_size=2, stride=2, padding=padding) - target = F.avg_pool2d(target, kernel_size=2, stride=2, padding=padding) - - score = _gmsd( - prediction, target, reduction=self.reduction) + return gmsd(prediction=prediction, target=target, reduction=self.reduction, data_range=self.data_range, + t=self.t) + + +def multi_scale_gmsd(prediction: torch.Tensor, target: torch.Tensor, data_range: Union[int, float] = 1., + reduction: str = 'mean', + scale_weights: Optional[Union[torch.Tensor, Tuple[float, ...], List[float]]] = None, + chromatic: bool = False, beta1: float = 0.01, beta2: float = 0.32, beta3: float = 15., + t: float = 170 / (255. ** 2)) -> torch.Tensor: + r"""Computation of Multi scale GMSD. + + Args: + prediction: Tensor of prediction of the network. The height and width should be at least 2 ** scales + 1. + target: Reference tensor. The height and width should be at least 2 ** scales + 1. + data_range: The difference between the maximum and minimum of the pixel value, + i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. + The pixel value interval of both input and output should remain the same. + reduction: Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + scale_weights: Weights for different scales. Can contain any number of floating point values. + chromatic: Flag to use MS-GMSDc algorithm from paper. + It also evaluates chromatic components of the image. Default: True + beta1: Algorithm parameter. Weight of chromatic component in the loss. + beta2: Algorithm parameter. Small constant, see [1]. + beta3: Algorithm parameter. Small constant, see [1]. + t: Constant from the reference paper numerical stability of similarity map + + Returns: + Value of MS-GMSD. 0 <= GMSD loss <= 1. + """ + _validate_input(input_tensors=(prediction, target), allow_5d=False, scale_weights=scale_weights) + prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) + + # Values from the paper + if scale_weights is None: + scale_weights = torch.tensor([0.096, 0.596, 0.289, 0.019]) + elif isinstance(scale_weights, torch.Tensor): + scale_weights = scale_weights / scale_weights.sum() + else: + # Normalize scale weights + scale_weights = torch.tensor(scale_weights) / torch.tensor(scale_weights).sum() + + # Check that input is big enough + num_scales = scale_weights.size(0) + min_size = 2 ** num_scales + 1 + + if prediction.size(-1) < min_size or prediction.size(-2) < min_size: + raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.') + + prediction = prediction / float(data_range) + target = target / float(data_range) + + num_channels = prediction.size(1) + if num_channels == 3: + prediction = rgb2yiq(prediction) + target = rgb2yiq(target) + + scale_weights = scale_weights.to(prediction) + ms_gmds = [] + for scale in range(num_scales): + if scale > 0: + # Average by 2x2 filter and downsample + up_pad = 0 + down_pad = max(prediction.shape[2] % 2, prediction.shape[3] % 2) + pad_to_use = [up_pad, down_pad, up_pad, down_pad] + prediction = F.pad(prediction, pad=pad_to_use) + target = F.pad(target, pad=pad_to_use) + prediction = F.avg_pool2d(prediction, kernel_size=2, padding=0) + target = F.avg_pool2d(target, kernel_size=2, padding=0) + + score = _gmsd(prediction[:, :1], target[:, :1], t=t) + ms_gmds.append(score) + + # Stack results in different scales and multiply by weight + ms_gmds_val = scale_weights.view(1, num_scales) * (torch.stack(ms_gmds, dim=1) ** 2) + + # Sum and take sqrt per-image + ms_gmds_val = torch.sqrt(torch.sum(ms_gmds_val, dim=1)) + + # Shape: (batch_size, ) + score = ms_gmds_val + + if chromatic: + assert prediction.size(1) == 3, "Chromatic component can be computed only for RGB images!" + + prediction_iq = prediction[:, 1:] + target_iq = target[:, 1:] + + rmse_iq = torch.sqrt(torch.mean((prediction_iq - target_iq) ** 2, dim=[2, 3])) + rmse_chrome = torch.sqrt(torch.sum(rmse_iq ** 2, dim=1)) + gamma = 2 / (1 + beta2 * torch.exp(-beta3 * ms_gmds_val)) - 1 + score = gamma * ms_gmds_val + (1 - gamma) * beta1 * rmse_chrome + + if reduction == 'none': return score - + + return {'mean': score.mean, + 'sum': score.sum + }[reduction](dim=0) + class MultiScaleGMSDLoss(_Loss): r"""Creates a criterion that measures multi scale Gradient Magnitude Similarity Deviation @@ -148,6 +262,7 @@ class MultiScaleGMSDLoss(_Loss): beta1: Algorithm parameter. Weight of chromatic component in the loss. beta2: Algorithm parameter. Small constant, see [1]. beta3: Algorithm parameter. Small constant, see [1]. + t: Constant from the reference paper numerical stability of similarity map Reference: [1] GRADIENT MAGNITUDE SIMILARITY DEVIATION ON MULTIPLE SCALES (2017) @@ -155,9 +270,9 @@ class MultiScaleGMSDLoss(_Loss): """ def __init__(self, reduction: str = 'mean', data_range: Union[int, float] = 1., - scale_weights: Optional[Union[Tuple[float], List[float]]] = None, + scale_weights: Optional[Union[torch.Tensor, Tuple[float, ...], List[float]]] = None, chromatic: bool = False, beta1: float = 0.01, beta2: float = 0.32, - beta3: float = 15.) -> None: + beta3: float = 15., t: float = 170 / (255. ** 2)) -> None: super().__init__() # Generic loss parameters. @@ -165,84 +280,26 @@ def __init__(self, reduction: str = 'mean', data_range: Union[int, float] = 1., # Loss-specific parameters. self.data_range = data_range - - # Values from the paper - if scale_weights is None: - self.scale_weights = torch.tensor([0.096, 0.596, 0.289, 0.019]) - else: - # Normalize scale weights - self.scale_weights = torch.tensor(scale_weights) / torch.tensor(scale_weights).sum() + self.scale_weights = scale_weights self.chromatic = chromatic self.beta1 = beta1 self.beta2 = beta2 self.beta3 = beta3 + self.t = t def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r"""Computation of Multi scale GMSD as a loss function. + r"""Computation of Multi Scale GMSD index as a loss function. Args: - prediction: Tensor of prediction of the network. - target: Reference tensor. + prediction: Tensor of prediction of the network. Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), + channels first. The height and width should be at least 2 ** scales + 1. + target: Reference tensor. Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first. + The height and width should be at least 2 ** scales + 1. Returns: - Value of GMSD loss to be minimized. 0 <= GMSD loss <= 1. + Value of MS-GMSD loss to be minimized. 0 <= MS-GMSD loss <= 1. """ - _validate_input(input_tensors=(prediction, target), allow_5d=False, scale_weights=self.scale_weights) - prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) - - return self.compute_metric(prediction, target) - - def compute_metric(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - # Check that input is big enough - num_scales = self.scale_weights.size(0) - min_size = 2 ** num_scales + 1 - - if prediction.size(-1) < min_size or prediction.size(-2) < min_size: - raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.') - - if self.data_range == 255: - prediction = prediction / 255. - target = target / 255. - - scale_weights = self.scale_weights.to(prediction) - ms_gmds = [] - for scale in range(num_scales): - if scale > 0: - # Average by 2x2 filter and downsample - padding = (prediction.shape[2] % 2, prediction.shape[3] % 2) - prediction = F.avg_pool2d(prediction, kernel_size=2, padding=padding) - target = F.avg_pool2d(target, kernel_size=2, padding=padding) - - score = _gmsd(prediction, target, reduction='none') - ms_gmds.append(score) - - # Stack results in different scales and multiply by weight - ms_gmds_val = scale_weights.view(1, num_scales) * (torch.stack(ms_gmds, dim=1) ** 2) - - # Sum and take sqrt per-image - ms_gmds_val = torch.sqrt(torch.sum(ms_gmds_val, dim=1)) - - # Shape: (batch_size, ) - score = ms_gmds_val - - if self.chromatic: - assert prediction.size(1) == 3, "Chromatic component can be computed only for RGB images!" - - # Convert to YIQ color space https://en.wikipedia.org/wiki/YIQ - iq_weights = torch.tensor([[0.5959, -0.2746, -0.3213], [0.2115, -0.5227, 0.3112]]).t().to(prediction) - prediction_iq = torch.matmul(prediction.permute(0, 2, 3, 1), iq_weights).permute(0, 3, 1, 2) - target_iq = torch.matmul(target.permute(0, 2, 3, 1), iq_weights).permute(0, 3, 1, 2) - - rmse_iq = torch.sqrt(torch.mean((prediction_iq - target_iq) ** 2, dim=[2, 3])) - rmse_chrome = torch.sqrt(torch.sum(rmse_iq ** 2, dim=1)) - gamma = 2 / (1 + self.beta2 * torch.exp(-self.beta3 * ms_gmds_val)) - 1 - - score = gamma * ms_gmds_val + (1 - gamma) * self.beta1 * rmse_chrome - - if self.reduction == 'none': - return score - - return {'mean': score.mean, - 'sum': score.sum - }[self.reduction](dim=0) + return multi_scale_gmsd(prediction=prediction, target=target, data_range=self.data_range, + reduction=self.reduction, chromatic=self.chromatic, beta1=self.beta1, + beta2=self.beta2, beta3=self.beta3, scale_weights=self.scale_weights, t=self.t) diff --git a/tests/test_gmsd.py b/tests/test_gmsd.py index 756ae0b3..ba7652cd 100644 --- a/tests/test_gmsd.py +++ b/tests/test_gmsd.py @@ -1,175 +1,264 @@ import torch import pytest +from skimage.io import imread +import numpy as np +from typing import Any, Tuple -from piq import GMSDLoss, MultiScaleGMSDLoss +from piq import gmsd, multi_scale_gmsd, GMSDLoss, MultiScaleGMSDLoss LEAF_VARIABLE_ERROR_MESSAGE = 'Expected non None gradient of leaf variable' @pytest.fixture(scope='module') def prediction() -> torch.Tensor: - return torch.rand(2, 3, 128, 128) + return torch.rand(2, 3, 96, 96) @pytest.fixture(scope='module') def target() -> torch.Tensor: - return torch.rand(2, 3, 128, 128) + return torch.rand(2, 3, 96, 96) -# ================== Test class: `GMSDLoss` ================== -def test_gmsd_loss(prediction: torch.Tensor, target: torch.Tensor) -> None: - loss = GMSDLoss() - loss(prediction, target) +prediction_image = [ + torch.tensor(imread('tests/assets/goldhill_jpeg.gif'), dtype=torch.float32).unsqueeze(0).unsqueeze(0), + torch.tensor(imread('tests/assets/i01_01_5.bmp'), dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) +] +target_image = [ + torch.tensor(imread('tests/assets/goldhill.gif'), dtype=torch.float32).unsqueeze(0).unsqueeze(0), + torch.tensor(imread('tests/assets/I01.BMP'), dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) +] +target_score = [ + torch.tensor(0.138012587141798), + torch.tensor(0.094124655829098) +] -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_gmsd_loss_on_gpu(prediction: torch.Tensor, target: torch.Tensor) -> None: - loss = GMSDLoss() - loss(prediction.cuda(), target.cuda()) +@pytest.fixture(params=zip(prediction_image, target_image, target_score)) +def input_images_score(request: Any) -> Any: + return request.param -def test_gmsd_loss_backward(prediction: torch.Tensor, target: torch.Tensor) -> None: - prediction.requires_grad_() - loss_value = GMSDLoss()(prediction, target) - loss_value.backward() - assert prediction.grad is not None, LEAF_VARIABLE_ERROR_MESSAGE + +# ================== Test function: `gmsd` ================== +def test_gmsd_forward(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + gmsd(prediction.to(device), target.to(device)) + + +def test_gmsd_zero_for_equal_tensors(prediction: torch.Tensor, device: str) -> None: + target = prediction.clone() + measure = gmsd(prediction.to(device), target.to(device)) + assert measure.abs() <= 1e-6, f'GMSD for equal tensors must be 0, got {measure}' + + +def test_gmsd_raises_if_tensors_have_different_types(target: torch.Tensor, device: str) -> None: + wrong_type_predictions = [list(range(10)), np.arange(10)] + for wrong_type_prediction in wrong_type_predictions: + with pytest.raises(AssertionError): + gmsd(wrong_type_prediction, target.to(device)) + + +def test_gmsd_supports_different_data_ranges(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + prediction_255 = (prediction * 255).type(torch.uint8) + target_255 = (target * 255).type(torch.uint8) + measure = gmsd(prediction.to(device), target.to(device)) + + measure_255 = gmsd(prediction_255.to(device), target_255.to(device), data_range=255) + diff = torch.abs(measure_255 - measure) + assert diff <= 1e-4, f'Result for same tensor with different data_range should be the same, got {diff}' + + +def test_gmsd_supports_greyscale_tensors(device: str) -> None: + target = torch.ones(2, 1, 96, 96) + prediction = torch.zeros(2, 1, 96, 96) + gmsd(prediction.to(device), target.to(device)) + + +def test_gmsd_modes(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + for reduction in ['mean', 'sum', 'none']: + gmsd(prediction.to(device), target.to(device), reduction=reduction) + + for reduction in ['DEADBEEF', 'random']: + with pytest.raises(KeyError): + gmsd(prediction.to(device), target.to(device), reduction=reduction) + + +def test_gmsd_compare_with_matlab(input_images_score: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + device: str) -> None: + prediction, target, target_value = input_images_score + score = gmsd(prediction=prediction.to(device), target=target.to(device), data_range=255) + assert torch.isclose(score, target_value.to(score)), f'The estimated value must be equal to MATLAB provided one, ' \ + f'got {score.item():.8f}, while MATLAB equals {target_value}' -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_gmsd_loss_backward_on_gpu(prediction: torch.Tensor, target: torch.Tensor) -> None: +# ================== Test class: `GMSDLoss` ================== +def test_gmsd_loss_forward_backward(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: prediction.requires_grad_() - loss_value = GMSDLoss()(prediction.cuda(), target.cuda()) + loss_value = GMSDLoss()(prediction.to(device), target.to(device)) loss_value.backward() - assert prediction.grad is not None, LEAF_VARIABLE_ERROR_MESSAGE + assert torch.isfinite(prediction.grad).all(), LEAF_VARIABLE_ERROR_MESSAGE -def test_gmsd_zero_for_equal_tensors(prediction: torch.Tensor): +def test_gmsd_loss_zero_for_equal_tensors(prediction: torch.Tensor, device: str) -> None: loss = GMSDLoss() target = prediction.clone() - measure = loss(prediction, target) + measure = loss(prediction.to(device), target.to(device)) assert measure.abs() <= 1e-6, f'GMSD for equal tensors must be 0, got {measure}' -def test_gmsd_loss_raises_if_tensors_have_different_types(target: torch.Tensor) -> None: - wrong_type_prediction = list(range(10)) - with pytest.raises(AssertionError): - GMSDLoss()(wrong_type_prediction, target) +def test_gmsd_loss_raises_if_tensors_have_different_types(target: torch.Tensor, device: str) -> None: + wrong_type_predictions = [list(range(10)), np.arange(10)] + for wrong_type_prediction in wrong_type_predictions: + with pytest.raises(AssertionError): + GMSDLoss()(wrong_type_prediction, target.to(device)) -def test_gmsd_loss_supports_different_data_ranges(prediction: torch.Tensor, target: torch.Tensor) -> None: +def test_gmsd_loss_supports_different_data_ranges(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: prediction_255 = (prediction * 255).type(torch.uint8) target_255 = (target * 255).type(torch.uint8) loss = GMSDLoss() - measure = loss(prediction, target) + measure = loss(prediction.to(device), target.to(device)) loss_255 = GMSDLoss(data_range=255) - measure_255 = loss_255(prediction_255, target_255) + measure_255 = loss_255(prediction_255.to(device), target_255.to(device)) diff = torch.abs(measure_255 - measure) assert diff <= 1e-4, f'Result for same tensor with different data_range should be the same, got {diff}' -def test_gmsd_supports_greyscale_tensors(): +def test_gmsd_loss_supports_greyscale_tensors(device: str) -> None: loss = GMSDLoss() - target = torch.ones(2, 1, 128, 128) - prediction = torch.zeros(2, 1, 128, 128) - loss(prediction, target) + target = torch.ones(2, 1, 96, 96) + prediction = torch.zeros(2, 1, 96, 96) + loss(prediction.to(device), target.to(device)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_gmsd_supports_greyscale_tensors_on_gpu(): - loss = GMSDLoss() - target = torch.ones(2, 1, 128, 128).cuda() - prediction = torch.zeros(2, 1, 128, 128).cuda() - loss(prediction, target) +def test_gmsd_loss_modes(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + for reduction in ['mean', 'sum', 'none']: + GMSDLoss(reduction=reduction)(prediction.to(device), target.to(device)) + for reduction in ['DEADBEEF', 'random']: + with pytest.raises(KeyError): + GMSDLoss(reduction=reduction)(prediction.to(device), target.to(device)) -# ================== Test class: `MultiScaleGMSDLoss` ================== -def test_multi_scale_gmsd_loss(prediction: torch.Tensor, target: torch.Tensor) -> None: - loss = MultiScaleGMSDLoss(chromatic=True) - loss(prediction, target) +# ================== Test function: `multi_scale_gmsd` ================== +def test_multi_scale_gmsd_forward_backward(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + multi_scale_gmsd(prediction.to(device), target.to(device), chromatic=True) -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_multi_scale_gmsd_loss_on_gpu(prediction: torch.Tensor, target: torch.Tensor) -> None: - prediction = prediction.cuda() - target = target.cuda() - loss = MultiScaleGMSDLoss(chromatic=True) - loss(prediction, target) +def test_multi_scale_gmsd_zero_for_equal_tensors(prediction: torch.Tensor, device: str) -> None: + target = prediction.clone() + measure = multi_scale_gmsd(prediction.to(device), target.to(device)) + assert measure.abs() <= 1e-6, f'MultiScaleGMSD for equal tensors must be 0, got {measure}' + + +def test_multi_scale_gmsd_supports_different_data_ranges(prediction: torch.Tensor, target: torch.Tensor, + device: str) -> None: + prediction_255 = (prediction * 255).type(torch.uint8) + target_255 = (target * 255).type(torch.uint8) + + measure = multi_scale_gmsd(prediction.to(device), target.to(device)) + measure_255 = multi_scale_gmsd(prediction_255.to(device), target_255.to(device), data_range=255) + diff = torch.abs(measure_255 - measure) + assert diff <= 1e-4, f'Result for same tensor with different data_range should be the same, got {diff}' -def test_multi_scale_gmsd_loss_backward(prediction: torch.Tensor, target: torch.Tensor) -> None: - prediction.requires_grad_() - loss_value = MultiScaleGMSDLoss(chromatic=True)(prediction, target) - loss_value.backward() - assert prediction.grad is not None, LEAF_VARIABLE_ERROR_MESSAGE +def test_multi_scale_gmsd_supports_greyscale_tensors(device: str) -> None: + target = torch.ones(2, 1, 96, 96) + prediction = torch.zeros(2, 1, 96, 96) + multi_scale_gmsd(prediction.to(device), target.to(device)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_multi_scale_gmsd_loss_backward_on_gpu(prediction: torch.Tensor, target: torch.Tensor) -> None: + +def test_multi_scale_gmsd_fails_for_greyscale_tensors_chromatic_flag(device: str) -> None: + target = torch.ones(2, 1, 96, 96) + prediction = torch.zeros(2, 1, 96, 96) + with pytest.raises(AssertionError): + multi_scale_gmsd(prediction.to(device), target.to(device), chromatic=True) + + +def test_multi_scale_gmsd_supports_custom_weights(prediction: torch.Tensor, target: torch.Tensor, + device: str) -> None: + multi_scale_gmsd(prediction.to(device), target.to(device), scale_weights=[3., 4., 2., 1., 2.]) + multi_scale_gmsd(prediction.to(device), target.to(device), scale_weights=torch.tensor([3., 4., 2., 1., 2.])) + + +def test_multi_scale_gmsd_raise_exception_for_small_images(device: str) -> None: + target = torch.ones(3, 1, 32, 32) + prediction = torch.zeros(3, 1, 32, 32) + with pytest.raises(ValueError): + multi_scale_gmsd(prediction.to(device), target.to(device), scale_weights=[3., 4., 2., 1., 1.]) + + +def test_multi_scale_gmsd_modes(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + for reduction in ['mean', 'sum', 'none']: + multi_scale_gmsd(prediction.to(device), target.to(device), reduction=reduction) + + for reduction in ['DEADBEEF', 'random']: + with pytest.raises(KeyError): + multi_scale_gmsd(prediction.to(device), target.to(device), reduction=reduction) + + +# ================== Test class: `MultiScaleGMSDLoss` ================== +def test_multi_scale_gmsd_loss_forward_backward(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: prediction.requires_grad_() - loss_value = MultiScaleGMSDLoss(chromatic=True)(prediction, target) + loss_value = MultiScaleGMSDLoss(chromatic=True)(prediction.to(device), target.to(device)) loss_value.backward() - assert prediction.grad is not None, LEAF_VARIABLE_ERROR_MESSAGE + assert torch.isfinite(prediction.grad).all(), LEAF_VARIABLE_ERROR_MESSAGE -def test_multi_scale_gmsd_zero_for_equal_tensors(prediction: torch.Tensor): +def test_multi_scale_gmsd_loss_zero_for_equal_tensors(prediction: torch.Tensor, device: str) -> None: loss = MultiScaleGMSDLoss() target = prediction.clone() - measure = loss(prediction, target) + measure = loss(prediction.to(device), target.to(device)) assert measure.abs() <= 1e-6, f'MultiScaleGMSD for equal tensors must be 0, got {measure}' -def test_multi_scale_gmsd_loss_supports_different_data_ranges( - prediction: torch.Tensor, target: torch.Tensor) -> None: +def test_multi_scale_gmsd_loss_supports_different_data_ranges(prediction: torch.Tensor, target: torch.Tensor, + device: str) -> None: prediction_255 = (prediction * 255).type(torch.uint8) target_255 = (target * 255).type(torch.uint8) loss = MultiScaleGMSDLoss() - measure = loss(prediction, target) + measure = loss(prediction.to(device), target.to(device)) loss_255 = MultiScaleGMSDLoss(data_range=255) - measure_255 = loss_255(prediction_255, target_255) + measure_255 = loss_255(prediction_255.to(device), target_255.to(device)) diff = torch.abs(measure_255 - measure) assert diff <= 1e-4, f'Result for same tensor with different data_range should be the same, got {diff}' -def test_multi_scale_gmsd_supports_greyscale_tensors(): - loss = MultiScaleGMSDLoss() - target = torch.ones(2, 1, 128, 128) - prediction = torch.zeros(2, 1, 128, 128) - loss(prediction, target) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_multi_scale_gmsd_supports_greyscale_tensors_on_gpu(): +def test_multi_scale_gmsd_loss_supports_greyscale_tensors(device: str) -> None: loss = MultiScaleGMSDLoss() - target = torch.ones(2, 1, 128, 128).cuda() - prediction = torch.zeros(2, 1, 128, 128).cuda() - loss(prediction.cuda(), target.cuda()) + target = torch.ones(2, 1, 96, 96) + prediction = torch.zeros(2, 1, 96, 96) + loss(prediction.to(device), target.to(device)) -def test_multi_scale_gmsd_fails_for_greyscale_tensors_chromatic_flag(): +def test_multi_scale_gmsd_loss_fails_for_greyscale_tensors_chromatic_flag(device: str) -> None: loss = MultiScaleGMSDLoss(chromatic=True) - target = torch.ones(2, 1, 128, 128) - prediction = torch.zeros(2, 1, 128, 128) + target = torch.ones(2, 1, 96, 96) + prediction = torch.zeros(2, 1, 96, 96) with pytest.raises(AssertionError): - loss(prediction, target) - - -def test_multi_scale_gmsd_supports_custom_weights( - prediction: torch.Tensor, target: torch.Tensor): - loss = MultiScaleGMSDLoss(scale_weights=[3., 4., 2., 1., 2.]) - loss(prediction, target) + loss(prediction.to(device), target.to(device)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_multi_scale_gmsd_supports_custom_weights_on_gpu( - prediction: torch.Tensor, target: torch.Tensor): +def test_multi_scale_gmsd_loss_supports_custom_weights(prediction: torch.Tensor, target: torch.Tensor, + device: str) -> None: loss = MultiScaleGMSDLoss(scale_weights=[3., 4., 2., 1., 2.]) - loss(prediction, target) + loss(prediction.to(device), target.to(device)) + loss = MultiScaleGMSDLoss(scale_weights=torch.tensor([3., 4., 2., 1., 2.])) + loss(prediction.to(device), target.to(device)) -def test_multi_scale_gmsd_raise_exception_for_small_images(): +def test_multi_scale_gmsd_loss_raise_exception_for_small_images(device: str) -> None: target = torch.ones(3, 1, 32, 32) prediction = torch.zeros(3, 1, 32, 32) loss = MultiScaleGMSDLoss(scale_weights=[3., 4., 2., 1., 1.]) with pytest.raises(ValueError): - loss(prediction, target) + loss(prediction.to(device), target.to(device)) + + +def test_multi_scale_loss_gmsd_modes(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + for reduction in ['mean', 'sum', 'none']: + MultiScaleGMSDLoss(reduction=reduction)(prediction.to(device), target.to(device)) + + for reduction in ['DEADBEEF', 'random']: + with pytest.raises(KeyError): + MultiScaleGMSDLoss(reduction=reduction)(prediction.to(device), target.to(device))