From 4fac82b3ec8dabaf467a049a586e369c2c7c2dca Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 17 Feb 2023 12:12:07 -0500 Subject: [PATCH] Fix group normalization (#477) --- CHANGELOG.md | 6 ++ lib/axon/layers.ex | 9 ++- lib/axon/shape.ex | 6 +- mix.exs | 2 +- test/axon/layers_test.exs | 128 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7194c554..ce48e218 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## v0.5.1 (2022-02-17) + +### Bug Fixes + +* Fixed incorrect results from group normalization + ## v0.5.0 (2022-02-16) ### Enhancements diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index 54003662..860405a5 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -1207,17 +1207,18 @@ defmodule Axon.Layers do defn group_norm(input, gamma, beta, opts \\ []) do opts = keyword!(opts, [:num_groups, epsilon: 1.0e-5, channel_index: -1, mode: :inference]) + channel_axis = normalize_group_norm_channel_axis(input, opts[:channel_index]) + group_shape = Axon.Shape.group_norm_shape(input, opts[:num_groups], opts[:channel_index]) num_channels = Nx.axis_size(input, opts[:channel_index]) parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - gamma = Nx.reshape(gamma, parameter_shape) beta = Nx.reshape(beta, parameter_shape) x = Nx.reshape(input, group_shape) - axes = Axon.Shape.group_norm_axes(x, opts[:channel_index]) + axes = Axon.Shape.group_norm_axes(x, channel_axis) {mean, var} = mean_and_variance(x, axes: axes) x = (x - mean) * Nx.rsqrt(var + opts[:epsilon]) @@ -1225,6 +1226,10 @@ defmodule Axon.Layers do x * gamma + beta end + deftransformp normalize_group_norm_channel_axis(input, channel_index) do + Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.shape(input)) + end + @doc ~S""" Functional implementation of instance normalization. diff --git a/lib/axon/shape.ex b/lib/axon/shape.ex index 7c5860de..34fa86d2 100644 --- a/lib/axon/shape.ex +++ b/lib/axon/shape.ex @@ -942,8 +942,8 @@ defmodule Axon.Shape do @doc """ Calculates the reduction axes for group normalization. """ - deftransform group_norm_axes(input, channel_index) do - Enum.to_list(1..(Nx.rank(input) - 1)) -- [channel_index] + deftransform group_norm_axes(x, channel_index) do + Enum.to_list(1..(Nx.rank(x) - 1)) -- [channel_index] end @doc """ @@ -951,7 +951,7 @@ defmodule Axon.Shape do """ deftransform group_norm_shape(input, num_groups, channel_index) do shape = Nx.shape(input) - channel_index = Nx.Shape.normalize_axis(shape, channel_index, Nx.names(input)) + channel_index = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) channels = elem(shape, channel_index) group_size = div(channels, num_groups) diff --git a/mix.exs b/mix.exs index cd50ab2e..e91fcc50 100644 --- a/mix.exs +++ b/mix.exs @@ -2,7 +2,7 @@ defmodule Axon.MixProject do use Mix.Project @source_url "https://github.com/elixir-nx/axon" - @version "0.5.0" + @version "0.5.1" def project do [ diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index 47df2c95..6c353d6d 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -1226,4 +1226,132 @@ defmodule Axon.LayersTest do ) end end + + describe "group_norm" do + test "matches pytorch" do + a = + Nx.tensor([ + [ + 0.8423, + 1.9226, + -1.1295, + -1.3154, + 1.2963, + -0.6821, + -0.0519, + 0.6875, + -0.0313, + -0.3328, + -0.2821, + -2.3289, + -1.7641, + -1.3184, + -0.0890, + 0.0625 + ], + [ + -1.0853, + 0.8060, + -0.1397, + -0.2169, + 0.9605, + 0.3947, + 0.4760, + 0.8097, + 0.0380, + -0.6314, + 0.5761, + 1.9309, + 0.5038, + -0.1892, + 1.8476, + 0.0517 + ] + ]) + + b = + Nx.tensor([ + -0.3101, + -1.5896, + -1.4963, + 0.1278, + -1.4580, + 1.3832, + 0.5709, + 0.5531, + -0.0588, + 1.0411, + 1.3503, + -1.2166, + 0.7133, + 0.0694, + 0.3150, + -0.1306 + ]) + + c = + Nx.tensor([ + 1.6585, + 2.3515, + -1.3456, + 0.2376, + -0.1333, + 0.5068, + 0.2441, + 1.0382, + 0.6879, + -0.5402, + -1.8304, + -0.8906, + -0.5329, + -0.3390, + -0.1877, + 0.1405 + ]) + + expected = + Nx.tensor([ + [ + 1.4768, + -0.1375, + 0.4536, + 0.0623, + -1.5881, + -0.5951, + 0.1157, + 1.2847, + 0.6378, + -0.0194, + -1.0751, + 1.3407, + -1.3700, + -0.3844, + 0.0597, + 0.0149 + ], + [ + 2.2986, + 0.9877, + -0.4434, + 0.1453, + -1.7321, + 0.8146, + 0.4430, + 1.5159, + 0.7202, + -1.9153, + -1.7368, + -2.8723, + -0.5429, + -0.3954, + 0.2952, + 0.2103 + ] + ]) + + actual = Axon.Layers.group_norm(a, b, c, num_groups: 2) + + assert_all_close(expected, actual, atol: 1.0e-3) + end + end end