Skip to content

Commit 45fc19d

Browse files
authored
Merge pull request #41 from zaccharieramzi/pogm
Pogm
2 parents 8c85136 + d4ec351 commit 45fc19d

File tree

3 files changed

+217
-0
lines changed

3 files changed

+217
-0
lines changed

modopt/opt/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
99
References
1010
----------
11+
.. [K2018] Kim, D., & Fessler, J. A. (2018).
12+
Adaptive restart of the optimized gradient method for convex optimization.
13+
Journal of Optimization Theory and Applications, 178(1), 240-263.
14+
[https://link.springer.com/content/pdf/10.1007%2Fs10957-018-1287-4.pdf]
1115
.. [L2018] Liang, Jingwei, and Carola-Bibiane Schönlieb.
1216
Improving FISTA: Faster, Smarter and Greedier.
1317
arXiv preprint arXiv:1811.01430 (2018).

modopt/opt/algorithms.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,3 +1069,202 @@ def retrieve_outputs(self):
10691069
for obs in self._observers['cv_metrics']:
10701070
metrics[obs.name] = obs.retrieve_metrics()
10711071
self.metrics = metrics
1072+
1073+
class POGM(SetUp):
1074+
r"""Proximal Optimised Gradient Method
1075+
1076+
This class implements algorithm 3 from [K2018]_
1077+
1078+
Parameters
1079+
----------
1080+
u : np.ndarray
1081+
Initial guess for the u variable
1082+
x : np.ndarray
1083+
Initial guess for the x variable (primal)
1084+
y : np.ndarray
1085+
Initial guess for the y variable
1086+
z : np.ndarray
1087+
Initial guess for the z variable
1088+
grad : class
1089+
Gradient operator class
1090+
prox : class
1091+
Proximity operator class
1092+
cost : class or str, optional
1093+
Cost function class (default is 'auto'); Use 'auto' to automatically
1094+
generate a costObj instance
1095+
linear : class instance, optional
1096+
Linear operator class (default is None)
1097+
beta_param : float, optional
1098+
Initial value of the beta parameter (default is 1.0). This corresponds
1099+
to (1 / L) in [K2018]_
1100+
sigma_bar : float, optional
1101+
Value of the shrinking parameter sigma bar (default is 1.0)
1102+
auto_iterate : bool, optional
1103+
Option to automatically begin iterations upon initialisation (default
1104+
is 'True')
1105+
"""
1106+
def __init__(
1107+
self,
1108+
u,
1109+
x,
1110+
y,
1111+
z,
1112+
grad,
1113+
prox,
1114+
cost='auto',
1115+
linear=None,
1116+
beta_param=1.0,
1117+
sigma_bar=1.0,
1118+
auto_iterate=True,
1119+
metric_call_period=5,
1120+
metrics={},
1121+
):
1122+
# Set default algorithm properties
1123+
super(POGM, self).__init__(
1124+
metric_call_period=metric_call_period,
1125+
metrics=metrics,
1126+
linear=linear,
1127+
)
1128+
1129+
# set the initial variable values
1130+
(self._check_input_data(data) for data in (u, x, y, z))
1131+
self._u_old = np.copy(u)
1132+
self._x_old = np.copy(x)
1133+
self._y_old = np.copy(y)
1134+
self._z = np.copy(z)
1135+
1136+
# Set the algorithm operators
1137+
(self._check_operator(operator) for operator in (grad, prox, cost))
1138+
self._grad = grad
1139+
self._prox = prox
1140+
self._linear = linear
1141+
if cost == 'auto':
1142+
self._cost_func = costObj([self._grad, self._prox])
1143+
else:
1144+
self._cost_func = cost
1145+
1146+
# Set the algorithm parameters
1147+
(self._check_param(param) for param in (beta_param, sigma_bar))
1148+
if not (0 <= sigma_bar <=1):
1149+
raise ValueError('The sigma bar parameter needs to be in [0, 1]')
1150+
self._beta = beta_param
1151+
self._sigma_bar = sigma_bar
1152+
self._xi = self._sigma = self._t_old = 1.0
1153+
self._grad.get_grad(self._x_old)
1154+
self._g_old = self._grad.grad
1155+
1156+
# Automatically run the algorithm
1157+
if auto_iterate:
1158+
self.iterate()
1159+
1160+
def _update(self):
1161+
r"""Update
1162+
1163+
This method updates the current reconstruction
1164+
1165+
Notes
1166+
-----
1167+
Implements algorithm 3 from [K2018]_
1168+
1169+
"""
1170+
# Step 4 from alg. 3
1171+
self._grad.get_grad(self._x_old)
1172+
self._u_new = self._x_old - self._beta * self._grad.grad
1173+
1174+
# Step 5 from alg. 3
1175+
self._t_new = 0.5 * (1 + np.sqrt(1 + 4 * self._t_old**2))
1176+
1177+
# Step 6 from alg. 3
1178+
t_shifted_ratio = (self._t_old - 1) / self._t_new
1179+
sigma_t_ratio = self._sigma * self._t_old / self._t_new
1180+
beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi
1181+
self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z)
1182+
self._z += self._u_new
1183+
self._z += t_shifted_ratio * (self._u_new - self._u_old)
1184+
self._z += sigma_t_ratio * (self._u_new - self._x_old)
1185+
1186+
# Step 7 from alg. 3
1187+
self._xi = self._beta * (1 + t_shifted_ratio + sigma_t_ratio)
1188+
1189+
# Step 8 from alg. 3
1190+
self._x_new = self._prox.op(self._z, extra_factor=self._xi)
1191+
1192+
# Restarting and gamma-Decreasing
1193+
# Step 9 from alg. 3
1194+
self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
1195+
1196+
# Step 10 from alg 3.
1197+
self._y_new = self._x_old - self._beta * self._g_new
1198+
1199+
# Step 11 from alg. 3
1200+
restart_crit = np.vdot(- self._g_new, self._y_new - self._y_old) < 0
1201+
if restart_crit:
1202+
self._t_new = 1
1203+
self._sigma = 1
1204+
1205+
# Step 13 from alg. 3
1206+
elif np.vdot(self._g_new, self._g_old) < 0:
1207+
self._sigma *= self._sigma_bar
1208+
1209+
# updating variables
1210+
self._t_old = self._t_new
1211+
np.copyto(self._u_old, self._u_new)
1212+
np.copyto(self._x_old, self._x_new)
1213+
np.copyto(self._g_old, self._g_new)
1214+
np.copyto(self._y_old, self._y_new)
1215+
1216+
# Test cost function for convergence.
1217+
if self._cost_func:
1218+
self.converge = self.any_convergence_flag() or \
1219+
self._cost_func.get_cost(self._x_new)
1220+
1221+
1222+
def iterate(self, max_iter=150):
1223+
r"""Iterate
1224+
1225+
This method calls update until either convergence criteria is met or
1226+
the maximum number of iterations is reached
1227+
1228+
Parameters
1229+
----------
1230+
max_iter : int, optional
1231+
Maximum number of iterations (default is ``150``)
1232+
1233+
"""
1234+
1235+
self._run_alg(max_iter)
1236+
1237+
# retrieve metrics results
1238+
self.retrieve_outputs()
1239+
# rename outputs as attributes
1240+
self.x_final = self._x_new
1241+
1242+
def get_notify_observers_kwargs(self):
1243+
""" Return the mapping between the metrics call and the iterated
1244+
variables.
1245+
1246+
Return
1247+
----------
1248+
notify_observers_kwargs: dict,
1249+
the mapping between the iterated variables.
1250+
"""
1251+
return {
1252+
'u_new': self._u_new,
1253+
'x_new': self._x_new,
1254+
'y_new': self._y_new,
1255+
'z_new': self._z,
1256+
'xi': self._xi,
1257+
'sigma': self._sigma,
1258+
't': self._t_new,
1259+
'idx': self.idx,
1260+
}
1261+
1262+
def retrieve_outputs(self):
1263+
""" Declare the outputs of the algorithms as attributes: x_final,
1264+
y_final, metrics.
1265+
"""
1266+
1267+
metrics = {}
1268+
for obs in self._observers['cv_metrics']:
1269+
metrics[obs.name] = obs.retrieve_metrics()
1270+
self.metrics = metrics

modopt/tests/test_opt.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ def setUp(self):
9797
prox_dual=prox_dual_inst,
9898
linear=dummy(),
9999
cost=cost_inst, auto_iterate=False)
100+
self.pogm1 = algorithms.POGM(
101+
u=self.data1,
102+
x=self.data1,
103+
y=self.data1,
104+
z=self.data1,
105+
grad=grad_inst,
106+
prox=prox_inst,
107+
)
100108
self.dummy = dummy()
101109
self.dummy.cost = lambda x: x
102110
self.setup._check_operator(self.dummy.cost)
@@ -172,6 +180,12 @@ def test_condat(self):
172180
npt.assert_almost_equal(self.condat2.x_final, self.data1,
173181
err_msg='Incorrect Condat result.')
174182

183+
def test_pogm(self):
184+
npt.assert_almost_equal(
185+
self.pogm1.x_final,
186+
self.data1,
187+
err_msg='Incorrect POGM result.',
188+
)
175189

176190
class CostTestCase(TestCase):
177191

0 commit comments

Comments
 (0)