Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions lib/axon/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,7 @@ defmodule Axon.Shared do

defn normalize(input, mean, variance, gamma, bias, opts \\ []) do
[epsilon: epsilon] = keyword!(opts, epsilon: 1.0e-6)

# The select is so that we improve numerical stability by clipping
# both insignificant values of variance and NaNs to epsilon.
scale =
gamma * Nx.select(variance >= epsilon, Nx.rsqrt(variance + epsilon), Nx.rsqrt(epsilon))

scale = gamma * Nx.rsqrt(variance + epsilon)
scale * (input - mean) + bias
end

Expand Down
23 changes: 23 additions & 0 deletions test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1722,4 +1722,27 @@ defmodule Axon.LayersTest do
assert_all_close(expected, actual, atol: 1.0e-3)
end
end

describe "batch_norm" do
test "matches pytorch when variance < epsilon" do
input_val = -0.002805
mean = -0.008561
variance = 0.000412
weight = 1.0
bias = -0.144881
epsilon = 0.001

expected = Nx.tensor([0.0083])

actual =
Axon.Layers.batch_norm(
Nx.tensor([[[[input_val]]]]),
Nx.tensor([weight]),
Nx.tensor([bias]),
Nx.tensor([mean]),
Nx.tensor([variance]), mode: :inference, epsilon: epsilon)

assert_all_close(expected, actual, atol: 1.0e-3)
end
end
end
Loading