Skip to content

Commit

Permalink
adds pow to KFAC and EKFAC
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Sep 30, 2024
1 parent fa05935 commit 4485ad5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
29 changes: 28 additions & 1 deletion nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,27 @@ def inverse(self, regul=1e-8, use_pi=True):
inv_data[layer_id] = (inv_a, inv_g)
return PMatKFAC(generator=self.generator, data=inv_data)

def pow(self, pow, regul=1e-8, use_pi=True):
pow_data = dict()
for layer_id, layer in self.generator.layer_collection.layers.items():
a, g = self.data[layer_id]
if use_pi:
pi = (torch.trace(a) / torch.trace(g) * g.size(0) / a.size(0)) ** 0.5
else:
pi = 1
pow_a = torch.linalg.matrix_power(
a + pi * regul**0.5 * torch.eye(a.size(0), device=a.device), pow
)
pow_g = torch.linalg.matrix_power(
g + regul**0.5 / pi * torch.eye(g.size(0), device=g.device), pow
)
pow_data[layer_id] = (pow_a, pow_g)
return PMatKFAC(generator=self.generator, data=pow_data)

def __pow__(self, pow):
return self.pow(pow)


def solve(self, vs, regul=1e-8, use_pi=True):
vs_dict = vs.get_dict_representation()
out_dict = dict()
Expand Down Expand Up @@ -729,10 +750,16 @@ def get_diag(self, v):
raise NotImplementedError

def inverse(self, regul=1e-8):
return self.pow(-1, regul=regul)

def pow(self, pow, regul=1e-8):
evecs, diags = self.data
inv_diags = {i: 1.0 / (d + regul) for i, d in diags.items()}
inv_diags = {i: (d + regul) ** pow for i, d in diags.items()}
return PMatEKFAC(generator=self.generator, data=(evecs, inv_diags))

def __pow__(self, pow):
return self.pow(pow)

def solve(self, vs, regul=1e-8):
vs_dict = vs.get_dict_representation()
out_dict = dict()
Expand Down
7 changes: 7 additions & 0 deletions tests/test_jacobian_ekfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def test_pspace_ekfac_vs_direct():
mv_ekfac = M_ekfac.mv(v)
check_tensors(mv_direct, mv_ekfac.get_flat_representation())

# Test pow
M_pow = M_ekfac**2
check_tensors(
M_pow.get_dense_tensor(),
torch.mm(M_ekfac.get_dense_tensor(), M_ekfac.get_dense_tensor()),
)

# Test inverse
regul = 1e-5
M_inv = M_ekfac.inverse(regul=regul)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_jacobian_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ def test_jacobian_kfac():
mnorm_direct = torch.dot(mv_direct, random_v.get_flat_representation())
check_ratio(mnorm_direct, mnorm_kfac)

# Test pow
M_pow = M_kfac**2
check_tensors(
M_pow.get_dense_tensor(),
torch.mm(M_kfac.get_dense_tensor(), M_kfac.get_dense_tensor()),
)

# Test inverse
# We start from a mv vector since it kills its components projected to
# the small eigenvalues of KFAC
Expand Down

0 comments on commit 4485ad5

Please sign in to comment.