Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor.abs() gives incorrect results on Complex64 when using MPS #125135

Closed
stevenryoung opened this issue Apr 29, 2024 · 6 comments
Closed

Tensor.abs() gives incorrect results on Complex64 when using MPS #125135

stevenryoung opened this issue Apr 29, 2024 · 6 comments
Assignees
Labels
module: complex Related to complex number support in PyTorch module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@stevenryoung
Copy link

stevenryoung commented Apr 29, 2024

馃悰 Describe the bug

The abs() function gives incorrect results when using Complex64 tensors on an MPS device. It appears to be some kind of indexing issue per the example below.

import torch

with torch.device("cpu"):
    print(torch.ones((2,), dtype=torch.complex64).abs())
    

with torch.device("mps"):
    print(torch.ones((2,), dtype=torch.complex64).abs())
    print(torch.tensor([1.0 + 0.0j, 0.0 + 10.0j, 100.0 + 0.0j, 1000.0 + 0.0j]).abs())

Output:

tensor([1., 1.])
tensor([1., 0.], device='mps:0')
tensor([ 1.,  0., 10.,  0.], device='mps:0')

Versions

Collecting environment information...
PyTorch version: 2.4.0.dev20240428
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: version 3.28.3
Libc version: N/A

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0.dev20240428
[pip3] torchaudio==2.2.0.dev20240428
[pip3] torchvision==0.19.0.dev20240428
[conda] numpy 1.26.4 py311h7125741_0 conda-forge
[conda] pytorch 2.4.0.dev20240428 py3.11_0 pytorch-nightly
[conda] torchaudio 2.2.0.dev20240428 py311_cpu pytorch-nightly
[conda] torchvision 0.19.0.dev20240428 py311_cpu pytorch-nightly

cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @amjames @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@stevenryoung
Copy link
Author

stevenryoung commented Apr 29, 2024

For any user needing a quick (probably unsafe) fix, the following worked for my use case:

def patch_abs(t):
    if (t.device == torch.device("mps:0")) and (t.dtype == torch.complex64):
        return torch.sqrt(torch.pow(torch.real(t), 2) + torch.pow(torch.imag(t), 2) + 1e-12)
    return torch.abs(t)

torch.Tensor.abs = patch_abs

with torch.device("cpu"):
    print(torch.ones((2,), dtype=torch.complex64).abs())
    
with torch.device("mps"):
    print(torch.ones((2,), dtype=torch.complex64).abs())
    print(torch.tensor([1.0 + 0.0j, 0.0 + 10.0j, 100.0 + 0.0j, 1000.0 + 0.0j]).abs())

Output:

tensor([1., 1.])
tensor([1., 1.], device='mps:0')
tensor([   1.,   10.,  100., 1000.], device='mps:0')

@gambiTarun
Copy link

Hi @bdhirsh !
First time contributor here. Could I work on this bug?

@tringwald tringwald added module: complex Related to complex number support in PyTorch module: mps Related to Apple Metal Performance Shaders framework labels Apr 29, 2024
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 30, 2024
@malfet malfet self-assigned this May 7, 2024
@malfet malfet added the module: correctness (silent) issue that returns an incorrect result silently label May 7, 2024
@malfet
Copy link
Contributor

malfet commented May 7, 2024

Grabbing for myself, wish there was a proper doc for https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/3564540-absolutewithtensor?language=objc

@malfet malfet added this to the 2.3.1 milestone May 7, 2024
@malfet
Copy link
Contributor

malfet commented May 7, 2024

@gambiTarun sorry, missed your comment. How familiar are you with ObjC/MPS framework?

@gambiTarun
Copy link

Hi @malfet, no worries! I鈥檓 not familiar with ObjC/MPS, but I鈥檓 willing to learn. If you could point me in the right direction for resolving this bug, or suggest some resources, I'd appreciate it!

@jhavukainen
Copy link
Collaborator

Hi @malfet! Wanted to add some context as I did have a brief discussion with some of the MPSGraph people regarding this issue and that particular op. My current hypothesis on what happens:

  • If MPSGraph abs op gets passed an op that is data type Complex64, it will return a tensor of the same type. However if I print the type of the output from the CPU abs() (and the MPS abs for that matter) using the above snippet provided in the error it seems to be in float32. So probably the complex parts get dropped due to this conversion not being handled explicitly/correctly in the MPS abs implementation and we end up just interpreting the complex tensor as a float32 tensor. This would make complex 1+0i look like a float2 [1,0] tensor.
  • If that's the case, we should check for complex data type in the abs op and explicitly take the real part of the complex output of the graph op absoluteWithTensor. There should be an API

-(MPSGraphTensor *) realPartOfTensor:(MPSGraphTensor *) tensor name:(NSString * _Nullable) name
that should do the trick.

Unfortunately I'm drowning in other work this week so I'm more than happy to let anyone take a shot at this if they'd like.

pytorchbot pushed a commit that referenced this issue May 13, 2024
By calling `realPartOfTensor:` if input type is complex on Sonoma and fall back to `at::view_as_real` trick on Ventura.

Split `unary_op` template into `unary_op` and `unary_op_noresize`, which skips resize and empty checks

Marked `abs`, `isclose` and `nn.functional.softsign` OpInfo tests as supported by complex types

Fixes #125135

Pull Request resolved: #125662
Approved by: https://github.com/kulinseth

(cherry picked from commit 0fd1fc1)
huydhn pushed a commit that referenced this issue May 13, 2024
[MPS] Fix `abs` for complex types (#125662)

By calling `realPartOfTensor:` if input type is complex on Sonoma and fall back to `at::view_as_real` trick on Ventura.

Split `unary_op` template into `unary_op` and `unary_op_noresize`, which skips resize and empty checks

Marked `abs`, `isclose` and `nn.functional.softsign` OpInfo tests as supported by complex types

Fixes #125135

Pull Request resolved: #125662
Approved by: https://github.com/kulinseth

(cherry picked from commit 0fd1fc1)

Co-authored-by: Nikita Shulga <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: complex Related to complex number support in PyTorch module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants