Skip to content

Commit 4fac82b

Browse files
authored
Fix group normalization (#477)
1 parent 884107c commit 4fac82b

File tree

5 files changed

+145
-6
lines changed

5 files changed

+145
-6
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## v0.5.1 (2022-02-17)
4+
5+
### Bug Fixes
6+
7+
* Fixed incorrect results from group normalization
8+
39
## v0.5.0 (2022-02-16)
410

511
### Enhancements

lib/axon/layers.ex

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,24 +1207,29 @@ defmodule Axon.Layers do
12071207
defn group_norm(input, gamma, beta, opts \\ []) do
12081208
opts = keyword!(opts, [:num_groups, epsilon: 1.0e-5, channel_index: -1, mode: :inference])
12091209

1210+
channel_axis = normalize_group_norm_channel_axis(input, opts[:channel_index])
1211+
12101212
group_shape = Axon.Shape.group_norm_shape(input, opts[:num_groups], opts[:channel_index])
12111213
num_channels = Nx.axis_size(input, opts[:channel_index])
12121214

12131215
parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index])
1214-
12151216
gamma = Nx.reshape(gamma, parameter_shape)
12161217
beta = Nx.reshape(beta, parameter_shape)
12171218

12181219
x = Nx.reshape(input, group_shape)
12191220

1220-
axes = Axon.Shape.group_norm_axes(x, opts[:channel_index])
1221+
axes = Axon.Shape.group_norm_axes(x, channel_axis)
12211222

12221223
{mean, var} = mean_and_variance(x, axes: axes)
12231224
x = (x - mean) * Nx.rsqrt(var + opts[:epsilon])
12241225
x = Nx.reshape(x, input)
12251226
x * gamma + beta
12261227
end
12271228

1229+
deftransformp normalize_group_norm_channel_axis(input, channel_index) do
1230+
Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.shape(input))
1231+
end
1232+
12281233
@doc ~S"""
12291234
Functional implementation of instance normalization.
12301235

lib/axon/shape.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -942,16 +942,16 @@ defmodule Axon.Shape do
942942
@doc """
943943
Calculates the reduction axes for group normalization.
944944
"""
945-
deftransform group_norm_axes(input, channel_index) do
946-
Enum.to_list(1..(Nx.rank(input) - 1)) -- [channel_index]
945+
deftransform group_norm_axes(x, channel_index) do
946+
Enum.to_list(1..(Nx.rank(x) - 1)) -- [channel_index]
947947
end
948948

949949
@doc """
950950
Calculates the reshape for group normalization.
951951
"""
952952
deftransform group_norm_shape(input, num_groups, channel_index) do
953953
shape = Nx.shape(input)
954-
channel_index = Nx.Shape.normalize_axis(shape, channel_index, Nx.names(input))
954+
channel_index = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input))
955955

956956
channels = elem(shape, channel_index)
957957
group_size = div(channels, num_groups)

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defmodule Axon.MixProject do
22
use Mix.Project
33

44
@source_url "https://github.com/elixir-nx/axon"
5-
@version "0.5.0"
5+
@version "0.5.1"
66

77
def project do
88
[

test/axon/layers_test.exs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,4 +1226,132 @@ defmodule Axon.LayersTest do
12261226
)
12271227
end
12281228
end
1229+
1230+
describe "group_norm" do
1231+
test "matches pytorch" do
1232+
a =
1233+
Nx.tensor([
1234+
[
1235+
0.8423,
1236+
1.9226,
1237+
-1.1295,
1238+
-1.3154,
1239+
1.2963,
1240+
-0.6821,
1241+
-0.0519,
1242+
0.6875,
1243+
-0.0313,
1244+
-0.3328,
1245+
-0.2821,
1246+
-2.3289,
1247+
-1.7641,
1248+
-1.3184,
1249+
-0.0890,
1250+
0.0625
1251+
],
1252+
[
1253+
-1.0853,
1254+
0.8060,
1255+
-0.1397,
1256+
-0.2169,
1257+
0.9605,
1258+
0.3947,
1259+
0.4760,
1260+
0.8097,
1261+
0.0380,
1262+
-0.6314,
1263+
0.5761,
1264+
1.9309,
1265+
0.5038,
1266+
-0.1892,
1267+
1.8476,
1268+
0.0517
1269+
]
1270+
])
1271+
1272+
b =
1273+
Nx.tensor([
1274+
-0.3101,
1275+
-1.5896,
1276+
-1.4963,
1277+
0.1278,
1278+
-1.4580,
1279+
1.3832,
1280+
0.5709,
1281+
0.5531,
1282+
-0.0588,
1283+
1.0411,
1284+
1.3503,
1285+
-1.2166,
1286+
0.7133,
1287+
0.0694,
1288+
0.3150,
1289+
-0.1306
1290+
])
1291+
1292+
c =
1293+
Nx.tensor([
1294+
1.6585,
1295+
2.3515,
1296+
-1.3456,
1297+
0.2376,
1298+
-0.1333,
1299+
0.5068,
1300+
0.2441,
1301+
1.0382,
1302+
0.6879,
1303+
-0.5402,
1304+
-1.8304,
1305+
-0.8906,
1306+
-0.5329,
1307+
-0.3390,
1308+
-0.1877,
1309+
0.1405
1310+
])
1311+
1312+
expected =
1313+
Nx.tensor([
1314+
[
1315+
1.4768,
1316+
-0.1375,
1317+
0.4536,
1318+
0.0623,
1319+
-1.5881,
1320+
-0.5951,
1321+
0.1157,
1322+
1.2847,
1323+
0.6378,
1324+
-0.0194,
1325+
-1.0751,
1326+
1.3407,
1327+
-1.3700,
1328+
-0.3844,
1329+
0.0597,
1330+
0.0149
1331+
],
1332+
[
1333+
2.2986,
1334+
0.9877,
1335+
-0.4434,
1336+
0.1453,
1337+
-1.7321,
1338+
0.8146,
1339+
0.4430,
1340+
1.5159,
1341+
0.7202,
1342+
-1.9153,
1343+
-1.7368,
1344+
-2.8723,
1345+
-0.5429,
1346+
-0.3954,
1347+
0.2952,
1348+
0.2103
1349+
]
1350+
])
1351+
1352+
actual = Axon.Layers.group_norm(a, b, c, num_groups: 2)
1353+
1354+
assert_all_close(expected, actual, atol: 1.0e-3)
1355+
end
1356+
end
12291357
end

0 commit comments

Comments
 (0)