Skip to content

Fix Nx.gather error message for scalar indices#1713

Merged
polvalente merged 1 commit intoelixir-nx:mainfrom
blasphemetheus:fork/fix/gather-scalar-error
Mar 20, 2026
Merged

Fix Nx.gather error message for scalar indices#1713
polvalente merged 1 commit intoelixir-nx:mainfrom
blasphemetheus:fork/fix/gather-scalar-error

Conversation

@blasphemetheus
Copy link
Contributor

Nx.gather(tensor, scalar) gives an unhelpful Erlang error because indexed_axes tries to access elem({}, -1) before the shape validation in Nx.Shape.gather can fire. Moved the scalar check earlier to return a more helpful Nx error.

So if you try:

Nx.gather(Nx.iota({3}), Nx.tensor(0))

You'll get a 1st argument out of range error (Argument Error)

There is validation in Nx.Shape.gather (line 1641-1643 in nx/lib/nx/shape.ex). It just isn't reached because indexed_axes(tensor, indices, opts) runs first, which does elem(indices.shape, tuple_size(indices.shape) -1) -> ends up with elem({}, -1). So Erlang crashes before there's a graceful error.

The fix is to move the same check the gather would do, just earlier, before indexed_axes.

Running gather_scalar_error_test.exs on main

     Nx.GatherScalarErrorTest                                  
     [test/nx/gather_scalar_error_test.exs]                    
       * test gather with valid indices still works [L#10]     
       * test gather with valid indices still works (3.4ms)    
     [L#10]                                                    
       * test gather raises correct error on scalar indices    
     [L#4]                                                     
       * test gather raises correct error on scalar indices    
     (4.5ms) [L#4]                                             
                                                               
       1) test gather raises correct error on scalar indices   
     (Nx.GatherScalarErrorTest)                                
          test/nx/gather_scalar_error_test.exs:4               
          Wrong message for ArgumentError                      
          expected:                                            
            ~r/expected indices rank to be at least 1/         
          actual:                                              
            "errors were found at the given arguments:\n\n  *  
     1st argument: out of range\n"                             
          code: assert_raise ArgumentError, ~r/expected indices
      rank to be at least 1/, fn ->                            
          stacktrace:                                          
            test/nx/gather_scalar_error_test.exs:5: (test)     
                                                     

originally part of the closed fuzz test edge case PR #1707

Nx.gather(tensor, scalar) gave an unhelpful Erlang error because
indexed_axes tried to access elem({}, -1) before the shape validation
in Nx.Shape.gather could fire. Moved the scalar check earlier.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fork/fix/gather-scalar-error branch from b60d75a to 1d319b4 Compare March 20, 2026 10:11
@polvalente polvalente merged commit 4c741d3 into elixir-nx:main Mar 20, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants