Skip to content

Commit

Permalink
Fix group normalization (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Feb 17, 2023
1 parent 884107c commit 4fac82b
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 6 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v0.5.1 (2022-02-17)

### Bug Fixes

* Fixed incorrect results from group normalization

## v0.5.0 (2022-02-16)

### Enhancements
Expand Down
9 changes: 7 additions & 2 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1207,24 +1207,29 @@ defmodule Axon.Layers do
defn group_norm(input, gamma, beta, opts \\ []) do
opts = keyword!(opts, [:num_groups, epsilon: 1.0e-5, channel_index: -1, mode: :inference])

channel_axis = normalize_group_norm_channel_axis(input, opts[:channel_index])

group_shape = Axon.Shape.group_norm_shape(input, opts[:num_groups], opts[:channel_index])
num_channels = Nx.axis_size(input, opts[:channel_index])

parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index])

gamma = Nx.reshape(gamma, parameter_shape)
beta = Nx.reshape(beta, parameter_shape)

x = Nx.reshape(input, group_shape)

axes = Axon.Shape.group_norm_axes(x, opts[:channel_index])
axes = Axon.Shape.group_norm_axes(x, channel_axis)

{mean, var} = mean_and_variance(x, axes: axes)
x = (x - mean) * Nx.rsqrt(var + opts[:epsilon])
x = Nx.reshape(x, input)
x * gamma + beta
end

deftransformp normalize_group_norm_channel_axis(input, channel_index) do
Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.shape(input))
end

@doc ~S"""
Functional implementation of instance normalization.
Expand Down
6 changes: 3 additions & 3 deletions lib/axon/shape.ex
Original file line number Diff line number Diff line change
Expand Up @@ -942,16 +942,16 @@ defmodule Axon.Shape do
@doc """
Calculates the reduction axes for group normalization.
"""
deftransform group_norm_axes(input, channel_index) do
Enum.to_list(1..(Nx.rank(input) - 1)) -- [channel_index]
deftransform group_norm_axes(x, channel_index) do
Enum.to_list(1..(Nx.rank(x) - 1)) -- [channel_index]
end

@doc """
Calculates the reshape for group normalization.
"""
deftransform group_norm_shape(input, num_groups, channel_index) do
shape = Nx.shape(input)
channel_index = Nx.Shape.normalize_axis(shape, channel_index, Nx.names(input))
channel_index = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input))

channels = elem(shape, channel_index)
group_size = div(channels, num_groups)
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule Axon.MixProject do
use Mix.Project

@source_url "https://github.com/elixir-nx/axon"
@version "0.5.0"
@version "0.5.1"

def project do
[
Expand Down
128 changes: 128 additions & 0 deletions test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1226,4 +1226,132 @@ defmodule Axon.LayersTest do
)
end
end

describe "group_norm" do
test "matches pytorch" do
a =
Nx.tensor([
[
0.8423,
1.9226,
-1.1295,
-1.3154,
1.2963,
-0.6821,
-0.0519,
0.6875,
-0.0313,
-0.3328,
-0.2821,
-2.3289,
-1.7641,
-1.3184,
-0.0890,
0.0625
],
[
-1.0853,
0.8060,
-0.1397,
-0.2169,
0.9605,
0.3947,
0.4760,
0.8097,
0.0380,
-0.6314,
0.5761,
1.9309,
0.5038,
-0.1892,
1.8476,
0.0517
]
])

b =
Nx.tensor([
-0.3101,
-1.5896,
-1.4963,
0.1278,
-1.4580,
1.3832,
0.5709,
0.5531,
-0.0588,
1.0411,
1.3503,
-1.2166,
0.7133,
0.0694,
0.3150,
-0.1306
])

c =
Nx.tensor([
1.6585,
2.3515,
-1.3456,
0.2376,
-0.1333,
0.5068,
0.2441,
1.0382,
0.6879,
-0.5402,
-1.8304,
-0.8906,
-0.5329,
-0.3390,
-0.1877,
0.1405
])

expected =
Nx.tensor([
[
1.4768,
-0.1375,
0.4536,
0.0623,
-1.5881,
-0.5951,
0.1157,
1.2847,
0.6378,
-0.0194,
-1.0751,
1.3407,
-1.3700,
-0.3844,
0.0597,
0.0149
],
[
2.2986,
0.9877,
-0.4434,
0.1453,
-1.7321,
0.8146,
0.4430,
1.5159,
0.7202,
-1.9153,
-1.7368,
-2.8723,
-0.5429,
-0.3954,
0.2952,
0.2103
]
])

actual = Axon.Layers.group_norm(a, b, c, num_groups: 2)

assert_all_close(expected, actual, atol: 1.0e-3)
end
end
end

0 comments on commit 4fac82b

Please sign in to comment.