Skip to content

Commit

Permalink
🐛 Fix inaccurate LPIPS
Browse files Browse the repository at this point in the history
The inaccuracies came from the dropout layers.
Initializing LPIPS in evaluation mode corrected this behavior.
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent 3227b26 commit 43ae9d7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
specific image quality assessement metric.
"""

__version__ = '1.0.6'
__version__ = '1.0.7'
8 changes: 6 additions & 2 deletions piqa/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class LPIPS(nn.Module):
be scaled w.r.t. ImageNet.
dropout: Whether dropout is used or not.
pretrained: Whether the official pretrained weights are used or not.
eval: Whether to initialize the object in evaluation mode or not.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
Expand All @@ -84,8 +85,7 @@ class LPIPS(nn.Module):
* Output: (N,) or (1,) depending on `reduction`
Note:
`LPIPS` is a *trainable* metric. To prevent the weights from updating,
use the `torch.no_grad()` context or freeze the weights.
`LPIPS` is a *trainable* metric.
Example:
>>> criterion = LPIPS().cuda()
Expand All @@ -102,6 +102,7 @@ def __init__(
scaling: bool = True,
dropout: bool = True,
pretrained: bool = True,
eval: bool = True,
reduction: str = 'mean',
):
r""""""
Expand Down Expand Up @@ -143,6 +144,9 @@ def __init__(
if pretrained:
self.lin.load_state_dict(get_weights(network=network))

if eval:
self.eval()

self.reduce = build_reduce(reduction)

def forward(
Expand Down

0 comments on commit 43ae9d7

Please sign in to comment.