1
1
import inspect
2
2
import itertools
3
+ from dataclasses import dataclass
3
4
from pathlib import Path
4
5
from typing import Any
5
6
@@ -59,16 +60,7 @@ def criterion_plot(
59
60
60
61
results = _harmonize_inputs_to_dict (results , names )
61
62
62
- if template is None :
63
- template = PLOT_DEFAULTS [backend ]["template" ]
64
- if palette is None :
65
- palette = PLOT_DEFAULTS [backend ]["palette" ]
66
-
67
- if isinstance (palette , mpl .colors .Colormap ):
68
- palette = [palette (i ) for i in range (palette .N )]
69
- if not isinstance (palette , list ):
70
- palette = [palette ]
71
- palette = itertools .cycle (palette )
63
+ template , palette = _get_template_and_palette (backend , template , palette )
72
64
73
65
fun_or_monotone_fun = "monotone_fun" if monotone else "fun"
74
66
@@ -98,7 +90,16 @@ def criterion_plot(
98
90
# Create figure
99
91
# ==================================================================================
100
92
101
- fig , plot_func , label_func = _get_plot_backend (backend )
93
+ plot_config = PlotConfig (
94
+ template = template ,
95
+ xlabel = "No. of criterion evaluations" ,
96
+ ylabel = "Criterion value" ,
97
+ plotly_legend = {"yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
98
+ matplotlib_legend = {"loc" : "upper right" },
99
+ )
100
+
101
+ _backend_wrapper = _get_plot_backend (backend )
102
+ backend = _backend_wrapper (plot_config )
102
103
103
104
plot_multistart = (
104
105
len (data ) == 1 and data [0 ]["is_multistart" ] and not stack_multistart
@@ -119,8 +120,7 @@ def criterion_plot(
119
120
if max_evaluations is not None and len (history ) > max_evaluations :
120
121
history = history [:max_evaluations ]
121
122
122
- plot_func (
123
- fig ,
123
+ backend .plot (
124
124
x = np .arange (len (history )),
125
125
y = history ,
126
126
name = None ,
@@ -149,25 +149,17 @@ def criterion_plot(
149
149
150
150
_color = next (palette )
151
151
152
- plot_func (
153
- fig ,
152
+ backend .plot (
154
153
x = np .arange (len (history )),
155
154
y = history ,
156
155
name = "best result" if plot_multistart else _data ["name" ],
157
156
color = _color ,
158
157
plotly_scatter_kws = scatter_kws ,
159
158
)
160
159
161
- label_func (
162
- fig ,
163
- template = template ,
164
- xlabel = "No. of criterion evaluations" ,
165
- ylabel = "Criterion value" ,
166
- plotly_legend = {"yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
167
- matplotlib_legend = {"loc" : "upper right" },
168
- )
160
+ backend .post_plot ()
169
161
170
- return fig
162
+ return backend . return_fig ()
171
163
172
164
173
165
def _harmonize_inputs_to_dict (results , names ):
@@ -463,19 +455,35 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
463
455
)
464
456
465
457
458
+ def _get_template (backend , template ):
459
+ if template is None :
460
+ template = PLOT_DEFAULTS [backend ]["template" ]
461
+
462
+ return template
463
+
464
+
465
+ def _get_palette (backend , palette ):
466
+ if palette is None :
467
+ palette = PLOT_DEFAULTS [backend ]["palette" ]
468
+
469
+ if isinstance (palette , mpl .colors .Colormap ):
470
+ palette = [palette (i ) for i in range (palette .N )]
471
+ if not isinstance (palette , list ):
472
+ palette = list (palette )
473
+ palette = itertools .cycle (palette )
474
+
475
+ return palette
476
+
477
+
478
+ def _get_template_and_palette (backend , template , palette ):
479
+ template = _get_template (backend , template )
480
+ palette = _get_palette (backend , palette )
481
+
482
+ return template , palette
483
+
484
+
466
485
def _get_plot_backend (backend ):
467
- backends = {
468
- "plotly" : (
469
- go .Figure (),
470
- _plot_plotly ,
471
- _label_plotly ,
472
- ),
473
- "matplotlib" : (
474
- plt .subplots ()[1 ],
475
- _plot_matplotlib ,
476
- _label_matplotlib ,
477
- ),
478
- }
486
+ backends = {"plotly" : PlotlyBackend , "matplotlib" : MatplotlibBackend }
479
487
480
488
if backend not in backends :
481
489
msg = (
@@ -487,28 +495,72 @@ def _get_plot_backend(backend):
487
495
return backends [backend ]
488
496
489
497
490
- def _plot_plotly (fig , * , x , y , name , color , plotly_scatter_kws , ** kwargs ):
491
- trace = go .Scatter (
492
- x = x , y = y , mode = "lines" , name = name , line_color = color , ** plotly_scatter_kws
493
- )
494
- fig .add_trace (trace )
495
- return fig
498
+ @dataclass (frozen = True )
499
+ class PlotConfig :
500
+ template : str
501
+ xlabel : str
502
+ ylabel : str
503
+ plotly_legend : dict [str , Any ]
504
+ matplotlib_legend : dict [str , Any ]
496
505
497
506
498
- def _label_plotly (fig , * , template , xlabel , ylabel , plotly_legend , ** kwargs ):
499
- fig .update_layout (
500
- template = template ,
501
- xaxis_title_text = xlabel ,
502
- yaxis_title_text = ylabel ,
503
- legend = plotly_legend ,
504
- )
507
+ class BackendWrapper :
508
+ def __init__ (self , plot_config : PlotConfig ):
509
+ self .plot_config = plot_config
510
+
511
+ def create_figure (self ):
512
+ raise NotImplementedError
513
+
514
+ def plot (self , ** kwargs ):
515
+ raise NotImplementedError
516
+
517
+ def post_plot (self ):
518
+ raise NotImplementedError
519
+
520
+
521
+ class PlotlyBackend (BackendWrapper ):
522
+ def __init__ (self , plot_config : PlotConfig ):
523
+ super ().__init__ (plot_config )
524
+ self .fig = self .create_figure ()
525
+
526
+ def create_figure (self ):
527
+ fig = go .Figure ()
528
+ return fig
529
+
530
+ def plot (self , * , x , y , name , color , plotly_scatter_kws , ** kwargs ):
531
+ trace = go .Scatter (
532
+ x = x , y = y , mode = "lines" , name = name , line_color = color , ** plotly_scatter_kws
533
+ )
534
+ self .fig .add_trace (trace )
535
+
536
+ def post_plot (self ):
537
+ self .fig .update_layout (
538
+ template = self .plot_config .template ,
539
+ xaxis_title_text = self .plot_config .xlabel ,
540
+ yaxis_title_text = self .plot_config .ylabel ,
541
+ legend = self .plot_config .plotly_legend ,
542
+ )
543
+
544
+ def return_fig (self ):
545
+ return self .fig
546
+
547
+
548
+ class MatplotlibBackend (BackendWrapper ):
549
+ def __init__ (self , plot_config : PlotConfig ):
550
+ super ().__init__ (plot_config )
551
+ self .fig , self .ax = self .create_figure ()
505
552
553
+ def create_figure (self ):
554
+ plt .style .use (self .plot_config .template )
555
+ fig , ax = plt .subplots ()
556
+ return fig , ax
506
557
507
- def _plot_matplotlib (ax , * , x , y , name , color , ** kwargs ):
508
- ax .plot (x , y , label = name , color = color )
509
- return ax
558
+ def plot (self , * , x , y , name , color , ** kwargs ):
559
+ self .ax .plot (x , y , label = name , color = color )
510
560
561
+ def post_plot (self ):
562
+ self .ax .set (xlabel = self .plot_config .xlabel , ylabel = self .plot_config .ylabel )
563
+ self .ax .legend (** self .plot_config .matplotlib_legend )
511
564
512
- def _label_matplotlib (ax , * , xlabel , ylabel , matplotlib_legend , ** kwargs ):
513
- ax .set (xlabel = xlabel , ylabel = ylabel )
514
- ax .legend (** matplotlib_legend )
565
+ def return_fig (self ):
566
+ return self .ax
0 commit comments