Skip to content

Commit df7f712

Browse files
committed
implement BackendWrapper for unified usage across backends.
1 parent acd7852 commit df7f712

File tree

1 file changed

+107
-55
lines changed

1 file changed

+107
-55
lines changed

src/optimagic/visualization/history_plots.py

Lines changed: 107 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import itertools
3+
from dataclasses import dataclass
34
from pathlib import Path
45
from typing import Any
56

@@ -59,16 +60,7 @@ def criterion_plot(
5960

6061
results = _harmonize_inputs_to_dict(results, names)
6162

62-
if template is None:
63-
template = PLOT_DEFAULTS[backend]["template"]
64-
if palette is None:
65-
palette = PLOT_DEFAULTS[backend]["palette"]
66-
67-
if isinstance(palette, mpl.colors.Colormap):
68-
palette = [palette(i) for i in range(palette.N)]
69-
if not isinstance(palette, list):
70-
palette = [palette]
71-
palette = itertools.cycle(palette)
63+
template, palette = _get_template_and_palette(backend, template, palette)
7264

7365
fun_or_monotone_fun = "monotone_fun" if monotone else "fun"
7466

@@ -98,7 +90,16 @@ def criterion_plot(
9890
# Create figure
9991
# ==================================================================================
10092

101-
fig, plot_func, label_func = _get_plot_backend(backend)
93+
plot_config = PlotConfig(
94+
template=template,
95+
xlabel="No. of criterion evaluations",
96+
ylabel="Criterion value",
97+
plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
98+
matplotlib_legend={"loc": "upper right"},
99+
)
100+
101+
_backend_wrapper = _get_plot_backend(backend)
102+
backend = _backend_wrapper(plot_config)
102103

103104
plot_multistart = (
104105
len(data) == 1 and data[0]["is_multistart"] and not stack_multistart
@@ -119,8 +120,7 @@ def criterion_plot(
119120
if max_evaluations is not None and len(history) > max_evaluations:
120121
history = history[:max_evaluations]
121122

122-
plot_func(
123-
fig,
123+
backend.plot(
124124
x=np.arange(len(history)),
125125
y=history,
126126
name=None,
@@ -149,25 +149,17 @@ def criterion_plot(
149149

150150
_color = next(palette)
151151

152-
plot_func(
153-
fig,
152+
backend.plot(
154153
x=np.arange(len(history)),
155154
y=history,
156155
name="best result" if plot_multistart else _data["name"],
157156
color=_color,
158157
plotly_scatter_kws=scatter_kws,
159158
)
160159

161-
label_func(
162-
fig,
163-
template=template,
164-
xlabel="No. of criterion evaluations",
165-
ylabel="Criterion value",
166-
plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
167-
matplotlib_legend={"loc": "upper right"},
168-
)
160+
backend.post_plot()
169161

170-
return fig
162+
return backend.return_fig()
171163

172164

173165
def _harmonize_inputs_to_dict(results, names):
@@ -463,19 +455,35 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
463455
)
464456

465457

458+
def _get_template(backend, template):
459+
if template is None:
460+
template = PLOT_DEFAULTS[backend]["template"]
461+
462+
return template
463+
464+
465+
def _get_palette(backend, palette):
466+
if palette is None:
467+
palette = PLOT_DEFAULTS[backend]["palette"]
468+
469+
if isinstance(palette, mpl.colors.Colormap):
470+
palette = [palette(i) for i in range(palette.N)]
471+
if not isinstance(palette, list):
472+
palette = list(palette)
473+
palette = itertools.cycle(palette)
474+
475+
return palette
476+
477+
478+
def _get_template_and_palette(backend, template, palette):
479+
template = _get_template(backend, template)
480+
palette = _get_palette(backend, palette)
481+
482+
return template, palette
483+
484+
466485
def _get_plot_backend(backend):
467-
backends = {
468-
"plotly": (
469-
go.Figure(),
470-
_plot_plotly,
471-
_label_plotly,
472-
),
473-
"matplotlib": (
474-
plt.subplots()[1],
475-
_plot_matplotlib,
476-
_label_matplotlib,
477-
),
478-
}
486+
backends = {"plotly": PlotlyBackend, "matplotlib": MatplotlibBackend}
479487

480488
if backend not in backends:
481489
msg = (
@@ -487,28 +495,72 @@ def _get_plot_backend(backend):
487495
return backends[backend]
488496

489497

490-
def _plot_plotly(fig, *, x, y, name, color, plotly_scatter_kws, **kwargs):
491-
trace = go.Scatter(
492-
x=x, y=y, mode="lines", name=name, line_color=color, **plotly_scatter_kws
493-
)
494-
fig.add_trace(trace)
495-
return fig
498+
@dataclass(frozen=True)
499+
class PlotConfig:
500+
template: str
501+
xlabel: str
502+
ylabel: str
503+
plotly_legend: dict[str, Any]
504+
matplotlib_legend: dict[str, Any]
496505

497506

498-
def _label_plotly(fig, *, template, xlabel, ylabel, plotly_legend, **kwargs):
499-
fig.update_layout(
500-
template=template,
501-
xaxis_title_text=xlabel,
502-
yaxis_title_text=ylabel,
503-
legend=plotly_legend,
504-
)
507+
class BackendWrapper:
508+
def __init__(self, plot_config: PlotConfig):
509+
self.plot_config = plot_config
510+
511+
def create_figure(self):
512+
raise NotImplementedError
513+
514+
def plot(self, **kwargs):
515+
raise NotImplementedError
516+
517+
def post_plot(self):
518+
raise NotImplementedError
519+
520+
521+
class PlotlyBackend(BackendWrapper):
522+
def __init__(self, plot_config: PlotConfig):
523+
super().__init__(plot_config)
524+
self.fig = self.create_figure()
525+
526+
def create_figure(self):
527+
fig = go.Figure()
528+
return fig
529+
530+
def plot(self, *, x, y, name, color, plotly_scatter_kws, **kwargs):
531+
trace = go.Scatter(
532+
x=x, y=y, mode="lines", name=name, line_color=color, **plotly_scatter_kws
533+
)
534+
self.fig.add_trace(trace)
535+
536+
def post_plot(self):
537+
self.fig.update_layout(
538+
template=self.plot_config.template,
539+
xaxis_title_text=self.plot_config.xlabel,
540+
yaxis_title_text=self.plot_config.ylabel,
541+
legend=self.plot_config.plotly_legend,
542+
)
543+
544+
def return_fig(self):
545+
return self.fig
546+
547+
548+
class MatplotlibBackend(BackendWrapper):
549+
def __init__(self, plot_config: PlotConfig):
550+
super().__init__(plot_config)
551+
self.fig, self.ax = self.create_figure()
505552

553+
def create_figure(self):
554+
plt.style.use(self.plot_config.template)
555+
fig, ax = plt.subplots()
556+
return fig, ax
506557

507-
def _plot_matplotlib(ax, *, x, y, name, color, **kwargs):
508-
ax.plot(x, y, label=name, color=color)
509-
return ax
558+
def plot(self, *, x, y, name, color, **kwargs):
559+
self.ax.plot(x, y, label=name, color=color)
510560

561+
def post_plot(self):
562+
self.ax.set(xlabel=self.plot_config.xlabel, ylabel=self.plot_config.ylabel)
563+
self.ax.legend(**self.plot_config.matplotlib_legend)
511564

512-
def _label_matplotlib(ax, *, xlabel, ylabel, matplotlib_legend, **kwargs):
513-
ax.set(xlabel=xlabel, ylabel=ylabel)
514-
ax.legend(**matplotlib_legend)
565+
def return_fig(self):
566+
return self.ax

0 commit comments

Comments
 (0)