Fix Nx.slice crash on scalar tensor#1712
Fix Nx.slice crash on scalar tensor#1712blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
Conversation
bin_slice/7 called hd([]) on empty strides list for rank-0 tensors. Added scalar guard clause that returns data unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
nx/test/nx/scalar_slice_test.exs
Outdated
| result = Nx.slice(t, [], []) | ||
| assert_in_delta Nx.to_number(result), 3.14, 1.0e-10 | ||
| end | ||
| end |
There was a problem hiding this comment.
Instead of a separate module, let's add a function to nx_test.exs. I would also add a doctest that says:
Sliding a one-dimensional tensor is a no-op:
iex> Nx.slice(42, [], [])
WDYT?
There was a problem hiding this comment.
sounds good! I'll go through and apply that idea to the other PRs too, just adding em at the end of nx_test.exs
polvalente
left a comment
There was a problem hiding this comment.
I'm not sure we should support slicing scalars.
Furthermore, if we're adding support for this, we should also check how EXLA and Torchx behave.
I think the correct PR here is to fail on scalar slicing
|
Slicing a scalar could always be a no-op if you pass no dimensions (and raise if you pass any) so there is nothing to be implemented on the other backends because there is no operation. |
|
FWIW: import numpy as np
x = np.array(5)
x[()] # returns 5But the current implementation requires indeed each backend to implement it (but the doctest should test them aleady) |
If you run
Nx.sliceon a scalar tensor [tensor with no dimensions, just a single number wrapped in a tensor struct] (aka rank 0), it throws an errorSlicing a scalar is a valid no-op.
Nx.Shape.slicedoes it fine. But hereBinaryBackend.bin_slice/7callshd(strides)on the empty strides list[], which crashes.The fix is to add a scalar guard clause that matches when all lists are empty and returns the data unchanged.
running scalar_slice_test.exs on main
originally part of the closed fuzz test edge case PR #1707