Skip to content

Commit

Permalink
🔥 Drop utils.tensor_norm in favor of torch.linalg.norm
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Mar 11, 2021
1 parent 4dcb651 commit f744007
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 137 deletions.
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
specific image quality assessement metric.
"""

__version__ = '1.1.2'
__version__ = '1.1.3'

from .tv import TV
from .psnr import PSNR
Expand Down
5 changes: 2 additions & 3 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
prewitt_kernel,
gradient_kernel,
channel_conv,
tensor_norm,
)


Expand Down Expand Up @@ -77,8 +76,8 @@ def gmsd(
# Gradient magnitude
pad = kernel.size(-1) // 2

gm_x = tensor_norm(channel_conv(x, kernel, padding=pad), dim=[1])
gm_y = tensor_norm(channel_conv(y, kernel, padding=pad), dim=[1])
gm_x = torch.linalg.norm(channel_conv(x, kernel, padding=pad), dim=1)
gm_y = torch.linalg.norm(channel_conv(y, kernel, padding=pad), dim=1)

gm_xy = gm_x * gm_y

Expand Down
5 changes: 2 additions & 3 deletions piqa/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch.hub as hub

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils.functional import normalize_tensor

from typing import Dict, List

Expand Down Expand Up @@ -225,8 +224,8 @@ def forward(
residuals = []

for lin, fx, fy in zip(self.lins, self.net(input), self.net(target)):
fx = normalize_tensor(fx, dim=[1], norm='L2')
fy = normalize_tensor(fy, dim=[1], norm='L2')
fx = fx / torch.linalg.norm(fx, dim=1, keepdim=True)
fy = fy / torch.linalg.norm(fy, dim=1, keepdim=True)

mse = ((fx - fy) ** 2).mean(dim=(-1, -2), keepdim=True)
residuals.append(lin(mse).flatten())
Expand Down
9 changes: 4 additions & 5 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
prewitt_kernel,
gradient_kernel,
channel_conv,
tensor_norm,
)

import piqa.utils.complex as cx
Expand Down Expand Up @@ -77,11 +76,11 @@ def mdsi(
# Gradient magnitude
pad = kernel.size(-1) // 2

gm_x = tensor_norm(channel_conv(l_x, kernel, padding=pad), dim=[1])
gm_y = tensor_norm(channel_conv(l_y, kernel, padding=pad), dim=[1])
gm_avg = tensor_norm(
gm_x = torch.linalg.norm(channel_conv(l_x, kernel, padding=pad), dim=1)
gm_y = torch.linalg.norm(channel_conv(l_y, kernel, padding=pad), dim=1)
gm_avg = torch.linalg.norm(
channel_conv((l_x + l_y) / 2., kernel, padding=pad),
dim=[1],
dim=1,
)

gm_x_sq, gm_y_sq, gm_avg_sq = gm_x ** 2, gm_y ** 2, gm_avg ** 2
Expand Down
4 changes: 3 additions & 1 deletion piqa/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ def _assert_type(
)


@_jit
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
r"""Returns a reducing module.
r"""Returns the reduction of \(x\).
Args:
x: A tensor, \((*,)\).
reduction: Specifies the reduction type:
`'none'` | `'mean'` | `'sum'`.
Expand Down
2 changes: 1 addition & 1 deletion piqa/utils/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def mod(x: torch.Tensor, squared: bool = False) -> torch.Tensor:
tensor([2.0000, 1.0000])
"""

x = (x ** 2).sum(dim=-1)
x = x.square().sum(dim=-1)

if not squared:
x = torch.sqrt(x)
Expand Down
128 changes: 5 additions & 123 deletions piqa/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,13 @@ def gaussian_kernel(
) -> torch.Tensor:
r"""Returns the 1-dimensional Gaussian kernel of size \(K\).
$$ G(x) = \frac{1}{\sum_{y = 1}^{K} G(y)} \exp
$$ G(x) = \gamma \exp
\left(\frac{(x - \mu)^2}{2 \sigma^2}\right) $$
where \(x \in [1; K]\) is a position in the kernel
where \(\gamma\) is such that
$$ \sum_{x = 1}^{K} G(x) = 1 $$
and \(\mu = \frac{1 + K}{2}\).
Args:
Expand Down Expand Up @@ -263,124 +266,3 @@ def gradient_kernel(kernel: torch.Tensor) -> torch.Tensor:
"""

return torch.stack([kernel, kernel.t()]).unsqueeze(1)


def tensor_norm(
x: torch.Tensor,
dim: List[int], # Union[int, Tuple[int, ...]] = ()
keepdim: bool = False,
norm: str = 'L2',
) -> torch.Tensor:
r"""Returns the norm of \(x\).
$$ L_1(x) = \left\| x \right\|_1 = \sum_i \left| x_i \right| $$
$$ L_2(x) = \left\| x \right\|_2 = \left( \sum_i x^2_i \right)^\frac{1}{2} $$
Args:
x: A tensor, \((*,)\).
dim: The dimension(s) along which to calculate the norm.
keepdim: Whether the output tensor has `dim` retained or not.
norm: Specifies the norm funcion to apply:
`'L1'` | `'L2'` | `'L2_squared'`.
Wikipedia:
https://en.wikipedia.org/wiki/Norm_(mathematics)
Example:
>>> x = torch.arange(9).float().view(3, 3)
>>> x
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]])
>>> tensor_norm(x, dim=0)
tensor([6.7082, 8.1240, 9.6437])
"""

if norm == 'L1':
x = x.abs()
else: # norm in ['L2', 'L2_squared']
x = x ** 2

x = x.sum(dim=dim, keepdim=keepdim)

if norm == 'L2':
x = x.sqrt()

return x


def normalize_tensor(
x: torch.Tensor,
dim: List[int], # Union[int, Tuple[int, ...]] = ()
norm: str = 'L2',
epsilon: float = 1e-8,
) -> torch.Tensor:
r"""Returns \(x\) normalized.
$$ \hat{x} = \frac{x}{\left\|x\right\|} $$
Args:
x: A tensor, \((*,)\).
dim: The dimension(s) along which to normalize.
norm: Specifies the norm funcion to use:
`'L1'` | `'L2'` | `'L2_squared'`.
epsilon: A numerical stability term.
Returns:
The normalized tensor, \((*,)\).
Example:
>>> x = torch.arange(9, dtype=torch.float).view(3, 3)
>>> x
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]])
>>> normalize_tensor(x, dim=0)
tensor([[0.0000, 0.1231, 0.2074],
[0.4472, 0.4924, 0.5185],
[0.8944, 0.8616, 0.8296]])
"""

norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm)

return x / (norm + epsilon)


def unravel_index(
indices: torch.LongTensor,
shape: List[int],
) -> torch.LongTensor:
r"""Converts flat indices into unraveled coordinates in a target shape.
This is a `torch` implementation of `numpy.unravel_index`.
Args:
indices: A tensor of (flat) indices, \((*, N)\).
shape: The targeted shape, \((D,)\).
Returns:
The unraveled coordinates, \((*, N, D)\).
Example:
>>> unravel_index(torch.arange(9), shape=(3, 3))
tensor([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[1, 1],
[1, 2],
[2, 0],
[2, 1],
[2, 2]])
"""

coord = []

for dim in reversed(shape):
coord.append(indices % dim)
indices = indices // dim

coord = torch.stack(coord[::-1], dim=-1)

return coord

0 comments on commit f744007

Please sign in to comment.