Skip to content

Commit

Permalink
should fix layernorm, other tests still not passing
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Jan 31, 2024
1 parent 8f85cd0 commit 719f1e7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 46 deletions.
25 changes: 16 additions & 9 deletions nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
WeightNorm1dLayer,
WeightNorm2dLayer,
Conv1dLayer,
LayerNormLayer
LayerNormLayer,
)

from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward
Expand Down Expand Up @@ -277,9 +277,9 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
x_normalized = F.layer_norm(
x, normalized_shape=mod.normalized_shape, eps=mod.eps
)
buffer[:, :w_numel].add_(gy * x_normalized)
buffer[:, :w_numel].add_((gy * x_normalized).reshape(x.size(0), -1))
if layer.bias is not None:
buffer[:, w_numel:].add_(gy)
buffer[:, w_numel:].add_(gy.reshape(x.size(0), -1))


class GroupNormJacobianFactory(JacobianFactory):
Expand All @@ -292,16 +292,21 @@ def flat_grad(cls, buffer, mod, layer, x, gy):


class WeightNorm1dJacobianFactory(JacobianFactory):

@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
gw_prime = torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1).view(bs, *mod.weight.size())
gw_prime = (
torch.bmm(gy.unsqueeze(2), x.unsqueeze(1))
.view(bs, -1)
.view(bs, *mod.weight.size())
)
norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps

gw = gw_prime / torch.sqrt(norm2).unsqueeze(0)

gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0)

gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * (
mod.weight * norm2 ** (-1.5)
).unsqueeze(0)

buffer.add_(gw.view(bs, -1))

Expand All @@ -311,13 +316,15 @@ class WeightNorm2dJacobianFactory(JacobianFactory):
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
gw_prime = conv2d_backward(mod, x, gy).view(bs, *mod.weight.size())
norm2 = (mod.weight**2).sum(dim=(1,2,3), keepdim=True) + mod.eps
norm2 = (mod.weight**2).sum(dim=(1, 2, 3), keepdim=True) + mod.eps

gw = gw_prime / torch.sqrt(norm2).unsqueeze(0)
# print((gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True).size())
# print((mod.weight * norm2**(-1.5)).unsqueeze(0).size())

gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0)
gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2, 3, 4), keepdim=True) * (
mod.weight * norm2 ** (-1.5)
).unsqueeze(0)

buffer.add_(gw.view(bs, -1))

Expand Down
54 changes: 21 additions & 33 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,22 @@
import pytest
import torch
from tasks import (
get_batchnorm_conv_linear_task,
get_batchnorm_fc_linear_task,
get_conv_gn_task,
get_conv_skip_task,
get_conv_task,
get_fullyconnect_affine_task,
get_fullyconnect_cosine_task,
get_fullyconnect_onlylast_task,
get_fullyconnect_task,
get_fullyconnect_wn_task,
get_linear_conv_task,
get_linear_fc_task,
get_small_conv_transpose_task,
get_small_conv_wn_task,
get_conv1d_task,
)
from tasks import (get_batchnorm_conv_linear_task,
get_batchnorm_fc_linear_task, get_conv1d_task,
get_conv_gn_task, get_conv_skip_task, get_conv_task,
get_fullyconnect_affine_task, get_fullyconnect_cosine_task,
get_fullyconnect_onlylast_task, get_fullyconnect_task,
get_fullyconnect_wn_task, get_linear_conv_task,
get_linear_fc_task, get_small_conv_transpose_task,
get_small_conv_wn_task)
from test_tasks.layernorm import get_layernorm_conv_task, get_layernorm_task
from utils import check_ratio, check_tensors
from test_tasks.layernorm import get_layernorm_task

from nngeometry.generator import Jacobian
from nngeometry.object.fspace import FMatDense
from nngeometry.object.map import PullBackDense, PushForwardDense, PushForwardImplicit
from nngeometry.object.pspace import (
PMatBlockDiag,
PMatDense,
PMatDiag,
PMatImplicit,
PMatLowRank,
PMatQuasiDiag,
)
from nngeometry.object.map import (PullBackDense, PushForwardDense,
PushForwardImplicit)
from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag,
PMatImplicit, PMatLowRank, PMatQuasiDiag)
from nngeometry.object.vector import PVector, random_fvector, random_pvector

linear_tasks = [
Expand All @@ -42,6 +28,7 @@
]

nonlinear_tasks = [
get_layernorm_conv_task,
get_layernorm_task,
get_conv1d_task,
get_small_conv_transpose_task,
Expand Down Expand Up @@ -126,13 +113,14 @@ def test_jacobian_pushforward_dense_nonlinear():
check_tensors(
output_after - output_before,
doutput_lin.get_flat_representation().t(),
eps=5e-3, only_print_diff=True,
eps=5e-3,
only_print_diff=True,
)
check_tensors(
output_after - output_before,
doutput_lin.get_flat_representation().t(),
eps=5e-3,
)
# check_tensors(
# output_after - output_before,
# doutput_lin.get_flat_representation().t(),
# eps=5e-3,
# )


def test_jacobian_pushforward_implicit():
Expand Down
2 changes: 2 additions & 0 deletions tests/test_tasks/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torchvision import datasets, transforms

default_datapath = "tmp"


def get_mnist():
return datasets.MNIST(
root=default_datapath,
Expand Down
40 changes: 36 additions & 4 deletions tests/test_tasks/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import torch.nn as nn
from .datasets import get_mnist
from .device import to_device_model,to_device
from torch.utils.data import DataLoader, Subset

from nngeometry.layercollection import LayerCollection

from .datasets import get_mnist
from .device import to_device, to_device_model


class LayerNormNet(nn.Module):
def __init__(self, out_size):
super(LayerNormNet, self).__init__()

self.linear1 = nn.Linear(18*18, out_size)
self.linear1 = nn.Linear(18 * 18, out_size)
self.layer_norm1 = nn.LayerNorm((out_size,))

self.net = nn.Sequential(self.linear1, self.layer_norm1)
Expand All @@ -18,7 +21,8 @@ def forward(self, x):
x = x.view(x.size(0), -1)
return self.net(x)

def get_layernorm_task(normalization="none"):

def get_layernorm_task():
train_set = get_mnist()
train_set = Subset(train_set, range(70))
train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
Expand All @@ -31,3 +35,31 @@ def output_fn(input, target):

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)


class LayerNormConvNet(nn.Module):
def __init__(self):
super(LayerNormConvNet, self).__init__()
self.layer = nn.Conv2d(1, 3, (3, 2), 2)
self.layer_norm = nn.LayerNorm((3,8,9))

def forward(self, x):
x = x[:, :, 5:-5, 5:-5]
x = self.layer(x)
x = self.layer_norm(x)
return x.sum(dim=(2, 3))


def get_layernorm_conv_task():
train_set = get_mnist()
train_set = Subset(train_set, range(70))
train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
net = LayerNormConvNet()
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)

0 comments on commit 719f1e7

Please sign in to comment.