@@ -1207,24 +1207,29 @@ defmodule Axon.Layers do
1207
1207
defn group_norm ( input , gamma , beta , opts \\ [ ] ) do
1208
1208
opts = keyword! ( opts , [ :num_groups , epsilon: 1.0e-5 , channel_index: - 1 , mode: :inference ] )
1209
1209
1210
+ channel_axis = normalize_group_norm_channel_axis ( input , opts [ :channel_index ] )
1211
+
1210
1212
group_shape = Axon.Shape . group_norm_shape ( input , opts [ :num_groups ] , opts [ :channel_index ] )
1211
1213
num_channels = Nx . axis_size ( input , opts [ :channel_index ] )
1212
1214
1213
1215
parameter_shape = norm_parameter_reshape ( input , num_channels , opts [ :channel_index ] )
1214
-
1215
1216
gamma = Nx . reshape ( gamma , parameter_shape )
1216
1217
beta = Nx . reshape ( beta , parameter_shape )
1217
1218
1218
1219
x = Nx . reshape ( input , group_shape )
1219
1220
1220
- axes = Axon.Shape . group_norm_axes ( x , opts [ :channel_index ] )
1221
+ axes = Axon.Shape . group_norm_axes ( x , channel_axis )
1221
1222
1222
1223
{ mean , var } = mean_and_variance ( x , axes: axes )
1223
1224
x = ( x - mean ) * Nx . rsqrt ( var + opts [ :epsilon ] )
1224
1225
x = Nx . reshape ( x , input )
1225
1226
x * gamma + beta
1226
1227
end
1227
1228
1229
+ deftransformp normalize_group_norm_channel_axis ( input , channel_index ) do
1230
+ Nx.Shape . normalize_axis ( Nx . shape ( input ) , channel_index , Nx . shape ( input ) )
1231
+ end
1232
+
1228
1233
@ doc ~S"""
1229
1234
Functional implementation of instance normalization.
1230
1235
0 commit comments