Skip to content

Commit

Permalink
bug(tv): reduction was handled to return proper values (#66)
Browse files Browse the repository at this point in the history
* bug(tv): reduction was handled to return proper values

* ref(ssim): refactoring to eliminate deprecated dependamcies.
  • Loading branch information
denproc authored May 28, 2020
1 parent b04f151 commit f9c7203
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 40 deletions.
31 changes: 0 additions & 31 deletions photosynthesis_metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn._reduction as _Reduction
import torch.nn.functional as f
from torch.nn.modules.loss import _Loss

Expand Down Expand Up @@ -96,15 +95,6 @@ class SSIMLoss(_Loss):
kernel_sigma: Standard deviation for Gaussian kernel.
k1: Coefficient related to c1 in the above equation.
k2: Coefficient related to c2 in the above equation.
size_average: Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when reduce is ``False``. Default: ``True``
reduce: Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction: Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
Expand Down Expand Up @@ -138,16 +128,10 @@ class SSIMLoss(_Loss):
__constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']

def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03,
size_average: Optional[bool] = None, reduce: Optional[bool] = None,
reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None:
super().__init__()

# Generic loss parameters.
self.size_average = size_average
self.reduce = reduce
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)

self.reduction = reduction

# Loss-specific parameters.
Expand Down Expand Up @@ -306,15 +290,6 @@ class MultiScaleSSIMLoss(_Loss):
scale_weights: Weights for different scales.
If None, default weights from the paper [1] will be used.
Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333).
size_average: Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when reduce is ``False``. Default: ``True``
reduce: Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction: Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
Expand Down Expand Up @@ -355,16 +330,10 @@ class MultiScaleSSIMLoss(_Loss):

def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03,
scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None,
size_average: Optional[bool] = None, reduce: Optional[bool] = None,
reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None:
super().__init__()

# Generic loss parameters.
self.size_average = size_average
self.reduce = reduce
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)

self.reduction = reduction

# Loss-specific parameters.
Expand Down
13 changes: 7 additions & 6 deletions photosynthesis_metrics/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _adjust_tensor_dimensions(x: torch.Tensor):
def _validate_input(x: torch.Tensor) -> None:
"""Validates input tensor"""
assert isinstance(x, torch.Tensor), f'Input must be a torch.Tensor, got {type(x)}.'
assert 1 < x.dim() < 5, f'Input image must be 4D tensor, got image of shape {x.size()}.'
assert 1 < x.dim() < 5, f'Input image must be 2D, 3D or 4D tensor, got image of shape {x.size()}.'


def total_variation(x: torch.Tensor, size_average: bool = True, reduction_type: str = 'l2') -> torch.Tensor:
Expand Down Expand Up @@ -105,10 +105,9 @@ class TVLoss(_Loss):
https://remi.flamary.com/demos/proxtv.html
"""

def __init__(self, size_average: bool = True, reduction_type: str = 'l2', reduction: str = 'mean'):
def __init__(self, reduction_type: str = 'l2', reduction: str = 'mean'):
super().__init__()

self.size_average = size_average
self.reduction_type = reduction_type
self.reduction = reduction

Expand All @@ -128,12 +127,14 @@ def forward(self, prediction: torch.Tensor) -> torch.Tensor:
def compute_metric(self, prediction: torch.Tensor) -> torch.Tensor:
score = total_variation(
prediction,
size_average=self.size_average,
size_average=False,
reduction_type=self.reduction_type
)

if self.reduction == 'mean':
score = torch.mean(score)
score = torch.mean(score, dim=0)
elif self.reduction == 'sum':
score = torch.sum(score)
score = torch.sum(score, dim=0)
elif self.reduction != 'none':
raise ValueError(f'Expected "none"|"mean"|"sum" reduction, got {self.reduction}')
return score
7 changes: 4 additions & 3 deletions photosynthesis_metrics/vif.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
from typing import Union

from photosynthesis_metrics.utils import _adjust_dimensions, _validate_input

Expand All @@ -30,7 +31,7 @@ def _gaussian_kernel2d(kernel_size: int = 5, sigma: float = 2.0) -> torch.Tensor


def vif_p(prediction: torch.Tensor, target: torch.Tensor,
sigma_n_sq: float = 2.0, data_range: int = 1.0) -> torch.Tensor:
sigma_n_sq: float = 2.0, data_range: Union[int, float] = 1.0) -> torch.Tensor:
r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images.
This metric isn't symmetric, so make sure to place arguments in correct order.
Expand Down Expand Up @@ -107,7 +108,7 @@ class VIFLoss(_Loss):
value `1 - clip(VIF, min=0, max=1)` is returned.
"""

def __init__(self, sigma_n_sq: float = 2.0, data_range=1.0):
def __init__(self, sigma_n_sq: float = 2.0, data_range: Union[int, float] = 1.0):
r"""
Args:
sigma_n_sq: HVS model parameter (variance of the visual noise).
Expand All @@ -134,5 +135,5 @@ def compute_metric(self, prediction: torch.Tensor, target: torch.Tensor) -> torc

score = vif_p(prediction, target, sigma_n_sq=self.sigma_n_sq, data_range=self.data_range)
# Make sure value to be in [0, 1] range and convert to loss
loss = 1 - torch.clamp(torch.mean(score), 0, 1)
loss = 1 - torch.clamp(torch.mean(score, dim=0), 0, 1)
return loss

0 comments on commit f9c7203

Please sign in to comment.