Skip to content

Fix window_scatter_max/min crash on f64 tensors#1711

Open
blasphemetheus wants to merge 1 commit intoelixir-nx:mainfrom
blasphemetheus:fork/fix/window-scatter-f64
Open

Fix window_scatter_max/min crash on f64 tensors#1711
blasphemetheus wants to merge 1 commit intoelixir-nx:mainfrom
blasphemetheus:fork/fix/window-scatter-f64

Conversation

@blasphemetheus
Copy link
Contributor

window_scatter_max and window_scatter_min crash with f64 tensors.

  t = Nx.iota({6}, type: :f64)                                                          
  s = Nx.iota({3}, type: :f64)                                                          
  init = Nx.tensor(0.0, type: :f64)                                                     
  Nx.window_scatter_max(t, s, init, {2}, strides: [2], padding: :valid)                 
  # ** (ArgumentError) unexpected size for tensor data, expected 384 bits got: 288 bits

This does not affect f32. The scatter result in select_and_scatter wasn't cast to the output type before to_binary, causing a binary size mismatch for f64 (8-byte) tensors.

So the fix is: value |> Nx.as_type(output_type) |> to_binary()

running window_scatter_f64_test.exs on main

     Nx.WindowScatterF64Test [test/nx/window_scatter_f64_test.exs]                      
       * test window_scatter_max still works with f32 [L#22]                            
       * test window_scatter_max still works with f32 (3.1ms) [L#22]                    
       * test window_scatter_max works with f64 [L#4]                                   
       * test window_scatter_max works with f64 (1.5ms) [L#4]                           
                                                                                        
       1) test window_scatter_max works with f64 (Nx.WindowScatterF64Test)              
          test/nx/window_scatter_f64_test.exs:4                                         
          ** (ArgumentError) unexpected size for tensor data, expected 384 bits got: 288
      bits                                                                              
          code: result = Nx.window_scatter_max(t, s, init, {2}, strides: [2], padding:  
     :valid)                                                                            
          stacktrace:                                                                   
            (nx 0.11.0) lib/nx/binary_backend.ex:118: Nx.BinaryBackend.from_binary/2    
            (nx 0.11.0) lib/nx.ex:7486: Nx.window_scatter_max/5                         
            test/nx/window_scatter_f64_test.exs:8: (test)                               
                                                                                        
       * test window_scatter_min works with f64 [L#13]                                  
       * test window_scatter_min works with f64 (0.08ms) [L#13]                         
                                                                                        
       2) test window_scatter_min works with f64 (Nx.WindowScatterF64Test)              
          test/nx/window_scatter_f64_test.exs:13                                        
          ** (ArgumentError) unexpected size for tensor data, expected 384 bits got: 288
      bits                                                                              
          code: result = Nx.window_scatter_min(t, s, init, {2}, strides: [2], padding:  
     :valid)                                                                            
          stacktrace:                                                                   
            (nx 0.11.0) lib/nx/binary_backend.ex:118: Nx.BinaryBackend.from_binary/2    
            (nx 0.11.0) lib/nx.ex:7648: Nx.window_scatter_min/5                         
            test/nx/window_scatter_f64_test.exs:17: (test)                              
                                                                                        
                                                                                        
     Finished in 0.01 seconds (0.01s async, 0.00s sync)                                 
     3 tests, 2 failures 

originally part of the closed fuzz test edge case PR #1707

The scatter result was not cast to the output type before to_binary,
causing a binary size mismatch for f64 (8-byte) tensors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fork/fix/window-scatter-f64 branch from b6956c0 to 3cba2e1 Compare March 20, 2026 09:55
Comment on lines +7433 to +7442

It also works with f64 tensors:

iex> t = Nx.iota({6}, type: :f64)
iex> s = Nx.iota({3}, type: :f64)
iex> Nx.window_scatter_max(t, s, 0.0, {2}, strides: [2], padding: :valid)
#Nx.Tensor<
f64[6]
[0.0, 0.0, 0.0, 1.0, 0.0, 2.0]
>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a regular test if we're only fixing a binary backend bug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw we have the tests, so let's remove this one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants