4
4
5
5
This module contains class implementations of various optimisation algoritms.
6
6
7
- :Author: Samuel Farrens <[email protected] >, Zaccharie Ramzi <[email protected] >
7
+ :Author: Samuel Farrens <[email protected] >,
8
+ Zaccharie Ramzi <[email protected] >
8
9
9
10
NOTES
10
11
-----
@@ -260,58 +261,43 @@ class FISTA(object):
260
261
None , # no restarting
261
262
]
262
263
263
- def __init__ (
264
- self ,
265
- restart_strategy = None ,
266
- min_beta = None ,
267
- s_greedy = None ,
268
- xi_restart = None ,
269
- a_cd = None ,
270
- p_lazy = 1 ,
271
- q_lazy = 1 ,
272
- r_lazy = 4 ,
273
- ):
264
+ def __init__ (self , restart_strategy = None , min_beta = None , s_greedy = None ,
265
+ xi_restart = None , a_cd = None , p_lazy = 1 , q_lazy = 1 , r_lazy = 4 ):
266
+
274
267
if isinstance (a_cd , type (None )):
275
268
self .mode = 'regular'
276
269
self .p_lazy = p_lazy
277
270
self .q_lazy = q_lazy
278
271
self .r_lazy = r_lazy
272
+
279
273
elif a_cd > 2 :
280
274
self .mode = 'CD'
281
275
self .a_cd = a_cd
282
276
self ._n = 0
277
+
283
278
else :
284
- raise ValueError (
285
- "a_cd must either be None (for regular mode) or a number > 2" ,
286
- )
279
+ raise ValueError ('a_cd must either be None (for regular mode) or '
280
+ ' a number > 2' )
281
+
287
282
if restart_strategy in self .__class__ .__restarting_strategies__ :
288
- self ._check_restart_params (
289
- restart_strategy ,
290
- min_beta ,
291
- s_greedy ,
292
- xi_restart ,
293
- )
283
+ self ._check_restart_params (restart_strategy , min_beta , s_greedy ,
284
+ xi_restart )
294
285
self .restart_strategy = restart_strategy
295
286
self .min_beta = min_beta
296
287
self .s_greedy = s_greedy
297
288
self .xi_restart = xi_restart
289
+
298
290
else :
299
- raise ValueError (
300
- "Restarting strategy must be one of %s." %
301
- ", " .join (self .__class__ .__restarting_strategies__ )
302
- )
291
+ raise ValueError ('Restarting strategy must be one of {}.' .format (
292
+ ', ' .join (
293
+ self .__class__ .__restarting_strategies__ )))
303
294
self ._t_now = 1.0
304
295
self ._t_prev = 1.0
305
296
self ._delta_0 = None
306
297
self ._safeguard = False
307
298
308
- def _check_restart_params (
309
- self ,
310
- restart_strategy ,
311
- min_beta ,
312
- s_greedy ,
313
- xi_restart ,
314
- ):
299
+ def _check_restart_params (self , restart_strategy , min_beta , s_greedy ,
300
+ xi_restart ):
315
301
r""" Check restarting parameters
316
302
317
303
This method checks that the restarting parameters are set and satisfy
@@ -346,23 +332,24 @@ def _check_restart_params(
346
332
When a parameter that should be set isn't or doesn't verify the
347
333
correct assumptions.
348
334
"""
335
+
349
336
if restart_strategy is None :
350
337
return True
338
+
351
339
if self .mode != 'regular' :
352
- raise ValueError (
353
- "Restarting strategies can only be used with regular mode."
354
- )
355
- greedy_params_check = (
356
- min_beta is None or s_greedy is None or s_greedy <= 1
357
- )
340
+ raise ValueError ('Restarting strategies can only be used with '
341
+ ' regular mode.' )
342
+
343
+ greedy_params_check = (min_beta is None or s_greedy is None or
344
+ s_greedy <= 1 )
345
+
358
346
if restart_strategy == 'greedy' and greedy_params_check :
359
- raise ValueError (
360
- "You need a min_beta and an s_greedy > 1 for greedy restart."
361
- )
347
+ raise ValueError ('You need a min_beta and an s_greedy > 1 for '
348
+ ' greedy restart.' )
349
+
362
350
if xi_restart is None or xi_restart >= 1 :
363
- raise ValueError (
364
- "You need a xi_restart < 1 for restart."
365
- )
351
+ raise ValueError ('You need a xi_restart < 1 for restart.' )
352
+
366
353
return True
367
354
368
355
def is_restart (self , z_old , x_new , x_old ):
@@ -393,18 +380,22 @@ def is_restart(self, z_old, x_new, x_old):
393
380
"""
394
381
if self .restart_strategy is None :
395
382
return False
383
+
396
384
criterion = np .vdot (z_old - x_new , x_new - x_old ) >= 0
385
+
397
386
if criterion :
398
387
if 'adaptive' in self .restart_strategy :
399
388
self .r_lazy *= self .xi_restart
400
389
if self .restart_strategy in ['adaptive-ii' , 'adaptive-2' ]:
401
390
self ._t_now = 1
391
+
402
392
if self .restart_strategy == 'greedy' :
403
393
cur_delta = np .linalg .norm (x_new - x_old )
404
394
if self ._delta_0 is None :
405
395
self ._delta_0 = self .s_greedy * cur_delta
406
396
else :
407
397
self ._safeguard = cur_delta >= self ._delta_0
398
+
408
399
return criterion
409
400
410
401
def update_beta (self , beta ):
@@ -422,9 +413,11 @@ def update_beta(self, beta):
422
413
-------
423
414
float: the new value for the beta parameter
424
415
"""
416
+
425
417
if self ._safeguard :
426
418
beta *= self .xi_restart
427
419
beta = max (beta , self .min_beta )
420
+
428
421
return beta
429
422
430
423
def update_lambda (self , * args , ** kwargs ):
@@ -441,12 +434,17 @@ def update_lambda(self, *args, **kwargs):
441
434
Implements steps 3 and 4 from algoritm 10.7 in [B2011]_
442
435
443
436
"""
437
+
444
438
if self .restart_strategy == 'greedy' :
445
439
return 2
440
+
446
441
# Steps 3 and 4 from alg.10.7.
447
442
self ._t_prev = self ._t_now
443
+
448
444
if self .mode == 'regular' :
449
- self ._t_now = (self .p_lazy + np .sqrt (self .r_lazy * self ._t_prev ** 2 + self .q_lazy )) * 0.5
445
+ self ._t_now = (self .p_lazy + np .sqrt (self .r_lazy *
446
+ self ._t_prev ** 2 + self .q_lazy )) * 0.5
447
+
450
448
elif self .mode == 'CD' :
451
449
self ._t_now = (self ._n + self .a_cd - 1 ) / self .a_cd
452
450
self ._n += 1
@@ -538,7 +536,7 @@ def __init__(self, x, grad, prox, cost='auto', beta_param=1.0,
538
536
else :
539
537
self ._check_param_update (lambda_update )
540
538
self ._lambda_update = lambda_update
541
- self ._is_restart = lambda * args , ** kwargs :False
539
+ self ._is_restart = lambda * args , ** kwargs : False
542
540
543
541
# Automatically run the algorithm
544
542
if auto_iterate :
@@ -688,8 +686,8 @@ def __init__(self, x, grad, prox_list, cost='auto', gamma_param=1.0,
688
686
self ._x_old = np .copy (x )
689
687
690
688
# Set the algorithm operators
691
- (self ._check_operator (operator ) for operator in [grad , cost ]
692
- + prox_list )
689
+ (self ._check_operator (operator ) for operator in [grad , cost ] +
690
+ prox_list )
693
691
self ._grad = grad
694
692
self ._prox_list = np .array (prox_list )
695
693
self ._linear = linear
@@ -910,7 +908,7 @@ class Condat(SetUp):
910
908
"""
911
909
912
910
def __init__ (self , x , y , grad , prox , prox_dual , linear = None , cost = 'auto' ,
913
- reweight = None , rho = 0.5 , sigma = 1.0 , tau = 1.0 , rho_update = None ,
911
+ reweight = None , rho = 0.5 , sigma = 1.0 , tau = 1.0 , rho_update = None ,
914
912
sigma_update = None , tau_update = None , auto_iterate = True ,
915
913
max_iter = 150 , n_rewightings = 1 , metric_call_period = 5 ,
916
914
metrics = {}):
@@ -1070,6 +1068,7 @@ def retrieve_outputs(self):
1070
1068
metrics [obs .name ] = obs .retrieve_metrics ()
1071
1069
self .metrics = metrics
1072
1070
1071
+
1073
1072
class POGM (SetUp ):
1074
1073
r"""Proximal Optimised Gradient Method
1075
1074
@@ -1103,28 +1102,13 @@ class POGM(SetUp):
1103
1102
Option to automatically begin iterations upon initialisation (default
1104
1103
is 'True')
1105
1104
"""
1106
- def __init__ (
1107
- self ,
1108
- u ,
1109
- x ,
1110
- y ,
1111
- z ,
1112
- grad ,
1113
- prox ,
1114
- cost = 'auto' ,
1115
- linear = None ,
1116
- beta_param = 1.0 ,
1117
- sigma_bar = 1.0 ,
1118
- auto_iterate = True ,
1119
- metric_call_period = 5 ,
1120
- metrics = {},
1121
- ):
1105
+ def __init__ (self , u , x , y , z , grad , prox , cost = 'auto' , linear = None ,
1106
+ beta_param = 1.0 , sigma_bar = 1.0 , auto_iterate = True ,
1107
+ metric_call_period = 5 , metrics = {}):
1108
+
1122
1109
# Set default algorithm properties
1123
- super (POGM , self ).__init__ (
1124
- metric_call_period = metric_call_period ,
1125
- metrics = metrics ,
1126
- linear = linear ,
1127
- )
1110
+ super (POGM , self ).__init__ (metric_call_period = metric_call_period ,
1111
+ metrics = metrics , linear = linear )
1128
1112
1129
1113
# set the initial variable values
1130
1114
(self ._check_input_data (data ) for data in (u , x , y , z ))
@@ -1145,7 +1129,7 @@ def __init__(
1145
1129
1146
1130
# Set the algorithm parameters
1147
1131
(self ._check_param (param ) for param in (beta_param , sigma_bar ))
1148
- if not (0 <= sigma_bar <= 1 ):
1132
+ if not (0 <= sigma_bar <= 1 ):
1149
1133
raise ValueError ('The sigma bar parameter needs to be in [0, 1]' )
1150
1134
self ._beta = beta_param
1151
1135
self ._sigma_bar = sigma_bar
@@ -1169,7 +1153,7 @@ def _update(self):
1169
1153
"""
1170
1154
# Step 4 from alg. 3
1171
1155
self ._grad .get_grad (self ._x_old )
1172
- self ._u_new = self ._x_old - self ._beta * self ._grad .grad
1156
+ self ._u_new = self ._x_old - self ._beta * self ._grad .grad
1173
1157
1174
1158
# Step 5 from alg. 3
1175
1159
self ._t_new = 0.5 * (1 + np .sqrt (1 + 4 * self ._t_old ** 2 ))
@@ -1218,7 +1202,6 @@ def _update(self):
1218
1202
self .converge = self .any_convergence_flag () or \
1219
1203
self ._cost_func .get_cost (self ._x_new )
1220
1204
1221
-
1222
1205
def iterate (self , max_iter = 150 ):
1223
1206
r"""Iterate
1224
1207
0 commit comments