Skip to content

Commit

Permalink
layer norm work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Jan 29, 2024
1 parent 5c61d0f commit 8f85cd0
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 4 deletions.
38 changes: 35 additions & 3 deletions nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
WeightNorm1dLayer,
WeightNorm2dLayer,
Conv1dLayer,
LayerNormLayer
)

from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward
Expand Down Expand Up @@ -269,6 +270,18 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
buffer[:, w_numel:].add_(gy.sum(dim=(2, 3)))


class LayerNormJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
w_numel = layer.weight.numel()
x_normalized = F.layer_norm(
x, normalized_shape=mod.normalized_shape, eps=mod.eps
)
buffer[:, :w_numel].add_(gy * x_normalized)
if layer.bias is not None:
buffer[:, w_numel:].add_(gy)


class GroupNormJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
Expand All @@ -279,19 +292,37 @@ def flat_grad(cls, buffer, mod, layer, x, gy):


class WeightNorm1dJacobianFactory(JacobianFactory):

@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
gw_prime = torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1).view(bs, *mod.weight.size())
norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps
gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2), x.unsqueeze(1))
wn2_out = F.linear(x, mod.weight / norm2**1.5)
gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0)

gw = gw_prime / torch.sqrt(norm2).unsqueeze(0)

gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0)

buffer.add_(gw.view(bs, -1))


class WeightNorm2dJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
gw_prime = conv2d_backward(mod, x, gy).view(bs, *mod.weight.size())
norm2 = (mod.weight**2).sum(dim=(1,2,3), keepdim=True) + mod.eps

gw = gw_prime / torch.sqrt(norm2).unsqueeze(0)
# print((gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True).size())
# print((mod.weight * norm2**(-1.5)).unsqueeze(0).size())

gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0)

buffer.add_(gw.view(bs, -1))

@classmethod
def flat_grad_(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
out_dim = mod.weight.size(0)
norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps
Expand Down Expand Up @@ -426,4 +457,5 @@ def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g):
WeightNorm2dLayer: WeightNorm2dJacobianFactory,
Cosine1dLayer: Cosine1dJacobianFactory,
Affine1dLayer: Affine1dJacobianFactory,
LayerNormLayer: LayerNormJacobianFactory,
}
26 changes: 26 additions & 0 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LayerCollection:
"Affine1d",
"ConvTranspose2d",
"Conv1d",
"LayerNorm"
]

def __init__(self, layers=None):
Expand Down Expand Up @@ -146,6 +147,10 @@ def _module_to_layer(mod):
return Affine1dLayer(
num_features=mod.num_features, bias=(mod.bias is not None)
)
elif mod_class == "LayerNorm":
return LayerNormLayer(
normalized_shape=mod.normalized_shape, bias=(mod.bias is not None)
)

def numel(self):
"""
Expand Down Expand Up @@ -313,6 +318,24 @@ def __eq__(self, other):
return self.num_features == other.num_features


class LayerNormLayer(AbstractLayer):
def __init__(self, normalized_shape, bias=True):
self.weight = Parameter(*normalized_shape)
if bias:
self.bias = Parameter(*normalized_shape)
else:
self.bias = None

Check warning on line 327 in nngeometry/layercollection.py

View check run for this annotation

Codecov / codecov/patch

nngeometry/layercollection.py#L327

Added line #L327 was not covered by tests

def numel(self):
if self.bias is not None:
return self.weight.numel() + self.bias.numel()
else:
return self.weight.numel()

Check warning on line 333 in nngeometry/layercollection.py

View check run for this annotation

Codecov / codecov/patch

nngeometry/layercollection.py#L333

Added line #L333 was not covered by tests

def __eq__(self, other):
return self.weight == other.weight and self.bias == other.bias

Check warning on line 336 in nngeometry/layercollection.py

View check run for this annotation

Codecov / codecov/patch

nngeometry/layercollection.py#L336

Added line #L336 was not covered by tests


class GroupNormLayer(AbstractLayer):
def __init__(self, num_groups, num_channels):
self.num_channels = num_channels
Expand Down Expand Up @@ -406,3 +429,6 @@ def __init__(self, *size):

def numel(self):
return reduce(operator.mul, self.size, 1)

def __eq__(self, other):
return self.size == other.size

Check warning on line 434 in nngeometry/layercollection.py

View check run for this annotation

Codecov / codecov/patch

nngeometry/layercollection.py#L434

Added line #L434 was not covered by tests
10 changes: 9 additions & 1 deletion tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_conv1d_task,
)
from utils import check_ratio, check_tensors
from test_tasks.layernorm import get_layernorm_task

from nngeometry.generator import Jacobian
from nngeometry.object.fspace import FMatDense
Expand All @@ -41,6 +42,7 @@
]

nonlinear_tasks = [
get_layernorm_task,
get_conv1d_task,
get_small_conv_transpose_task,
get_conv_task,
Expand Down Expand Up @@ -104,6 +106,7 @@ def test_jacobian_pushforward_dense_linear():

def test_jacobian_pushforward_dense_nonlinear():
for get_task in nonlinear_tasks:
print(get_task)
loader, lc, parameters, model, function, n_output = get_task()
generator = Jacobian(
layer_collection=lc, model=model, function=function, n_output=n_output
Expand All @@ -123,8 +126,13 @@ def test_jacobian_pushforward_dense_nonlinear():
check_tensors(
output_after - output_before,
doutput_lin.get_flat_representation().t(),
eps=5e-3,
eps=5e-3, only_print_diff=True,
)
# check_tensors(
# output_after - output_before,
# doutput_lin.get_flat_representation().t(),
# eps=5e-3,
# )


def test_jacobian_pushforward_implicit():
Expand Down
10 changes: 10 additions & 0 deletions tests/test_tasks/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torchvision import datasets, transforms
default_datapath = "tmp"

def get_mnist():
return datasets.MNIST(
root=default_datapath,
train=True,
download=True,
transform=transforms.ToTensor(),
)
21 changes: 21 additions & 0 deletions tests/test_tasks/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

if torch.cuda.is_available():
device = "cuda"

Check warning on line 4 in tests/test_tasks/device.py

View check run for this annotation

Codecov / codecov/patch

tests/test_tasks/device.py#L4

Added line #L4 was not covered by tests

def to_device(tensor):
return tensor.to(device)

Check warning on line 7 in tests/test_tasks/device.py

View check run for this annotation

Codecov / codecov/patch

tests/test_tasks/device.py#L6-L7

Added lines #L6 - L7 were not covered by tests

def to_device_model(model):
model.to("cuda")

Check warning on line 10 in tests/test_tasks/device.py

View check run for this annotation

Codecov / codecov/patch

tests/test_tasks/device.py#L9-L10

Added lines #L9 - L10 were not covered by tests

else:
device = "cpu"

# on cpu we need to use double as otherwise ill-conditioning in sums
# causes numerical instability
def to_device(tensor):
return tensor.double()

def to_device_model(model):
model.double()
33 changes: 33 additions & 0 deletions tests/test_tasks/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch.nn as nn
from .datasets import get_mnist
from .device import to_device_model,to_device
from torch.utils.data import DataLoader, Subset
from nngeometry.layercollection import LayerCollection

class LayerNormNet(nn.Module):
def __init__(self, out_size):
super(LayerNormNet, self).__init__()

self.linear1 = nn.Linear(18*18, out_size)
self.layer_norm1 = nn.LayerNorm((out_size,))

self.net = nn.Sequential(self.linear1, self.layer_norm1)

def forward(self, x):
x = x[:, :, 5:-5, 5:-5].contiguous()
x = x.view(x.size(0), -1)
return self.net(x)

def get_layernorm_task(normalization="none"):
train_set = get_mnist()
train_set = Subset(train_set, range(70))
train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
net = LayerNormNet(out_size=3)
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)

0 comments on commit 8f85cd0

Please sign in to comment.