Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nan_to_num doesn't work with -inf/inf #8674

Open
Akshat-Tripathi opened this issue Feb 4, 2025 · 1 comment
Open

torch.nan_to_num doesn't work with -inf/inf #8674

Akshat-Tripathi opened this issue Feb 4, 2025 · 1 comment
Assignees
Labels
bug Something isn't working pytorch api XLA behavior doesn't match Pytorch eager frontend

Comments

@Akshat-Tripathi
Copy link

🐛 Bug

Hi, I was working on some pytorch code that's designed to run on multiple backends, including TPU via xla. This code uses the torch.nan_to_num() function with infinities, which works on pytorch's CPU and GPU backend, but not on TPU.

To Reproduce

import torch
import torch_xla.core.xla_model as xm

xla_device = xm.xla_device()
cpu_device = torch.device("cpu")

neg_inf = float("-inf")
pos_inf = float("inf")

xla_tensor = torch.zeros(3, 3, device=xla_device)
cpu_tensor = torch.zeros(3, 3, device=cpu_device)

try:
    torch.nan_to_num(
        xla_tensor,
        nan=neg_inf,
        posinf=pos_inf,
        neginf=neg_inf
    )
except:
    print("Failed")

torch.nan_to_num(
    cpu_tensor,
    nan=neg_inf,
    posinf=pos_inf,
    neginf=neg_inf
)
print("Passed")

Steps to reproduce the behavior:

  1. Run the above code snippet.

This is the stack trace I get with the xla backend.

[rank0]: RuntimeError: torch_xla/csrc/aten_xla_type.cpp:2346 : Check failed: min_max.min.toDouble() <= replacement.toDouble() && replacement.toDouble() <= min_max.max.toDouble() 
[rank0]: *** Begin stack trace ***
[rank0]:        tsl::CurrentStackTrace()
[rank0]:        torch_xla::XLANativeFunctions::nan_to_num(at::Tensor const&, std::optional<double>, std::optional<double>, std::optional<double>)
[rank0]: 
[rank0]: 
[rank0]:        at::_ops::nan_to_num::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<double>, std::optional<double>, std::optional<double>)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>)
[rank0]:        torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyObject_FastCallDictTstate
[rank0]:        _PyObject_Call_Prepend
[rank0]: 
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyObject_FastCallDictTstate
[rank0]:        _PyObject_Call_Prepend
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyObject_FastCallDictTstate
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyObject_FastCallDictTstate
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        PyEval_EvalCode
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyRun_SimpleFileObject
[rank0]:        _PyRun_AnyFileObject
[rank0]:        Py_RunMain
[rank0]:        Py_BytesMain
[rank0]: 
[rank0]:        __libc_start_main
[rank0]: 
[rank0]: *** End stack trace ***
[rank0]: Type BFloat16 replacement value -inf must be in the range [-3.40282e+38, 3.40282e+38].

Expected behavior

The function should work the same way on all torch backends.

Environment

  • Reproducible on XLA backend TPU:
  • torch_xla version: 2.6.0.dev20241126

Additional context

As a workaround I'm able to use dtype.min/max in place of infinities, but it's still not ideal.

@ysiraichi
Copy link
Collaborator

ysiraichi commented Feb 5, 2025

Thank you for reporting this bug. I was able to reproduce it on: 225c65b
I will look into it.

@ysiraichi ysiraichi added the bug Something isn't working label Feb 5, 2025
@ysiraichi ysiraichi self-assigned this Feb 5, 2025
@ysiraichi ysiraichi added the pytorch api XLA behavior doesn't match Pytorch eager frontend label Feb 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pytorch api XLA behavior doesn't match Pytorch eager frontend
Projects
None yet
Development

No branches or pull requests

2 participants