-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't use gaussian_blur if sigma is a tensor on gpu #8401
Comments
Hi @pmeier, it is a good-first issue? Will it be suitable for a beginner? |
Hi @Xact-sniper, I think a possible fix is that we can add torch.device to this function call here. Can you pls send a reproducible code snippet? @pmeier @NicolasHug any possible suggestions to this? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
馃悰 Describe the bug
Admittedly perhaps an unconventional use, but I'm using gaussian_blur in my model to blur attention maps and I want to have the sigma be a parameter.
It would work, except for this function:
vision/torchvision/transforms/_functional_tensor.py
Line 725 in 06ad737
x is not moved to the device that sigma is on.
I believe it is like this in all torchvision versions.
WORKS:
DOES NOT:
I don't know about the convention, like whether device should be passed in, but the simplest fix I believe would just be to change:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
to:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)
Actually that won't when sigma is just a float. So I guess there could be a check for whether its a float or a float tensor.
Versions
[pip3] efficientunet-pytorch==0.0.6
[pip3] ema-pytorch==0.4.5
[pip3] flake8==6.0.0
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.3
[pip3] numpydoc==1.4.0
[pip3] pytorch-msssim==1.0.0
[pip3] siren-pytorch==0.1.7
[pip3] torch==2.2.2+cu118
[pip3] torch-cluster==1.6.0+pt113cu116
[pip3] torch_geometric==2.4.0
[pip3] torch-scatter==2.1.0+pt113cu116
[pip3] torch-sparse==0.6.16+pt113cu116
[pip3] torch-spline-conv==1.2.1+pt113cu116
[pip3] torch-tools==0.1.5
[pip3] torchaudio==2.2.2+cu118
[pip3] torchbearer==0.5.3
[pip3] torchmeta==1.8.0
[pip3] torchvision==0.17.2+cu118
[pip3] uformer-pytorch==0.0.8
[pip3] vit-pytorch==1.5.0
[conda] blas 1.0 mkl
[conda] efficientunet-pytorch 0.0.6 pypi_0 pypi
[conda] ema-pytorch 0.4.5 pypi_0 pypi
[conda] mkl 2021.4.0 haa95532_640
[conda] mkl-service 2.4.0 py39h2bbff1b_0
[conda] mkl_fft 1.3.1 py39h277e83a_0
[conda] mkl_random 1.2.2 py39hf11a4ad_0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] numpydoc 1.4.0 py39haa95532_0
[conda] pytorch-cuda 11.6 h867d48c_1 pytorch
[conda] pytorch-msssim 1.0.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] siren-pytorch 0.1.7 pypi_0 pypi
[conda] torch 1.13.0 pypi_0 pypi
[conda] torch-cluster 1.6.0+pt113cu116 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torch-scatter 2.1.0+pt113cu116 pypi_0 pypi
[conda] torch-sparse 0.6.16+pt113cu116 pypi_0 pypi
[conda] torch-spline-conv 1.2.1+pt113cu116 pypi_0 pypi
[conda] torch-tools 0.1.5 pypi_0 pypi
[conda] torchaudio 0.9.1 pypi_0 pypi
[conda] torchbearer 0.5.3 pypi_0 pypi
[conda] torchmeta 1.8.0 pypi_0 pypi
[conda] torchvision 0.17.2+cu118 pypi_0 pypi
[conda] uformer-pytorch 0.0.8 pypi_0 pypi
[conda] vit-pytorch 1.5.0 pypi_0 pypi
The text was updated successfully, but these errors were encountered: