Skip to content

Commit

Permalink
marker_args dict for plot_1d; allow list marker_args when adding mult…
Browse files Browse the repository at this point in the history
…iple markers
  • Loading branch information
cmbant committed Sep 30, 2024
1 parent 8589dc2 commit f1230a4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion getdist/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def last_modified(files):


def slice_or_none(x, start=None, end=None):
return getattr(x, "__getitem__", lambda _: None)(slice(start, end))
return x[start:end] if hasattr(x, "__getitem__") else None


def findChainFileRoot(chain_dir, root, search_subdirectories=True):
Expand Down
33 changes: 23 additions & 10 deletions getdist/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,8 +1224,7 @@ def add_x_marker(self, marker: Union[float, Sequence[float]], color=None, ls=Non
ls = self.settings.axis_marker_ls
if lw is None:
lw = self.settings.axis_marker_lw
marker = makeList(marker)
for m in marker:
for m in makeList(marker):
self.get_axes(ax).axvline(m, ls=ls, color=color, lw=lw, **kwargs)

def add_y_marker(self, marker: Union[float, Iterable[float]], color=None, ls=None, lw=None, ax=None, **kwargs):
Expand All @@ -1246,8 +1245,7 @@ def add_y_marker(self, marker: Union[float, Iterable[float]], color=None, ls=Non
ls = self.settings.axis_marker_ls
if lw is None:
lw = self.settings.axis_marker_lw
marker = makeList(marker)
for m in marker:
for m in makeList(marker):
self.get_axes(ax).axhline(m, ls=ls, color=color, lw=lw, **kwargs)

def add_param_markers(self, param_value_dict: Dict[str, Union[Iterable[float], float]], *,
Expand Down Expand Up @@ -1583,6 +1581,7 @@ def plot_1d(self, roots, param, marker=None, marker_color=None, label_right=Fals
* **lws**: list of line widths for the different lines plotted
* **alphas**: list of alphas for the different lines plotted
* **line_args**: a list of dictionaries with settings for each set of lines
* **marker_args**: a dictionary with settings for the marker(s)
* arguments for :func:`~GetDistPlotter.set_axes`
.. plot::
Expand Down Expand Up @@ -1623,7 +1622,7 @@ def plot_1d(self, roots, param, marker=None, marker_color=None, label_right=Fals
if plotparam is None:
raise GetDistPlotError('No roots have parameter: ' + str(param))
if marker is not None:
self.add_x_marker(marker, marker_color, ax=ax)
self._add_marker_list(marker, ax, kwargs.get('marker_args', {}), color=marker_color)
if 'lims' in kwargs and kwargs['lims'] is not None:
xmin, xmax = kwargs['lims']
else:
Expand Down Expand Up @@ -2341,6 +2340,18 @@ class SampleNames:
setattr(obj, par.name, samples[:, i])
return obj

def _add_marker_list(self, markers, ax, marker_args, y=False, color=None):
add_proc = self.add_y_marker if y else self.add_x_marker
if isinstance(marker_args, (list, tuple)):
for marker, marker_arg in zip(markers, marker_args):
if color:
marker_arg['color'] = color
add_proc(marker, ax=ax, **marker_arg)
else:
if color:
marker_args['color'] = color
add_proc(markers, ax=ax, **marker_args)

# noinspection PyUnboundLocalVariable
def triangle_plot(self, roots, params=None, legend_labels=None, plot_3d_with_param=None, filled=False, shaded=False,
contour_args=None, contour_colors=None, contour_ls=None, contour_lws=None, line_args=None,
Expand Down Expand Up @@ -2382,8 +2393,9 @@ def triangle_plot(self, roots, params=None, legend_labels=None, plot_3d_with_par
(splits labels between left and right, but avoids labelling 1D y axes top left)
:param diag1d_kwargs: list of dict for arguments when making 1D plots on grid diagonal
:param markers: optional dict giving marker values indexed by parameter, or a list of marker values for
each parameter plotted
:param marker_args: dictionary of optional arguments for adding markers (passed to axvline and/or axhline)
each parameter plotted. Can have list values for multiple markers for each parameter.
:param marker_args: dictionary of optional arguments for adding markers (passed to axvline and/or axhline),
or list of dictionaries when multiple marker values are given for each parameter.
:param param_limits: a dictionary holding a mapping from parameter names to axis limits for that parameter
:param kwargs: optional keyword arguments for :func:`~GetDistPlotter.plot_2d`
or :func:`~GetDistPlotter.plot_3d` (lower triangle only)
Expand Down Expand Up @@ -2492,7 +2504,8 @@ def def_line_args(cont_args, cont_colors):
xlim = self.plot_1d(roots1d, param, marker=marker, do_xlabel=i == plot_col - 1,
no_label_no_numbers=self.settings.no_triangle_axis_labels, title_limit=title_limit,
label_right=True, no_zero=True, no_ylabel=True, no_ytick=True, line_args=line_args,
lims=param_limits.get(param.name), ax=ax, _ret_range=True, **diag1d_kwargs)
lims=param_limits.get(param.name), ax=ax, _ret_range=True,
marker_args=marker_args, **diag1d_kwargs)
lims[i] = xlim
if i > 0:
ax._shared_y_axis = self.subplots[i, 0]
Expand Down Expand Up @@ -2537,9 +2550,9 @@ def def_line_args(cont_args, cont_colors):
no_label_no_numbers=self.settings.no_triangle_axis_labels, shaded=shaded,
add_legend_proxy=i == 0 and i2 == 1, contour_args=contour_args, ax=ax, **kwargs)
if marker is not None:
self.add_x_marker(marker, ax=ax, **marker_args)
self._add_marker_list(marker, ax, marker_args)
if marker2 is not None:
self.add_y_marker(marker2, ax=ax, **marker_args)
self._add_marker_list(marker2, ax, marker_args, y=True)
self._inner_ticks(ax)
if i != i2:
ax.set_ylim(lims[i2])
Expand Down

0 comments on commit f1230a4

Please sign in to comment.