@@ -244,15 +244,26 @@ defmodule Axon.LossScaleTest do
244244
245245 non_finite = Nx . tensor ( [ :infinity , :infinity , :infinity ] )
246246
247- # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26
248- # is fixed
249- for i <- 0 .. 62 , reduce: state do
247+ for i <- 0 .. 99 , reduce: state do
250248 new_state ->
251249 { _ , % { loss_scale: loss_scale , counter: counter } = new_state } =
252250 adjust_fn . ( non_finite , new_state )
253251
254- expected_new_scale = Nx . max ( 1 , Nx . divide ( init_scale , Nx . pow ( factor , i + 1 ) ) )
252+ # We want to check if init_scale / factor ** (i + 1) is greater than 1.
253+ # If we rely on `i` directly, we run into integer overflow issues.
254+ # Instead, we accumulate the divisor on the reduce.
255+
256+ scale_divisor = 2 ** ( i + 1 )
257+
258+ expected_new_scale =
259+ if scale_divisor >= 2 ** 32 do
260+ Nx . tensor ( 1 )
261+ else
262+ Nx . max ( 1 , Nx . divide ( init_scale , scale_divisor ) )
263+ end
264+
255265 assert_equal ( counter , Nx . tensor ( 0 ) )
266+
256267 assert_all_close ( loss_scale , expected_new_scale )
257268
258269 new_state
@@ -277,15 +288,19 @@ defmodule Axon.LossScaleTest do
277288
278289 non_finite = Nx . tensor ( [ :infinity , :infinity , :infinity ] )
279290
280- # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26
281- # is fixed
282- for i <- 0 .. 62 , reduce: state do
291+ for i <- 0 .. 99 , reduce: state do
283292 new_state ->
284293 { _ , % { loss_scale: loss_scale , counter: counter } = new_state } =
285294 adjust_fn . ( non_finite , new_state )
286295
296+ scale_divisor = 2 ** ( i + 1 )
297+
287298 expected_new_scale =
288- Nx . max ( min_loss_scale , Nx . divide ( init_scale , Nx . pow ( factor , i + 1 ) ) )
299+ if scale_divisor >= 2 ** 32 do
300+ Nx . tensor ( min_loss_scale )
301+ else
302+ Nx . max ( min_loss_scale , Nx . divide ( init_scale , scale_divisor ) )
303+ end
289304
290305 assert_equal ( counter , Nx . tensor ( 0 ) )
291306 assert_all_close ( loss_scale , expected_new_scale )
0 commit comments