Skip to content

Add autocorrelation plot #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,25 @@ A complementary introduction and guide to ``plot_...`` functions is available at
.. autosummary::
:toctree: generated/

plot_autocorr
plot_bf
plot_compare
plot_convergence_dist
plot_dist
plot_energy
plot_ecdf_pit
plot_ess
plot_ess_evolution
plot_forest
plot_loo_pit
plot_ppc_dist
plot_ppc_pava
plot_ppc_pit
plot_ppc_rootogram
plot_prior_posterior
plot_psense_dist
plot_psense_quantities
plot_rank
plot_ridge
plot_trace
plot_trace_dist
plot_trace_dist
23 changes: 23 additions & 0 deletions docs/source/gallery/inference_diagnostics/07_plot_autocorr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
# Autocorrelation Plot

faceted plot with autocorrelation for each variable

---

:::{seealso}
API Documentation: {func}`~arviz_plots.plot_autocorr`
:::
"""
from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
pc = azp.plot_autocorr(
data,
backend="none" # change to preferred backend
)
pc.show()
2 changes: 2 additions & 0 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Batteries-included ArviZ plots."""

from .autocorr_plot import plot_autocorr
from .bf_plot import plot_bf
from .compare_plot import plot_compare
from .convergence_dist_plot import plot_convergence_dist
Expand All @@ -23,6 +24,7 @@
from .trace_plot import plot_trace

__all__ = [
"plot_autocorr",
"plot_bf",
"plot_compare",
"plot_convergence_dist",
Expand Down
246 changes: 246 additions & 0 deletions src/arviz_plots/plots/autocorr_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
"""Autocorrelation plot code."""

from copy import copy
from importlib import import_module

import numpy as np
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords, set_figure_layout
from arviz_plots.visuals import fill_between_y, labelled_title, labelled_x, line, line_xy


def plot_autocorr(
dt,
var_names=None,
filter_vars=None,
group="posterior",
coords=None,
sample_dims=None,
max_lag=None,
plot_collection=None,
backend=None,
labeller=None,
aes_map=None,
plot_kwargs=None,
pc_kwargs=None,
):
"""Autocorrelation plots for the given dataset.

Line plot of the autocorrelation function (ACF)

The ACF plots can be used as a convergence diagnostic for posteriors from MCMC
samples.

Parameters
----------
dt : DataTree
Input data
var_names : str or list of str, optional
One or more variables to be plotted. Currently only one variable is supported.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, “like”, “regex”}, optional, default=None
If None (default), interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str, optional
Which group to use. Defaults to "posterior".
coords : dict, optional
Coordinates to plot.
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
max_lag : int, optional
Maximum lag to compute the ACF. Defaults to 100.
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh", "plotly"}, optional
labeller : labeller, optional
aes_map : mapping of {str : sequence of str}, optional
Mapping of artists to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `plot_kwargs`.

plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:

* lines -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* ref_line -> passed to :func:`~arviz_plots.visuals.line_xy`
* ci -> passed to :func:`~arviz_plots.visuals.fill_between_y`
* xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`

pc_kwargs : mapping
Passed to :class:`arviz_plots.PlotCollection.grid`

Returns
-------
PlotCollection

Examples
--------
Autocorrelation plot for mu variable in the centered eight dataset.

.. plot::
:context: close-figs

>>> from arviz_plots import plot_autocorr, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> dt = load_arviz_data('centered_eight')
>>> plot_autocorr(dt, var_names=["mu"])


.. minigallery:: plot_autocorr

"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
sample_dims = list(sample_dims)
if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = plot_kwargs.copy()
if pc_kwargs is None:
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()

if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend

labeller = BaseLabeller()

# Default max lag to 100
if max_lag is None:
max_lag = 100

distribution = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)

acf_dataset = distribution.azstats.autocorr(dims=sample_dims).sel(draw=slice(0, max_lag - 1))
c_i = 1.96 / acf_dataset.sizes["draw"] ** 0.5
x_ci = np.arange(0, max_lag).astype(float)

plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
default_linestyle = plot_bknd.get_default_aes("linestyle", 2, {})[1]

if plot_collection is None:
pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy()
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs.setdefault("col_wrap", 5)
pc_kwargs.setdefault(
"cols",
["__variable__"]
+ [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)],
)
pc_kwargs.setdefault("rows", None)

if "chain" in distribution:
pc_kwargs["aes"].setdefault("color", ["chain"])
pc_kwargs["aes"].setdefault("overlay", ["chain"])

pc_kwargs = set_figure_layout(pc_kwargs, plot_bknd, distribution)
pc_kwargs["plot_grid_kws"].setdefault("sharex", True)
pc_kwargs["plot_grid_kws"].setdefault("sharey", True)

plot_collection = PlotCollection.wrap(
distribution,
backend=backend,
**pc_kwargs,
)

if aes_map is None:
aes_map = {}
else:
aes_map = aes_map.copy()
aes_map.setdefault("lines", plot_collection.aes_set)

## reference line
ref_ls_kwargs = copy(plot_kwargs.get("ref_line", {}))

if ref_ls_kwargs is not False:
_, _, ac_ls_ignore = filter_aes(plot_collection, aes_map, "ref_line", sample_dims)
ref_ls_kwargs.setdefault("color", "gray")
ref_ls_kwargs.setdefault("linestyle", default_linestyle)

plot_collection.map(
line_xy,
"ref_line",
data=acf_dataset,
x=x_ci,
y=0,
ignore_aes=ac_ls_ignore,
**ref_ls_kwargs,
)

## autocorrelation line
acf_ls_kwargs = copy(plot_kwargs.get("lines", {}))

if acf_ls_kwargs is not False:
_, _, ac_ls_ignore = filter_aes(plot_collection, aes_map, "lines", sample_dims)

plot_collection.map(
line,
"lines",
data=acf_dataset,
ignore_aes=ac_ls_ignore,
**acf_ls_kwargs,
)

# Plot confidence intervals
ci_kwargs = copy(plot_kwargs.get("ci", {}))
_, _, ci_ignore = filter_aes(plot_collection, aes_map, "ci", "draw")
if ci_kwargs is not False:
ci_kwargs.setdefault("color", "black")
ci_kwargs.setdefault("alpha", 0.1)

plot_collection.map(
fill_between_y,
"ci",
data=acf_dataset,
x=x_ci,
y=0,
y_bottom=-c_i,
y_top=c_i,
ignore_aes=ci_ignore,
**ci_kwargs,
)

# set xlabel
_, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims)
xlabel_kwargs = copy(plot_kwargs.get("xlabel", {}))
if xlabel_kwargs is not False:
if "color" not in xlabels_aes:
xlabel_kwargs.setdefault("color", "black")

xlabel_kwargs.setdefault("text", "Lag")
plot_collection.map(
labelled_x,
"xlabel",
ignore_aes=xlabels_ignore,
subset_info=True,
**xlabel_kwargs,
)

# title
title_kwargs = copy(plot_kwargs.get("title", {}))
_, _, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims)

if title_kwargs is not False:
plot_collection.map(
labelled_title,
"title",
ignore_aes=title_ignore,
subset_info=True,
labeller=labeller,
**title_kwargs,
)

return plot_collection
1 change: 1 addition & 0 deletions src/arviz_plots/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def fill_between_y(da, target, backend, *, x=None, y_bottom=None, y=None, y_top=
if np.ndim(np.squeeze(y_bottom)) == 0:
y_bottom = np.full_like(x, y_bottom)
plot_backend = import_module(f"arviz_plots.backend.{backend}")

return plot_backend.fill_between_y(x, y_bottom, y_top, target, **kwargs)


Expand Down
7 changes: 7 additions & 0 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,10 @@ def test_plot_prior_posterior(self, datatree, backend):
pc = plot_prior_posterior(datatree, backend=backend)
assert "chart" in pc.viz.data_vars
assert "Groups" in pc.viz["mu"].coords

def test_autocorr(self, datatree, backend):
pc = plot_trace(datatree, backend=backend)
assert "chart" in pc.viz.data_vars
assert "plot" not in pc.viz.data_vars
assert "hierarchy" not in pc.viz["mu"].dims
assert "hierarchy" in pc.viz["theta"].dims