diff --git a/photosynthesis_metrics/ssim.py b/photosynthesis_metrics/ssim.py index d186bf7f..c1fc3966 100644 --- a/photosynthesis_metrics/ssim.py +++ b/photosynthesis_metrics/ssim.py @@ -9,7 +9,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn._reduction as _Reduction import torch.nn.functional as f from torch.nn.modules.loss import _Loss @@ -96,15 +95,6 @@ class SSIMLoss(_Loss): kernel_sigma: Standard deviation for Gaussian kernel. k1: Coefficient related to c1 in the above equation. k2: Coefficient related to c2 in the above equation. - size_average: Deprecated (see :attr:`reduction`). By default, - the losses are averaged over each loss element in the batch. Note that for - some losses, there are multiple elements per sample. If the field :attr:`size_average` - is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` - reduce: Deprecated (see :attr:`reduction`). By default, the - losses are averaged or summed over observations for each minibatch depending - on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per - batch element instead and ignores :attr:`size_average`. Default: ``True`` reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of @@ -138,16 +128,10 @@ class SSIMLoss(_Loss): __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, - size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None: super().__init__() # Generic loss parameters. - self.size_average = size_average - self.reduce = reduce - if size_average is not None or reduce is not None: - reduction = _Reduction.legacy_get_string(size_average, reduce) - self.reduction = reduction # Loss-specific parameters. @@ -306,15 +290,6 @@ class MultiScaleSSIMLoss(_Loss): scale_weights: Weights for different scales. If None, default weights from the paper [1] will be used. Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). - size_average: Deprecated (see :attr:`reduction`). By default, - the losses are averaged over each loss element in the batch. Note that for - some losses, there are multiple elements per sample. If the field :attr:`size_average` - is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` - reduce: Deprecated (see :attr:`reduction`). By default, the - losses are averaged or summed over observations for each minibatch depending - on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per - batch element instead and ignores :attr:`size_average`. Default: ``True`` reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of @@ -355,16 +330,10 @@ class MultiScaleSSIMLoss(_Loss): def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None, - size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None: super().__init__() # Generic loss parameters. - self.size_average = size_average - self.reduce = reduce - if size_average is not None or reduce is not None: - reduction = _Reduction.legacy_get_string(size_average, reduce) - self.reduction = reduction # Loss-specific parameters. diff --git a/photosynthesis_metrics/tv.py b/photosynthesis_metrics/tv.py index dd130b65..7ef23ebd 100644 --- a/photosynthesis_metrics/tv.py +++ b/photosynthesis_metrics/tv.py @@ -23,7 +23,7 @@ def _adjust_tensor_dimensions(x: torch.Tensor): def _validate_input(x: torch.Tensor) -> None: """Validates input tensor""" assert isinstance(x, torch.Tensor), f'Input must be a torch.Tensor, got {type(x)}.' - assert 1 < x.dim() < 5, f'Input image must be 4D tensor, got image of shape {x.size()}.' + assert 1 < x.dim() < 5, f'Input image must be 2D, 3D or 4D tensor, got image of shape {x.size()}.' def total_variation(x: torch.Tensor, size_average: bool = True, reduction_type: str = 'l2') -> torch.Tensor: @@ -105,10 +105,9 @@ class TVLoss(_Loss): https://remi.flamary.com/demos/proxtv.html """ - def __init__(self, size_average: bool = True, reduction_type: str = 'l2', reduction: str = 'mean'): + def __init__(self, reduction_type: str = 'l2', reduction: str = 'mean'): super().__init__() - self.size_average = size_average self.reduction_type = reduction_type self.reduction = reduction @@ -128,12 +127,14 @@ def forward(self, prediction: torch.Tensor) -> torch.Tensor: def compute_metric(self, prediction: torch.Tensor) -> torch.Tensor: score = total_variation( prediction, - size_average=self.size_average, + size_average=False, reduction_type=self.reduction_type ) if self.reduction == 'mean': - score = torch.mean(score) + score = torch.mean(score, dim=0) elif self.reduction == 'sum': - score = torch.sum(score) + score = torch.sum(score, dim=0) + elif self.reduction != 'none': + raise ValueError(f'Expected "none"|"mean"|"sum" reduction, got {self.reduction}') return score diff --git a/photosynthesis_metrics/vif.py b/photosynthesis_metrics/vif.py index ca17ef6b..3ec7ce6c 100644 --- a/photosynthesis_metrics/vif.py +++ b/photosynthesis_metrics/vif.py @@ -8,6 +8,7 @@ import torch from torch.nn.modules.loss import _Loss import torch.nn.functional as F +from typing import Union from photosynthesis_metrics.utils import _adjust_dimensions, _validate_input @@ -30,7 +31,7 @@ def _gaussian_kernel2d(kernel_size: int = 5, sigma: float = 2.0) -> torch.Tensor def vif_p(prediction: torch.Tensor, target: torch.Tensor, - sigma_n_sq: float = 2.0, data_range: int = 1.0) -> torch.Tensor: + sigma_n_sq: float = 2.0, data_range: Union[int, float] = 1.0) -> torch.Tensor: r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images. This metric isn't symmetric, so make sure to place arguments in correct order. @@ -107,7 +108,7 @@ class VIFLoss(_Loss): value `1 - clip(VIF, min=0, max=1)` is returned. """ - def __init__(self, sigma_n_sq: float = 2.0, data_range=1.0): + def __init__(self, sigma_n_sq: float = 2.0, data_range: Union[int, float] = 1.0): r""" Args: sigma_n_sq: HVS model parameter (variance of the visual noise). @@ -134,5 +135,5 @@ def compute_metric(self, prediction: torch.Tensor, target: torch.Tensor) -> torc score = vif_p(prediction, target, sigma_n_sq=self.sigma_n_sq, data_range=self.data_range) # Make sure value to be in [0, 1] range and convert to loss - loss = 1 - torch.clamp(torch.mean(score), 0, 1) + loss = 1 - torch.clamp(torch.mean(score, dim=0), 0, 1) return loss