Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embedding work in progress #81

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions nngeometry/generator/jacobian/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,21 @@ def get_jacobian(self, examples):
self.start = 0
for d in loader:
inputs = d[0]
inputs.requires_grad = True
differentiate_wrt = []
if inputs.dtype in [
torch.float16,
torch.float32,
torch.float64,
]:
inputs.requires_grad = True
differentiate_wrt.append(inputs)
bs = inputs.size(0)
output = self.function(*d).view(bs, self.n_output).sum(dim=0)
for self.i_output in range(self.n_output):
retain_graph = self.i_output < self.n_output - 1
torch.autograd.grad(
output[self.i_output],
[inputs],
differentiate_wrt,
retain_graph=retain_graph,
only_inputs=True,
)
Expand Down
23 changes: 22 additions & 1 deletion nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class LayerCollection:
"Affine1d",
"ConvTranspose2d",
"Conv1d",
"LayerNorm"
"LayerNorm",
"Embedding",
]

def __init__(self, layers=None):
Expand Down Expand Up @@ -151,6 +152,10 @@ def _module_to_layer(mod):
return LayerNormLayer(
normalized_shape=mod.normalized_shape, bias=(mod.bias is not None)
)
elif mod_class == "Embedding":
return EmbeddingLayer(
embedding_dim=mod.embedding_dim, num_embeddings=mod.num_embeddings
)

def numel(self):
"""
Expand Down Expand Up @@ -292,6 +297,22 @@ def __eq__(self, other):
)


class EmbeddingLayer(AbstractLayer):
def __init__(self, num_embeddings, embedding_dim):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = Parameter(num_embeddings, embedding_dim)

def numel(self):
return self.weight.numel()

def __eq__(self, other):
return (
self.num_embeddings == other.num_embeddings
and self.embedding_dim == other.embedding_dim
)


class BatchNorm1dLayer(AbstractLayer):
def __init__(self, num_features):
self.num_features = num_features
Expand Down
29 changes: 28 additions & 1 deletion tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
import torch.nn.functional as tF
from torch.nn.modules.conv import ConvTranspose2d
from torch.utils.data import DataLoader, Subset
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms

from nngeometry.layercollection import LayerCollection
Expand Down Expand Up @@ -184,6 +184,33 @@ def output_fn(input, target):
return (train_loader, layer_collection, net.parameters(), net, output_fn, 2)


class EmbeddingNet(nn.Module):
def __init__(self):
super(EmbeddingNet, self).__init__()
self.embedding_layer = nn.Embedding(10, 3)

def forward(self, x):
output = self.embedding_layer(x)
print(output.size())
return output.sum(axis=1)


def get_embedding_task():
train_set = TensorDataset(
torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]), torch.LongTensor([2, 0])
)
train_loader = DataLoader(dataset=train_set, batch_size=2, shuffle=False)
net = EmbeddingNet()
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(input)

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)


class LinearConvNet(nn.Module):
def __init__(self):
super(LinearConvNet, self).__init__()
Expand Down
2 changes: 2 additions & 0 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
get_conv_gn_task,
get_conv_skip_task,
get_conv_task,
get_embedding_task,
get_fullyconnect_affine_task,
get_fullyconnect_cosine_task,
get_fullyconnect_onlylast_task,
Expand Down Expand Up @@ -35,6 +36,7 @@
from nngeometry.object.vector import PVector, random_fvector, random_pvector

linear_tasks = [
get_embedding_task,
get_linear_fc_task,
get_linear_conv_task,
get_batchnorm_fc_linear_task,
Expand Down
Loading