Skip to content

Commit 1577e17

Browse files
committed
fix implicit methods
1 parent c87132d commit 1577e17

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

gbml/imaml.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@ def __init__(self, args):
1414
super().__init__(args)
1515
self._init_net()
1616
self._init_opt()
17-
self.lamb = 2.0
18-
self.n_cg = 3
17+
self.lamb = 100.0
18+
self.n_cg = 1
1919
return None
2020

2121
@torch.enable_grad()
2222
def inner_loop(self, fmodel, diffopt, train_input, train_target):
2323

2424
train_logit = fmodel(train_input)
2525
inner_loss = F.cross_entropy(train_logit, train_target)
26-
inner_loss += (self.lamb/2.) * ((torch.nn.utils.parameters_to_vector(self.network.parameters())-torch.nn.utils.parameters_to_vector(self.network.parameters()).detach())**2).sum()
2726
diffopt.step(inner_loss)
2827

2928
return None
@@ -41,8 +40,6 @@ def cg(self, in_grad, outer_grad, params):
4140
beta = (r_new @ r_new)/(r @ r)
4241
p = r_new + beta * p
4342
r = r_new.clone().detach()
44-
# print(alpha, beta ,r @ r);input()
45-
# print('end')
4643
return self.vec_to_grad(x)
4744

4845
def vec_to_grad(self, vec):

gbml/neumann.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def __init__(self, args):
1414
super().__init__(args)
1515
self._init_net()
1616
self._init_opt()
17-
self.n_series = 3
17+
self.lamb = 100.0
18+
self.n_series = 1
1819
return None
1920

2021
@torch.enable_grad()
@@ -48,7 +49,7 @@ def vec_to_grad(self, vec):
4849
def hv_prod(self, in_grad, x, params):
4950
hv = torch.autograd.grad(in_grad, params, retain_graph=True, grad_outputs=x)
5051
hv = torch.nn.utils.parameters_to_vector(hv)
51-
hv = (-1.*self.args.inner_lr) * hv # scale for regularization
52+
hv = (-1./self.lamb) * hv # scaling for convergence
5253
return hv.detach()
5354

5455
def outer_loop(self, batch, is_train):

0 commit comments

Comments
 (0)