@@ -1069,3 +1069,202 @@ def retrieve_outputs(self):
1069
1069
for obs in self ._observers ['cv_metrics' ]:
1070
1070
metrics [obs .name ] = obs .retrieve_metrics ()
1071
1071
self .metrics = metrics
1072
+
1073
+ class POGM (SetUp ):
1074
+ r"""Proximal Optimised Gradient Method
1075
+
1076
+ This class implements algorithm 3 from [K2018]_
1077
+
1078
+ Parameters
1079
+ ----------
1080
+ u : np.ndarray
1081
+ Initial guess for the u variable
1082
+ x : np.ndarray
1083
+ Initial guess for the x variable (primal)
1084
+ y : np.ndarray
1085
+ Initial guess for the y variable
1086
+ z : np.ndarray
1087
+ Initial guess for the z variable
1088
+ grad : class
1089
+ Gradient operator class
1090
+ prox : class
1091
+ Proximity operator class
1092
+ cost : class or str, optional
1093
+ Cost function class (default is 'auto'); Use 'auto' to automatically
1094
+ generate a costObj instance
1095
+ linear : class instance, optional
1096
+ Linear operator class (default is None)
1097
+ beta_param : float, optional
1098
+ Initial value of the beta parameter (default is 1.0). This corresponds
1099
+ to (1 / L) in [K2018]_
1100
+ sigma_bar : float, optional
1101
+ Value of the shrinking parameter sigma bar (default is 1.0)
1102
+ auto_iterate : bool, optional
1103
+ Option to automatically begin iterations upon initialisation (default
1104
+ is 'True')
1105
+ """
1106
+ def __init__ (
1107
+ self ,
1108
+ u ,
1109
+ x ,
1110
+ y ,
1111
+ z ,
1112
+ grad ,
1113
+ prox ,
1114
+ cost = 'auto' ,
1115
+ linear = None ,
1116
+ beta_param = 1.0 ,
1117
+ sigma_bar = 1.0 ,
1118
+ auto_iterate = True ,
1119
+ metric_call_period = 5 ,
1120
+ metrics = {},
1121
+ ):
1122
+ # Set default algorithm properties
1123
+ super (POGM , self ).__init__ (
1124
+ metric_call_period = metric_call_period ,
1125
+ metrics = metrics ,
1126
+ linear = linear ,
1127
+ )
1128
+
1129
+ # set the initial variable values
1130
+ (self ._check_input_data (data ) for data in (u , x , y , z ))
1131
+ self ._u_old = np .copy (u )
1132
+ self ._x_old = np .copy (x )
1133
+ self ._y_old = np .copy (y )
1134
+ self ._z = np .copy (z )
1135
+
1136
+ # Set the algorithm operators
1137
+ (self ._check_operator (operator ) for operator in (grad , prox , cost ))
1138
+ self ._grad = grad
1139
+ self ._prox = prox
1140
+ self ._linear = linear
1141
+ if cost == 'auto' :
1142
+ self ._cost_func = costObj ([self ._grad , self ._prox ])
1143
+ else :
1144
+ self ._cost_func = cost
1145
+
1146
+ # Set the algorithm parameters
1147
+ (self ._check_param (param ) for param in (beta_param , sigma_bar ))
1148
+ if not (0 <= sigma_bar <= 1 ):
1149
+ raise ValueError ('The sigma bar parameter needs to be in [0, 1]' )
1150
+ self ._beta = beta_param
1151
+ self ._sigma_bar = sigma_bar
1152
+ self ._xi = self ._sigma = self ._t_old = 1.0
1153
+ self ._grad .get_grad (self ._x_old )
1154
+ self ._g_old = self ._grad .grad
1155
+
1156
+ # Automatically run the algorithm
1157
+ if auto_iterate :
1158
+ self .iterate ()
1159
+
1160
+ def _update (self ):
1161
+ r"""Update
1162
+
1163
+ This method updates the current reconstruction
1164
+
1165
+ Notes
1166
+ -----
1167
+ Implements algorithm 3 from [K2018]_
1168
+
1169
+ """
1170
+ # Step 4 from alg. 3
1171
+ self ._grad .get_grad (self ._x_old )
1172
+ self ._u_new = self ._x_old - self ._beta * self ._grad .grad
1173
+
1174
+ # Step 5 from alg. 3
1175
+ self ._t_new = 0.5 * (1 + np .sqrt (1 + 4 * self ._t_old ** 2 ))
1176
+
1177
+ # Step 6 from alg. 3
1178
+ t_shifted_ratio = (self ._t_old - 1 ) / self ._t_new
1179
+ sigma_t_ratio = self ._sigma * self ._t_old / self ._t_new
1180
+ beta_xi_t_shifted_ratio = t_shifted_ratio * self ._beta / self ._xi
1181
+ self ._z = - beta_xi_t_shifted_ratio * (self ._x_old - self ._z )
1182
+ self ._z += self ._u_new
1183
+ self ._z += t_shifted_ratio * (self ._u_new - self ._u_old )
1184
+ self ._z += sigma_t_ratio * (self ._u_new - self ._x_old )
1185
+
1186
+ # Step 7 from alg. 3
1187
+ self ._xi = self ._beta * (1 + t_shifted_ratio + sigma_t_ratio )
1188
+
1189
+ # Step 8 from alg. 3
1190
+ self ._x_new = self ._prox .op (self ._z , extra_factor = self ._xi )
1191
+
1192
+ # Restarting and gamma-Decreasing
1193
+ # Step 9 from alg. 3
1194
+ self ._g_new = self ._grad .grad - (self ._x_new - self ._z ) / self ._xi
1195
+
1196
+ # Step 10 from alg 3.
1197
+ self ._y_new = self ._x_old - self ._beta * self ._g_new
1198
+
1199
+ # Step 11 from alg. 3
1200
+ restart_crit = np .vdot (- self ._g_new , self ._y_new - self ._y_old ) < 0
1201
+ if restart_crit :
1202
+ self ._t_new = 1
1203
+ self ._sigma = 1
1204
+
1205
+ # Step 13 from alg. 3
1206
+ elif np .vdot (self ._g_new , self ._g_old ) < 0 :
1207
+ self ._sigma *= self ._sigma_bar
1208
+
1209
+ # updating variables
1210
+ self ._t_old = self ._t_new
1211
+ np .copyto (self ._u_old , self ._u_new )
1212
+ np .copyto (self ._x_old , self ._x_new )
1213
+ np .copyto (self ._g_old , self ._g_new )
1214
+ np .copyto (self ._y_old , self ._y_new )
1215
+
1216
+ # Test cost function for convergence.
1217
+ if self ._cost_func :
1218
+ self .converge = self .any_convergence_flag () or \
1219
+ self ._cost_func .get_cost (self ._x_new )
1220
+
1221
+
1222
+ def iterate (self , max_iter = 150 ):
1223
+ r"""Iterate
1224
+
1225
+ This method calls update until either convergence criteria is met or
1226
+ the maximum number of iterations is reached
1227
+
1228
+ Parameters
1229
+ ----------
1230
+ max_iter : int, optional
1231
+ Maximum number of iterations (default is ``150``)
1232
+
1233
+ """
1234
+
1235
+ self ._run_alg (max_iter )
1236
+
1237
+ # retrieve metrics results
1238
+ self .retrieve_outputs ()
1239
+ # rename outputs as attributes
1240
+ self .x_final = self ._x_new
1241
+
1242
+ def get_notify_observers_kwargs (self ):
1243
+ """ Return the mapping between the metrics call and the iterated
1244
+ variables.
1245
+
1246
+ Return
1247
+ ----------
1248
+ notify_observers_kwargs: dict,
1249
+ the mapping between the iterated variables.
1250
+ """
1251
+ return {
1252
+ 'u_new' : self ._u_new ,
1253
+ 'x_new' : self ._x_new ,
1254
+ 'y_new' : self ._y_new ,
1255
+ 'z_new' : self ._z ,
1256
+ 'xi' : self ._xi ,
1257
+ 'sigma' : self ._sigma ,
1258
+ 't' : self ._t_new ,
1259
+ 'idx' : self .idx ,
1260
+ }
1261
+
1262
+ def retrieve_outputs (self ):
1263
+ """ Declare the outputs of the algorithms as attributes: x_final,
1264
+ y_final, metrics.
1265
+ """
1266
+
1267
+ metrics = {}
1268
+ for obs in self ._observers ['cv_metrics' ]:
1269
+ metrics [obs .name ] = obs .retrieve_metrics ()
1270
+ self .metrics = metrics
0 commit comments