Skip to content

Commit

Permalink
Merge pull request #86 from tfjgeorge/pvec_pow
Browse files Browse the repository at this point in the history
pvector pow
  • Loading branch information
tfjgeorge authored Sep 30, 2024
2 parents 62db484 + 9fa382f commit fa05935
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
15 changes: 15 additions & 0 deletions nngeometry/object/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,21 @@ def __sub__(self, other):
),
)

def __pow__(self, exp):
if self.dict_repr is not None:
v_dict = dict()
for l_id, l in self.layer_collection.layers.items():
if l.bias is not None:
v_dict[l_id] = (
self.dict_repr[l_id][0] ** exp,
self.dict_repr[l_id][1] ** exp,
)
else:
v_dict[l_id] = (self.dict_repr[l_id][0] ** exp,)
return PVector(self.layer_collection, dict_repr=v_dict)
else:
return PVector(self.layer_collection, vector_repr=self.vector_repr**exp)

def dot(self, other):
"""
Computes the dot product between `self` and `other`
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch==2.0.1
torch>=2.0.1
torchvision>=0.9.1
requests>=2.24.0
22 changes: 22 additions & 0 deletions tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,28 @@ def test_sub():
)


def test_pow():
model = ConvNet()
layer_collection = LayerCollection.from_model(model)
r1 = random_pvector(layer_collection)
sqrt_r1 = r1**3
assert (
torch.norm(
sqrt_r1.get_flat_representation() - r1.get_flat_representation() ** 3
)
< 1e-5
)

r1 = random_pvector_dict(layer_collection)
sqrt_r1 = r1**3
assert (
torch.norm(
sqrt_r1.get_flat_representation() - r1.get_flat_representation() ** 3
)
< 1e-5
)


def test_clone():
eps = 1e-8
model = ConvNet()
Expand Down

0 comments on commit fa05935

Please sign in to comment.