Skip to content

Conversation

ysiraichi
Copy link
Collaborator

This PR refactors the roll operation implementation by improving its error message, and returning a status type value.

Key Changes:

  • Make tensor_methods::roll return StatusOr<XLATensorPtr>
  • Improve error messages and error handling
    • Create CheckRollShiftsRequired and CheckRollDimsAndShiftsAreCompatible functions

Example 1: empty shifts argument

a = torch.arange(8, device=device).view(2, 2, 2)
shifts = []

Before:

Traceback (most recent call last):
  File "examples/roll.py", line 24, in <module>
    torch.roll(a, shifts)
RuntimeError: Check failed: 0 < shifts.size() (0 vs. 0)`shifts` required (at torch_xla/csrc/tensor_methods.cpp:3057)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/roll.py", line 24, in <module>
    torch.roll(a, shifts)
RuntimeError: roll(): expected `shifts` to have at least 1 element.

Status Propagation Trace:
    From: roll at torch_xla/csrc/tensor_methods.cpp:3086
    From: roll at torch_xla/csrc/aten_xla_type.cpp:3322

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Example 2: multiple shifts arguments with empty dims

a = torch.arange(8, device=device).view(2, 2, 2)
shifts = [1, 1]

Before:

tensor([[[7, 0],
         [1, 2]],

        [[3, 4],
         [5, 6]]], device='xla:0')

After:

Traceback (most recent call last):
  File "examples/roll.py", line 24, in <module>
    torch.roll(a, shifts)
RuntimeError: roll(): expected `shifts` [1, 1] (size=2) to have exactly 1 element when `dims` is empty.

Status Propagation Trace:
    From: roll at torch_xla/csrc/tensor_methods.cpp:3087
    From: roll at torch_xla/csrc/aten_xla_type.cpp:3322

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Example 3:

a = torch.arange(8, device=device).view(2, 2, 2)
shifts = [1, 1]
dims = [0, 1, 2]

Before:

Traceback (most recent call last):
  File "examples/roll.py", line 24, in <module>
    torch.roll(a, shifts, dims)
RuntimeError: Check failed: shifts.size() == dims.size() (2 vs. 3)shifts and dimensions must align. shifts: 2, dims:3 (at torch_xla/csrc/tensor_methods.cpp:3059)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/roll.py", line 24, in <module>
    torch.roll(a, shifts, dims)
RuntimeError: roll(): expected `dims` [0, 1, 2] (size=3) to match the size of `shifts` [1, 1] (size=2).

Status Propagation Trace:
    From: roll at torch_xla/csrc/tensor_methods.cpp:3087
    From: roll at torch_xla/csrc/aten_xla_type.cpp:3322

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

@ysiraichi ysiraichi changed the title `roll: improve error handling and error messages. roll: improve error handling and error messages. Sep 6, 2025
@ysiraichi ysiraichi force-pushed the ysiraichi/better-error-roll branch 3 times, most recently from 22f2c49 to e5b4032 Compare September 12, 2025 19:04
@ysiraichi

This comment was marked as outdated.

@ysiraichi ysiraichi force-pushed the ysiraichi/better-error-roll branch from e445a05 to 28f9502 Compare September 15, 2025 12:26
@ysiraichi ysiraichi merged commit a66cfc3 into master Sep 16, 2025
41 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants