Skip to content

Commit 78fb704

Browse files
authored
fix: loss scale assertions (#597)
1 parent 88241ac commit 78fb704

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

test/axon/loss_scale_test.exs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,26 @@ defmodule Axon.LossScaleTest do
244244

245245
non_finite = Nx.tensor([:infinity, :infinity, :infinity])
246246

247-
# TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26
248-
# is fixed
249-
for i <- 0..62, reduce: state do
247+
for i <- 0..99, reduce: state do
250248
new_state ->
251249
{_, %{loss_scale: loss_scale, counter: counter} = new_state} =
252250
adjust_fn.(non_finite, new_state)
253251

254-
expected_new_scale = Nx.max(1, Nx.divide(init_scale, Nx.pow(factor, i + 1)))
252+
# We want to check if init_scale / factor ** (i + 1) is greater than 1.
253+
# If we rely on `i` directly, we run into integer overflow issues.
254+
# Instead, we accumulate the divisor on the reduce.
255+
256+
scale_divisor = 2 ** (i + 1)
257+
258+
expected_new_scale =
259+
if scale_divisor >= 2 ** 32 do
260+
Nx.tensor(1)
261+
else
262+
Nx.max(1, Nx.divide(init_scale, scale_divisor))
263+
end
264+
255265
assert_equal(counter, Nx.tensor(0))
266+
256267
assert_all_close(loss_scale, expected_new_scale)
257268

258269
new_state
@@ -277,15 +288,19 @@ defmodule Axon.LossScaleTest do
277288

278289
non_finite = Nx.tensor([:infinity, :infinity, :infinity])
279290

280-
# TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26
281-
# is fixed
282-
for i <- 0..62, reduce: state do
291+
for i <- 0..99, reduce: state do
283292
new_state ->
284293
{_, %{loss_scale: loss_scale, counter: counter} = new_state} =
285294
adjust_fn.(non_finite, new_state)
286295

296+
scale_divisor = 2 ** (i + 1)
297+
287298
expected_new_scale =
288-
Nx.max(min_loss_scale, Nx.divide(init_scale, Nx.pow(factor, i + 1)))
299+
if scale_divisor >= 2 ** 32 do
300+
Nx.tensor(min_loss_scale)
301+
else
302+
Nx.max(min_loss_scale, Nx.divide(init_scale, scale_divisor))
303+
end
289304

290305
assert_equal(counter, Nx.tensor(0))
291306
assert_all_close(loss_scale, expected_new_scale)

0 commit comments

Comments
 (0)