From 3db33e33fb3bf55bcf054bfe995e80cbf5e006ef Mon Sep 17 00:00:00 2001 From: Denis Prokopenko <22414094+denproc@users.noreply.github.com> Date: Tue, 28 Jul 2020 09:52:14 +0300 Subject: [PATCH] Bugfix: relax the requirements (#154) * bug(requirements): made requirements less strict, added warning. * bug(requirements): Highlight it in the description and readme. * bug(requirements): Minor. * Release commit --- README.md | 4 +++- conda.recipe/meta.yaml | 8 ++++---- piq/__init__.py | 2 +- piq/brisque.py | 15 +++++++++++++++ requirements.txt | 4 ++-- 5 files changed, 25 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index f357c3ed..781366c8 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,9 @@ prediction = torch.rand(3, 3, 256, 256) brisque_index: torch.Tensor = brisque(prediction, data_range=1.) ``` -In order to use BRISQUE as a loss function, use corresponding PyTorch module: +In order to use BRISQUE as a loss function, use corresponding PyTorch module. + +Note: the back propagation is not available using `torch==1.5.0`. Update the environment with latest `torch` and `torchvision`. ```python import torch from piq import BRISQUELoss diff --git a/conda.recipe/meta.yaml b/conda.recipe/meta.yaml index 3ccc72f9..67c77b1d 100644 --- a/conda.recipe/meta.yaml +++ b/conda.recipe/meta.yaml @@ -10,14 +10,14 @@ source: requirements: build: - python - - pytorch>=1.2.0,!=1.5.0 - - torchvision>=0.4.0,!=0.6.0 + - pytorch>=1.2.0 + - torchvision>=0.4.0 - scipy==1.3.3 - gudhi>=3.2 run: - python - - pytorch>=1.2.0,!=1.5.0 - - torchvision>=0.4.0,!=0.6.0 + - pytorch>=1.2.0 + - torchvision>=0.4.0 - scipy==1.3.3 - gudhi>=3.2 diff --git a/piq/__init__.py b/piq/__init__.py index 24a3eea6..49d549f0 100644 --- a/piq/__init__.py +++ b/piq/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.0" +__version__ = "0.5.1" from .ssim import ssim, multi_scale_ssim, SSIMLoss, MultiScaleSSIMLoss from .msid import MSID diff --git a/piq/brisque.py b/piq/brisque.py index e38d43c1..7ea3a37c 100644 --- a/piq/brisque.py +++ b/piq/brisque.py @@ -8,6 +8,7 @@ https://github.com/bukalapak/pybrisque """ from typing import Union, Tuple +import warnings import torch from torch.nn.modules.loss import _Loss from torch.utils.model_zoo import load_url @@ -33,10 +34,20 @@ def brisque(x: torch.Tensor, Returns: Value of BRISQUE index. + Note: + The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation. + Update the torch and torchvision to the latest versions. + References: .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf """ + if '1.5.0' in torch.__version__: + warnings.warn(f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.' + f'Update torch to the latest version to access full functionality of the BRIQSUE.' + f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and' + f'https://github.com/pytorch/pytorch/issues/38869.') + _validate_input(input_tensors=x, allow_5d=False, kernel_size=kernel_size) x = _adjust_dimensions(input_tensors=x) @@ -91,6 +102,10 @@ class BRISQUELoss(_Loss): >>> output = loss(prediction) >>> output.backward() + Note: + The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation. + Update the torch and torchvision to the latest versions. + References: .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf diff --git a/requirements.txt b/requirements.txt index b840880d..d715a994 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=1.2.0,!=1.5.0 -torchvision>=0.4.0,!=0.6.0 +torch>=1.2.0 +torchvision>=0.4.0 scipy==1.3.3 gudhi>=3.2