Skip to content

Commit

Permalink
Merge pull request #76 from tfjgeorge/conv1d
Browse files Browse the repository at this point in the history
Conv1d
  • Loading branch information
tfjgeorge authored Nov 25, 2023
2 parents 5658a74 + 418beb9 commit 7e0e88f
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 71 deletions.
104 changes: 97 additions & 7 deletions nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import torch
import torch.nn.functional as F

from nngeometry.layercollection import (Affine1dLayer, BatchNorm1dLayer,
BatchNorm2dLayer, Conv2dLayer,
ConvTranspose2dLayer, Cosine1dLayer,
GroupNormLayer, LinearLayer,
WeightNorm1dLayer, WeightNorm2dLayer)

from .grads_conv import conv2d_backward, convtranspose2d_backward
from nngeometry.layercollection import (
Affine1dLayer,
BatchNorm1dLayer,
BatchNorm2dLayer,
Conv2dLayer,
ConvTranspose2dLayer,
Cosine1dLayer,
GroupNormLayer,
LinearLayer,
WeightNorm1dLayer,
WeightNorm2dLayer,
Conv1dLayer,
)

from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward


class JacobianFactory:
Expand Down Expand Up @@ -325,8 +333,90 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
buffer[:, w_numel:].add_(gy)


class Conv1dJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
w_numel = layer.weight.numel()
indiv_gw = conv1d_backward(mod, x, gy)
buffer[:, :w_numel].add_(indiv_gw.view(bs, -1))
if layer.bias is not None:
buffer[:, w_numel:].add_(gy.sum(dim=2))

@classmethod
def Jv(cls, buffer, mod, layer, x, gy, v, v_bias):
bs = x.size(0)
gy2 = F.conv1d(
x, v, stride=mod.stride, padding=mod.padding, dilation=mod.dilation
)
buffer.add_((gy * gy2).view(bs, -1).sum(dim=1))
if layer.bias is not None:
buffer.add_(torch.mv(gy.sum(dim=2), v_bias))

@classmethod
def kfac_xx(cls, buffer, mod, layer, x, gy):
ks = (1, mod.weight.size(2))
# A_tilda in KFC
A_tilda = F.unfold(
x.unsqueeze(2),
kernel_size=ks,
stride=(1, mod.stride[0]),
padding=(0, mod.padding[0]),
dilation=(1, mod.dilation[0]),
)
# A_tilda is bs * #locations x #parameters
A_tilda = A_tilda.permute(0, 2, 1).contiguous().view(-1, A_tilda.size(1))
if layer.bias is not None:
A_tilda = torch.cat([A_tilda, torch.ones_like(A_tilda[:, :1])], dim=1)
# Omega_hat in KFC
buffer.add_(torch.mm(A_tilda.t(), A_tilda))

@classmethod
def kfac_gg(cls, buffer, mod, layer, x, gy):
spatial_locations = gy.size(2)
os = gy.size(1)
# DS_tilda in KFC
DS_tilda = gy.permute(0, 2, 1).contiguous().view(-1, os)
buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations)

@classmethod
def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g):
ks = (1, mod.weight.size(2))
gy_s = gy.size()
bs = gy_s[0]
# project x to kfe
x_unfold = F.unfold(
x.unsqueeze(2),
kernel_size=ks,
stride=(1, mod.stride[0]),
padding=(0, mod.padding[0]),
dilation=(1, mod.dilation[0]),
)
x_unfold_s = x_unfold.size()
x_unfold = (
x_unfold.view(bs, x_unfold_s[1], -1)
.permute(0, 2, 1)
.contiguous()
.view(-1, x_unfold_s[1])
)
if mod.bias is not None:
x_unfold = torch.cat([x_unfold, torch.ones_like(x_unfold[:, :1])], dim=1)
x_kfe = torch.mm(x_unfold, evecs_a)

# project gy to kfe
gy = gy.view(bs, gy_s[1], -1).permute(0, 2, 1).contiguous()
gy_kfe = torch.mm(gy.view(-1, gy_s[1]), evecs_g)
gy_kfe = gy_kfe.view(bs, -1, gy_s[1]).permute(0, 2, 1).contiguous()

indiv_gw = torch.bmm(
gy_kfe.view(bs, gy_s[1], -1), x_kfe.view(bs, -1, x_kfe.size(1))
)
buffer.add_((indiv_gw**2).sum(dim=0).view(-1))


FactoryMap = {
LinearLayer: LinearJacobianFactory,
Conv1dLayer: Conv1dJacobianFactory,
Conv2dLayer: Conv2dJacobianFactory,
ConvTranspose2dLayer: ConvTranspose2dJacobianFactory,
BatchNorm1dLayer: BatchNorm1dJacobianFactory,
Expand Down
37 changes: 30 additions & 7 deletions nngeometry/generator/jacobian/grads_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ def conv_backward(
return weight_bgrad


def conv1d_backward(*args, **kwargs):
"""Computes per-example gradients for nn.Conv1d layers."""
return conv_backward(*args, nd=1, **kwargs)


def conv2d_backward_using_conv(mod, x, gy):
"""Computes per-example gradients for nn.Conv2d layers."""
return conv_backward(
Expand Down Expand Up @@ -119,7 +114,29 @@ def conv2d_backward_using_unfold(mod, x, gy):


def conv2d_backward(*args, **kwargs):
return _conv_grad_impl.get_impl()(*args, **kwargs)
return _conv_grad_impl.get_impl2d()(*args, **kwargs)


def conv1d_backward_using_unfold(mod, x, gy):
"""Computes per-example gradients for nn.Conv1d layers."""
ks = (1, mod.weight.size(2))
gy_s = gy.size()
bs = gy_s[0]
x_unfold = F.unfold(
x.unsqueeze(2),
kernel_size=ks,
stride=(1, mod.stride[0]),
padding=(0, mod.padding[0]),
dilation=(1, mod.dilation[0]),
)
x_unfold_s = x_unfold.size()
return torch.bmm(
gy.view(bs, gy_s[1], -1), x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1)
)


def conv1d_backward(*args, **kwargs):
return _conv_grad_impl.get_impl1d()(*args, **kwargs)


class ConvGradImplManager:
Expand All @@ -129,12 +146,18 @@ def __init__(self):
def use_unfold(self, choice=True):
self._use_unfold = choice

def get_impl(self):
def get_impl2d(self):
if self._use_unfold:
return conv2d_backward_using_unfold
else:
return conv2d_backward_using_conv

def get_impl1d(self):
if self._use_unfold:
return conv1d_backward_using_unfold
else:
raise NotImplementedError()


_conv_grad_impl = ConvGradImplManager()

Expand Down
15 changes: 12 additions & 3 deletions nngeometry/generator/jacobian/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def __init__(
self.centering = centering

if function is None:
function = lambda *x: model(x[0])

def function(*x):
return model(x[0])

self.function = function

if layer_collection is None:
Expand Down Expand Up @@ -250,6 +253,9 @@ def get_kfac_blocks(self, examples):
elif layer_class == "Conv2dLayer":
sG = layer.out_channels
sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1]
elif layer_class == "Conv1dLayer":
sG = layer.out_channels
sA = layer.in_channels * layer.kernel_size[0]
if layer.bias is not None:
sA += 1
self._blocks[layer_id] = (
Expand Down Expand Up @@ -459,6 +465,9 @@ def get_kfe_diag(self, kfe, examples):
elif layer_class == "Conv2dLayer":
sG = layer.out_channels
sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1]
elif layer_class == "Conv1dLayer":
sG = layer.out_channels
sA = layer.in_channels * layer.kernel_size[0]
if layer.bias is not None:
sA += 1
self._diags[layer_id] = torch.zeros((sG * sA), device=device, dtype=dtype)
Expand Down Expand Up @@ -761,7 +770,7 @@ def _hook_compute_kfac_blocks(self, mod, gy):
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
block = self._blocks[layer_id]
if mod_class in ["Linear", "Conv2d"]:
if mod_class in ["Linear", "Conv2d", "Conv1d"]:
FactoryMap[layer.__class__].kfac_gg(block[1], mod, layer, x, gy)
if self.i_output == 0:
# do this only once if n_output > 1
Expand All @@ -775,7 +784,7 @@ def _hook_compute_kfe_diag(self, mod, gy):
layer = self.layer_collection[layer_id]
x = self.xs[mod]
evecs_a, evecs_g = self._kfe[layer_id]
if mod_class in ["Linear", "Conv2d"]:
if mod_class in ["Linear", "Conv2d", "Conv1d"]:
FactoryMap[layer.__class__].kfe_diag(
self._diags[layer_id], mod, layer, x, gy, evecs_a, evecs_g
)
Expand Down
33 changes: 33 additions & 0 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class LayerCollection:
"Cosine1d",
"Affine1d",
"ConvTranspose2d",
"Conv1d",
]

def __init__(self, layers=None):
Expand Down Expand Up @@ -112,6 +113,13 @@ def _module_to_layer(mod):
kernel_size=mod.kernel_size,
bias=(mod.bias is not None),
)
elif mod_class == "Conv1d":
return Conv1dLayer(
in_channels=mod.in_channels,
out_channels=mod.out_channels,
kernel_size=mod.kernel_size,
bias=(mod.bias is not None),
)
elif mod_class == "BatchNorm1d":
return BatchNorm1dLayer(num_features=mod.num_features)
elif mod_class == "BatchNorm2d":
Expand Down Expand Up @@ -231,6 +239,31 @@ def __eq__(self, other):
)


class Conv1dLayer(AbstractLayer):
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.weight = Parameter(out_channels, in_channels, kernel_size[0])
if bias:
self.bias = Parameter(out_channels)
else:
self.bias = None

def numel(self):
if self.bias is not None:
return self.weight.numel() + self.bias.numel()
else:
return self.weight.numel()

def __eq__(self, other):
return (
self.in_channels == other.in_channels
and self.out_channels == other.out_channels
and self.kernel_size == other.kernel_size
)


class LinearLayer(AbstractLayer):
def __init__(self, in_features, out_features, bias=True):
self.in_features = in_features
Expand Down
12 changes: 10 additions & 2 deletions nngeometry/object/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from .fspace import FMatDense
from .map import PullBackDense, PushForwardDense, PushForwardImplicit
from .pspace import (PMatBlockDiag, PMatDense, PMatDiag, PMatEKFAC,
PMatImplicit, PMatKFAC, PMatLowRank, PMatQuasiDiag)
from .pspace import (
PMatBlockDiag,
PMatDense,
PMatDiag,
PMatEKFAC,
PMatImplicit,
PMatKFAC,
PMatLowRank,
PMatQuasiDiag,
)
from .vector import FVector, PVector

__all__ = [
Expand Down
16 changes: 8 additions & 8 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def get_dense_tensor(self, split_weight_bias=True):
a, g = self.data[layer_id]
start = self.generator.layer_collection.p_pos[layer_id]
sAG = a.size(0) * g.size(0)
if split_weight_bias:
if split_weight_bias and layer.bias:
reconstruct = torch.cat(
[
torch.cat(
Expand Down Expand Up @@ -517,7 +517,7 @@ def get_diag(self, split_weight_bias=True):
for layer_id, layer in self.generator.layer_collection.layers.items():
a, g = self.data[layer_id]
diag_of_block = torch.diag(g).view(-1, 1) * torch.diag(a).view(1, -1)
if split_weight_bias:
if split_weight_bias and layer.bias:
diags.append(diag_of_block[:, :-1].contiguous().view(-1))
diags.append(diag_of_block[:, -1:].view(-1))
else:
Expand All @@ -535,10 +535,10 @@ def mv(self, vs):
v = torch.cat([v, vs_dict[layer_id][1].unsqueeze(1)], dim=1)
a, g = self.data[layer_id]
mv = torch.mm(torch.mm(g, v), a)
if layer.bias is None:
mv_tuple = (mv.view(*sw),)
else:
if layer.bias:
mv_tuple = (mv[:, :-1].contiguous().view(*sw), mv[:, -1].contiguous())
else:
mv_tuple = (mv.view(*sw),)
out_dict[layer_id] = mv_tuple
return PVector(layer_collection=vs.layer_collection, dict_repr=out_dict)

Expand Down Expand Up @@ -602,7 +602,7 @@ class PMatEKFAC(PMatAbstract):
"""
EKFAC representation from
*George, Laurent et al., Fast Approximate Natural Gradient Descent
in a Kronecker-factored Eigenbasis, NIPS 2018*
in a Kronecker-factored Eigenbasis, NeurIPS 2018*
"""

Expand Down Expand Up @@ -659,9 +659,9 @@ def get_KFE(self, split_weight_bias=True):
"""
evecs, _ = self.data
KFE = dict()
for layer_id, _ in self.generator.layer_collection.layers.items():
for layer_id, layer in self.generator.layer_collection.layers.items():
evecs_a, evecs_g = evecs[layer_id]
if split_weight_bias:
if split_weight_bias and layer.bias:
kronecker(evecs_g, evecs_a[:-1, :])
kronecker(evecs_g, evecs_a[-1:, :].contiguous())
KFE[layer_id] = torch.cat(
Expand Down
2 changes: 1 addition & 1 deletion nngeometry/object/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _dict_to_flat(self):
parts = []
for layer_id, layer in self.layer_collection.layers.items():
parts.append(self.dict_repr[layer_id][0].view(-1))
if len(self.dict_repr[layer_id]) > 1:
if layer.bias:
parts.append(self.dict_repr[layer_id][1].view(-1))
return torch.cat(parts)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
'nngeometry.generator',
'nngeometry.generator.jacobian',
'nngeometry.object'],
install_requires=['torch>=2.0.0'],
install_requires=['torch>=2.0.0','torchvision>=0.9.1'],
zip_safe=False)
Loading

0 comments on commit 7e0e88f

Please sign in to comment.