Skip to content

Fix Nx.linspace crash with n=1, NaN as well#1710

Open
blasphemetheus wants to merge 1 commit intoelixir-nx:mainfrom
blasphemetheus:fork/fix/linspace-n1
Open

Fix Nx.linspace crash with n=1, NaN as well#1710
blasphemetheus wants to merge 1 commit intoelixir-nx:mainfrom
blasphemetheus:fork/fix/linspace-n1

Conversation

@blasphemetheus
Copy link
Contributor

@blasphemetheus blasphemetheus commented Mar 20, 2026

linspace generates evenly-spaced numbers between a start and stop value. Linear Spacing.

  Nx.linspace(0, 10, n: 5)                                                               
  #=> [0.0, 2.5, 5.0, 7.5, 10.0] 

When n=1 and endpoint=true (default), the divisor is n-1=0, causing a divide-by-zero. The fix is to special-case n=1 to return start value directly. There's also a divide 0 by 0 possible which produces NaN currently.

I'm not sure precisely why one might want to call this function with n=1, but I think it should return sensible results
In the case of the current ArithmeticError: (for any positive int j)

linspace(0, j, n: 1)
# > [0]

in case of the current NaN: (for any positive int k)

linspace(k, k, n: 1)
# > [k]

running LinspaceN1Space test file on main

     Nx.LinspaceN1Test [test/nx/linspace_n1_test.exs]                                    
       * test linspace n=1 returns start value [L#4]                                     
       * test linspace n=1 returns start value (4.6ms) [L#4]                             
                                                                                         
       1) test linspace n=1 returns start value (Nx.LinspaceN1Test)                      
          test/nx/linspace_n1_test.exs:4                                                 
          ** (ArithmeticError) bad argument in arithmetic expression                     
          code: result = Nx.linspace(0, 10, n: 1)                                        
          stacktrace:                                                                    
            (complex 0.6.0) lib/complex.ex:579: Complex.divide/2                         
            (nx 0.11.0) lib/nx/binary_backend.ex:2461:                                   
     Nx.BinaryBackend."-binary_to_binary/4-lbc$^7/2-13-"/5                               
            (nx 0.11.0) lib/nx/binary_backend.ex:681:                                    
     Nx.BinaryBackend.element_wise_bin_op/4                                              
            (nx 0.11.0) lib/nx.ex:16846: Nx.linspace/3                                   
            test/nx/linspace_n1_test.exs:5: (test)                                       
                                                                                         
       * test linspace n=1 with same start/stop [L#10]                                   
       * test linspace n=1 with same start/stop (1.3ms) [L#10]                           
                                                                                         
       2) test linspace n=1 with same start/stop (Nx.LinspaceN1Test)                     
          test/nx/linspace_n1_test.exs:10                                                
          Assertion with == failed                                                       
          code:  assert Nx.to_flat_list(result) == [5.0]                                 
          left:  [:nan]                                                                  
          right: [5.0]                                                                   
          stacktrace:                                                                    
            test/nx/linspace_n1_test.exs:12: (test)                                      
                                                                                         
       * test linspace n=2 still works [L#15]                                            
       * test linspace n=2 still works (0.04ms) [L#15]                                   
       * test linspace n=5 still works [L#20]                                            
       * test linspace n=5 still works (0.01ms) [L#20]                                   
                                                                                         
     Finished in 0.01 seconds (0.01s async, 0.00s sync)                                  
     4 tests, 2 failures 

These are edge cases, but the current result in presumably unintended behavior.

originally part of the closed fuzz test edge case PR #1707

result = Nx.linspace(0, 1, n: 5)
assert Nx.shape(result) == {5}
end
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the other, let's add those tests directly to test/nx_test.exs. There may be already an existing describe block we could use!

When n=1 and endpoint=true (default), the divisor is n-1=0, causing
a divide-by-zero. Special-case n=1 to return start value directly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
iota
|> multiply(step)
|> add(start)
|> as_type(opts[:type])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we extract the else into a defp?

Comment on lines +3076 to +3083
result = Nx.linspace(0, 10, n: 1)
assert Nx.shape(result) == {1}
assert Nx.to_flat_list(result) == [0.0]
end

test "n=1 with same start/stop" do
result = Nx.linspace(5, 5, n: 1)
assert Nx.to_flat_list(result) == [5.0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use assert_equal here

Copy link
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few stylistic comments!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants