diff --git a/nngeometry/object/pspace.py b/nngeometry/object/pspace.py index b374684..aa5c347 100644 --- a/nngeometry/object/pspace.py +++ b/nngeometry/object/pspace.py @@ -444,6 +444,27 @@ def inverse(self, regul=1e-8, use_pi=True): inv_data[layer_id] = (inv_a, inv_g) return PMatKFAC(generator=self.generator, data=inv_data) + def pow(self, pow, regul=1e-8, use_pi=True): + pow_data = dict() + for layer_id, layer in self.generator.layer_collection.layers.items(): + a, g = self.data[layer_id] + if use_pi: + pi = (torch.trace(a) / torch.trace(g) * g.size(0) / a.size(0)) ** 0.5 + else: + pi = 1 + pow_a = torch.linalg.matrix_power( + a + pi * regul**0.5 * torch.eye(a.size(0), device=a.device), pow + ) + pow_g = torch.linalg.matrix_power( + g + regul**0.5 / pi * torch.eye(g.size(0), device=g.device), pow + ) + pow_data[layer_id] = (pow_a, pow_g) + return PMatKFAC(generator=self.generator, data=pow_data) + + def __pow__(self, pow): + return self.pow(pow) + + def solve(self, vs, regul=1e-8, use_pi=True): vs_dict = vs.get_dict_representation() out_dict = dict() @@ -729,10 +750,16 @@ def get_diag(self, v): raise NotImplementedError def inverse(self, regul=1e-8): + return self.pow(-1, regul=regul) + + def pow(self, pow, regul=1e-8): evecs, diags = self.data - inv_diags = {i: 1.0 / (d + regul) for i, d in diags.items()} + inv_diags = {i: (d + regul) ** pow for i, d in diags.items()} return PMatEKFAC(generator=self.generator, data=(evecs, inv_diags)) + def __pow__(self, pow): + return self.pow(pow) + def solve(self, vs, regul=1e-8): vs_dict = vs.get_dict_representation() out_dict = dict() diff --git a/tests/test_jacobian_ekfac.py b/tests/test_jacobian_ekfac.py index 79be84d..3ffc572 100644 --- a/tests/test_jacobian_ekfac.py +++ b/tests/test_jacobian_ekfac.py @@ -84,6 +84,13 @@ def test_pspace_ekfac_vs_direct(): mv_ekfac = M_ekfac.mv(v) check_tensors(mv_direct, mv_ekfac.get_flat_representation()) + # Test pow + M_pow = M_ekfac**2 + check_tensors( + M_pow.get_dense_tensor(), + torch.mm(M_ekfac.get_dense_tensor(), M_ekfac.get_dense_tensor()), + ) + # Test inverse regul = 1e-5 M_inv = M_ekfac.inverse(regul=regul) diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index f27178c..a42ed23 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -202,6 +202,13 @@ def test_jacobian_kfac(): mnorm_direct = torch.dot(mv_direct, random_v.get_flat_representation()) check_ratio(mnorm_direct, mnorm_kfac) + # Test pow + M_pow = M_kfac**2 + check_tensors( + M_pow.get_dense_tensor(), + torch.mm(M_kfac.get_dense_tensor(), M_kfac.get_dense_tensor()), + ) + # Test inverse # We start from a mv vector since it kills its components projected to # the small eigenvalues of KFAC