From 0bc64944c3e258d91b7251ca7966171879578071 Mon Sep 17 00:00:00 2001 From: Dilawar Singh Date: Sat, 24 Jul 2021 08:52:34 +0530 Subject: [PATCH] Fix: epimargin/plots.py:62:0: E0102: function already defined line 14 ************* Module epimargin.plots (function-redefined) epimargin/plots.py:194:0: E0102: function already defined line 14 (function-redefined) Fixes: Use `import matplotlib.pyplot as plt` rather than `from matplotlib.pyplot import *`. Related: https://github.com/COVID-IWG/epimargin/issues/123 --- epimargin/plots.py | 136 +++++++++++++++++++++++---------------------- poetry.lock | 30 ++++++++-- pyproject.toml | 3 +- 3 files changed, 98 insertions(+), 71 deletions(-) diff --git a/epimargin/plots.py b/epimargin/plots.py index 2c69d08..7890cab 100644 --- a/epimargin/plots.py +++ b/epimargin/plots.py @@ -11,12 +11,16 @@ import seaborn as sns import tikzplotlib from matplotlib.patheffects import Normal, Stroke -from matplotlib.pyplot import * + +# pylint will complain about grid overwriting matplotlib.pyplot.grid +# from matplotlib.pyplot import * + +import matplotlib.pyplot as plt from .models import SIR def normalize_dates(dates): - try: + try: return [_.to_pydatetime().date() for _ in dates] except AttributeError: return dates @@ -36,7 +40,7 @@ def normalize_dates(dates): BLK = "#292f36" BLK_CI = "#aeb7c2" -### stoplight +### stoplight RED = "#D63231" YLW = "#FD8B5A" GRN = "#38AE66" @@ -48,16 +52,18 @@ def normalize_dates(dates): ANOMALY_RED = "#D63231" PRED_PURPLE = "#554B68" -## policy simulations +## policy simulations SIM_PALETTE = ["#437034", "#7D4343", "#43587D", "#7D4370"] -# typography -def rebuild_font_cache(): - import matplotlib.font_manager - matplotlib.font_manager._rebuild() +# FIXME: Maptlotlib 3>.1.x as rebuild_if_missing set to True. +# Also this is not used anywhere. +## typography +#def rebuild_font_cache(): +# import matplotlib.font_manager +# matplotlib.font_manager._rebuild() def despine(**kwargs): - pass + pass def grid(flag): if flag: @@ -67,7 +73,7 @@ def grid(flag): # container class for different theme Aesthetics = namedtuple( - "Aesthetics", + "Aesthetics", ["title", "label", "note", "ticks", "style", "palette", "accent", "despine", "framealpha", "handlelength"] ) @@ -135,7 +141,7 @@ def set_theme(name): theme = substack_settings elif name == "minimal": theme = minimal_settings - else: # default + else: # default theme = default_settings sns.set(style = theme.style, palette = theme.palette, font = theme.ticks["family"]) mpl.rcParams.update({"font.size": 22}) @@ -166,20 +172,20 @@ def rgb_to_dec(value): def get_continuous_cmap(hex_list, float_list=None): ''' creates and returns a color map that can be used in heat map figures. If float_list is not provided, colour map graduates linearly between each color in hex_list. - If float_list is provided, each color in hex_list is mapped to the respective location in float_list. - + If float_list is provided, each color in hex_list is mapped to the respective location in float_list. + Parameters ---------- hex_list: list of hex code strings float_list: list of floats between 0 and 1, same length as hex_list. Must start with 0 and end with 1. - + Returns ---------- colour map''' rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list] if not float_list: float_list = list(np.linspace(0,1,len(rgb_list))) - + cdict = dict() for num, col in enumerate(['red', 'green', 'blue']): col_list = [[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))] @@ -193,7 +199,7 @@ def get_continuous_cmap(hex_list, float_list=None): default_cmap = get_continuous_cmap([GRN, YLW, RED, RED], [0, 0.8, 0.9, 1]) def get_cmap(vmin = 0, vmax = 3, cmap = default_cmap): return mpl.cm.ScalarMappable( - norm = mpl.colors.Normalize(vmin, vmax), + norm = mpl.colors.Normalize(vmin, vmax), cmap = cmap ) @@ -213,7 +219,7 @@ def __init__(self, fig: Optional[mpl.figure.Figure] = None): def axis_labels(self, x, y, enforce_spacing = True, **kwargs): kwargs["fontdict"] = kwargs.get("fontdict", theme.label) if enforce_spacing and not x.startswith("\n"): - x = "\n" + x + x = "\n" + x if enforce_spacing and not y.endswith("\n"): y = y + "\n" return self.xlabel(x, **kwargs).ylabel(y, **kwargs) @@ -222,13 +228,13 @@ def xlabel(self, xl: str, **kwargs): kwargs["fontdict"] = kwargs.get("fontdict", theme.label) plt.xlabel(xl, **kwargs) plt.gca().xaxis.label.set_color("dimgray") - return self + return self def ylabel(self, yl: str, **kwargs): kwargs["fontdict"] = kwargs.get("fontdict", theme.label) plt.ylabel(yl, **kwargs) plt.gca().yaxis.label.set_color("dimgray") - return self + return self # stack title/subtitle vertically def title(self, text: str, **kwargs): @@ -242,15 +248,15 @@ def title(self, text: str, **kwargs): kwargs["fontdict"] = kwargs.get("fontdict", theme.title) kwargs["fontweight"] = kwargs.get("fontweight", theme.title["weight"]) plt.suptitle(text, **kwargs) - return self - + return self + def annotate(self, text: str, **kwargs): kwargs["fontdict"] = kwargs.get("fontdict", theme.note) kwargs["loc"] = kwargs.get("loc", "left") plt.title(text, **kwargs) return self - # stack title/subtitle horizontally + # stack title/subtitle horizontally def l_title(self, text: str, **kwargs): kwargs["loc"] = "left" kwargs["ha"] = kwargs.get("ha", "left") @@ -259,8 +265,8 @@ def l_title(self, text: str, **kwargs): kwargs["fontdict"] = kwargs.get("fontdict", theme.title) kwargs["fontweight"] = kwargs.get("fontweight", theme.title["weight"]) plt.title(text, **kwargs) - return self - + return self + def r_title(self, text: str, **kwargs): kwargs["loc"] = "right" kwargs["ha"] = kwargs.get("ha", "right") @@ -268,38 +274,38 @@ def r_title(self, text: str, **kwargs): kwargs["fontdict"] = kwargs.get("fontdict", theme.note) kwargs["color"] = theme.accent plt.title(text, **kwargs) - return self - + return self + def size(self, w, h): self.figure.set_size_inches(w, h) return self - + def legend(self, *args, **kwargs): kwargs["framealpha"] = kwargs.get("framealpha", theme.framealpha) kwargs["handlelength"] = kwargs.get("handlelength", theme.handlelength) plt.legend(*args, **kwargs) return self - + def format_xaxis(self, fmt = DATE_FMT): plt.gca().xaxis.set_major_formatter(DATE_FMT) plt.gca().xaxis.set_minor_formatter(DATE_FMT) - return self + return self def save(self, filename: Path, **kwargs): if str(filename).endswith("tex"): tikzplotlib.save(filename, **kwargs) - return self + return self kwargs["transparent"] = kwargs.get("transparent", str(filename).endswith("svg")) plt.savefig(filename, **kwargs) - return self + return self def adjust(self, **kwargs): plt.subplots_adjust(**kwargs) - return self + return self def show(self, **kwargs): plt.show(**kwargs) - return self + return self def plot_SIRD(model: SIR, layout = (1, 4)) -> PlotDevice: """ plot all 4 available curves (S, I, R, D) for a given SIR model """ @@ -311,7 +317,7 @@ def plot_SIRD(model: SIR, layout = (1, 4)) -> PlotDevice: d = ax.semilogy(t, model.D, alpha=0.75, label="Deaths", ) r = ax.semilogy(t, model.R, alpha=0.75, label="Recovered", ) ax.label_outer() - + fig.legend([s, i, r, d], labels = ["S", "I", "R", "D"], loc = "center right", borderaxespad = 0.1) return PlotDevice(fig) @@ -320,7 +326,7 @@ def plot_curve(models: Sequence[SIR], labels: Sequence[str], curve: str = "I"): fig = plt.figure() for (model, label) in zip(models, labels): plt.semilogy(model.aggregate(curve), label = label, figure = fig) - plt.legend() + plt.legend() plt.tight_layout() return PlotDevice(fig) @@ -336,21 +342,21 @@ def gantt_chart(gantt_data, start_date: Optional[str] = None, show_cbar = True): else: xticklabels = sorted(gantt_df.day.unique()) xlabel = "Days Since Beginning of Adaptive Control" - ax = sns.heatmap(gantt_pv["beta"], linewidths = 2, alpha = 0.8, + ax = sns.heatmap(gantt_pv["beta"], linewidths = 2, alpha = 0.8, annot = gantt_pv["R"], annot_kws={"size": 8}, cmap = ["#38AE66", "#FFF3B4", "#FD8B5A", "#D63231"], cbar = show_cbar, yticklabels = gantt_df["district"].unique(), xticklabels = xticklabels, cbar_kws = { - "ticks":[0.5, 1, 2, 2.5], - "label": "Mobility", - "format": mpl.ticker.FuncFormatter(lambda x, pos: {0.5:"voluntary", 1:"cautionary", 2:"partial", 2.5:"restricted"}[x]), - "orientation": "horizontal", - "aspect": 50, + "ticks":[0.5, 1, 2, 2.5], + "label": "Mobility", + "format": mpl.ticker.FuncFormatter(lambda x, pos: {0.5:"voluntary", 1:"cautionary", 2:"partial", 2.5:"restricted"}[x]), + "orientation": "horizontal", + "aspect": 50, "drawedges": True, "fraction": 0.05, - "pad": 0.10, + "pad": 0.10, "shrink": 0.5 } ) @@ -368,12 +374,12 @@ def predictions(date_range, model, color, bounds = [2.5, 97.5], curve = "dT"): return [(range_marker, median_marker), model.name] def simulations( - simulation_results: Sequence[Tuple[SIR]], - labels: Sequence[str], - historical: Optional[pd.Series] = None, - historical_label: str = "Empirical Case Data", - curve: str = "dT", - smoothing: Optional[np.ndarray] = None, + simulation_results: Sequence[Tuple[SIR]], + labels: Sequence[str], + historical: Optional[pd.Series] = None, + historical_label: str = "Empirical Case Data", + curve: str = "dT", + smoothing: Optional[np.ndarray] = None, semilog: bool = True) -> PlotDevice: """ plot simulation results for new daily cases and optionally show historical trends """ @@ -408,13 +414,13 @@ def simulations( p = plt.plot([pd.Timestamp(t) for t in smoothing[:, 0]], smoothing[:, 1], 'k-', linewidth = 1) legends.append(p) legend_labels.append("smoothed_data") - + for (rng, label, color) in zip(ranges, labels, SIM_PALETTE): p, = plt.plot(t, rng["avg"], color = color, linewidth = 2) f = plt.fill_between(t, rng["min"], rng["max"], color = color, alpha = 0.2) legends.append((p, f)) legend_labels.append(label) - + plt.gca().xaxis.set_major_formatter(DATE_FMT) plt.gca().xaxis.set_minor_formatter(DATE_FMT) plt.legend(legends, legend_labels, prop = dict(size = 20), handlelength = theme.handlelength, framealpha = theme.framealpha, loc = "best") @@ -429,7 +435,7 @@ def Rt(dates, Rt_pred, Rt_CI_upper, Rt_CI_lower, CI, ymin = 0.5, ymax = 3, yaxis """ plot Rt and associated confidence intervals over time """ CI_marker = plt.fill_between(dates, Rt_CI_lower, Rt_CI_upper, color = BLK, alpha = 0.3) Rt_marker, = plt.plot(dates, Rt_pred, color = BLK, linewidth = 2, zorder = 5, solid_capstyle = "butt") - if yaxis_colors: + if yaxis_colors: plt.plot([dates[0], dates[0]], [2.5, ymax], color = RED, linewidth = 6, alpha = 0.9, solid_capstyle="butt", zorder = 10) plt.plot([dates[0], dates[0]], [1, 2.5], color = YLW, linewidth = 6, alpha = 0.9, solid_capstyle="butt", zorder = 10) plt.plot([dates[0], dates[0]], [ymin, 1], color = GRN, linewidth = 6, alpha = 0.9, solid_capstyle="butt", zorder = 10) @@ -447,13 +453,13 @@ def Rt(dates, Rt_pred, Rt_CI_upper, Rt_CI_lower, CI, ymin = 0.5, ymax = 3, yaxis plt.gca().xaxis.set_minor_formatter(DATE_FMT) set_tick_size(theme.ticks["size"]) pd.markers = {"Rt" : (CI_marker, Rt_marker)} - return pd + return pd -def daily_cases(dates, T_pred, T_CI_upper, T_CI_lower, new_cases_ts, anomaly_dates, anomalies, CI, prediction_ts = None): +def daily_cases(dates, T_pred, T_CI_upper, T_CI_lower, new_cases_ts, anomaly_dates, anomalies, CI, prediction_ts = None): """ plots expected, smoothed cases from simulated annealing training """ new_cases_dates = dates[-len(new_cases_ts):] exp_cases_dates = dates[-len(T_pred):] - valid_idx = [i for i in range(len(dates)) if dates[i] not in anomaly_dates] + valid_idx = [i for i in range(len(dates)) if dates[i] not in anomaly_dates] T_CI_lower_rect = [min(l, u) for (l, u) in zip(T_CI_lower, T_CI_upper)] T_CI_upper_rect = [max(l, u) for (l, u) in zip(T_CI_lower, T_CI_upper)] observed_marker, = plt.plot([d for d in new_cases_dates if d not in anomaly_dates], [new_cases_ts[i] for i in range(len(new_cases_ts)) if new_cases_dates[i] not in anomaly_dates], color = OBS_BLK, linewidth = 2, zorder = 8) @@ -473,12 +479,12 @@ def daily_cases(dates, T_pred, T_CI_upper, T_CI_lower, new_cases_ts, anomaly_dat plt.vlines(dates[-1], ymin = 0, ymax = top, color = "black", linestyles = "dotted") legends += [(predicted_CI_marker, predicted_marker)] labels += [label] - else: + else: end = dates[-1] plt.ylim(bottom = 0, top = top) legends += [anomalies_marker] labels += ["anomalies"] - xlim(left = dates[0], right = end) + plt.xlim(left = dates[0], right = end) plt.legend(legends, labels, prop = {'size': 14}, framealpha = theme.framealpha, handlelength = theme.handlelength, loc = "best") plt.gca().xaxis.set_major_formatter(DATE_FMT) plt.gca().xaxis.set_minor_formatter(DATE_FMT) @@ -494,25 +500,25 @@ def choropleth(gdf, label_fn = lambda _: "", col = "Rt", title = "$R_t$", label_ ax.set_xticks([]) ax.set_yticks([]) if title: - ax.set_title(title, loc="left", fontdict = theme.label) + ax.set_title(title, loc="left", fontdict = theme.label) gdf.plot(color=[mappable.to_rgba(_) for _ in gdf[col]], ax = ax, edgecolors="black", linewidth=0.5, missing_kwds = {"color": theme.accent, "edgecolor": "white"}) if label_fn is not None: for (_, row) in gdf.iterrows(): label = label_fn(row) value = round(row[col], 2) ax.annotate( - s = f"{label}{value}", - xy = list(row["pt"].coords)[0], - ha = "center", - fontfamily = theme.note["family"], - color = "black", **label_kwargs, + s = f"{label}{value}", + xy = list(row["pt"].coords)[0], + ha = "center", + fontfamily = theme.note["family"], + color = "black", **label_kwargs, fontweight = "semibold", size = 12)\ .set_path_effects([Stroke(linewidth = 2, foreground = "white"), Normal()]) cbar_ax = fig.add_axes([0.90, 0.25, 0.01, 0.5]) cb = fig.colorbar(mappable = mappable, orientation = "vertical", cax = cbar_ax) cbar_ax.set_title("$R_t$", fontdict = theme.note) - + return PlotDevice(fig) def double_choropleth(gdf, label_fn = lambda _: "", Rt_col = "Rt", Rt_proj_col = "Rt_proj", titles = ["Current $R_t$", "Projected $R_t$ (1 Week)"], arrangement = (1, 2), label_kwargs = {}, mappable = sm): @@ -523,7 +529,7 @@ def double_choropleth(gdf, label_fn = lambda _: "", Rt_col = "Rt", Rt_proj_col = ax.grid(False) ax.set_xticks([]) ax.set_yticks([]) - ax.set_title(title, loc="left", fontdict = theme.label) + ax.set_title(title, loc="left", fontdict = theme.label) gdf.plot(color=[mappable.to_rgba(_) for _ in gdf[col]], ax = ax, edgecolors="black", linewidth=0.5, missing_kwds = {"color": theme.accent, "edgecolor": "white"}) if label_fn is not None: for (_, row) in gdf.iterrows(): @@ -536,7 +542,7 @@ def double_choropleth(gdf, label_fn = lambda _: "", Rt_col = "Rt", Rt_proj_col = cbar_ax = fig.add_axes([0.95, 0.25, 0.01, 0.5]) cb = fig.colorbar(mappable = mappable, orientation = "vertical", cax = cbar_ax) cbar_ax.set_title("$R_t$", fontdict = theme.note) - + return PlotDevice(fig) def double_choropleth_v(*args, **kwargs): diff --git a/poetry.lock b/poetry.lock index 04eace9..2190a47 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,6 +1,6 @@ [[package]] name = "arviz" -version = "0.11.2" +version = "0.11.0" description = "Exploratory analysis of Bayesian models" category = "main" optional = false @@ -17,7 +17,7 @@ typing-extensions = ">=3.7.4.3,<4" xarray = ">=0.16.1" [package.extras] -all = ["numba", "bokeh (>=1.4.0)", "ujson", "dask", "zarr (>=2.5.0)"] +all = ["numba", "bokeh (>=1.4.0)", "ujson", "dask"] [[package]] name = "astroid" @@ -719,6 +719,22 @@ category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +[[package]] +name = "tqdm" +version = "4.61.2" +description = "Fast, Extensible Progress Meter" +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["py-make (>=0.1.0)", "twine", "wheel"] +notebook = ["ipywidgets (>=6)"] +telegram = ["requests"] + [[package]] name = "typed-ast" version = "1.4.3" @@ -791,12 +807,12 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = "^3.6.1" -content-hash = "e6fb73ae657393212f2cfccac1faac3f7927e5100c434c9180a1e27cd481acbf" +content-hash = "9b183a789c74b6daad05051ce3142d987552a2408e0e83d9b90e904be6604dc5" [metadata.files] arviz = [ - {file = "arviz-0.11.2-py3-none-any.whl", hash = "sha256:f6a1389a90b53335f248d282c8142b8209150b9625459a85ec6d3d38786797c1"}, - {file = "arviz-0.11.2.tar.gz", hash = "sha256:a9d0eb32e84a0472aa78a488ba9b12b05e7be8c2c8fb34a1ba6286cc1254ee0d"}, + {file = "arviz-0.11.0-py3-none-any.whl", hash = "sha256:62f302e784d0bf4b498c7fa2b92e808696aa361771107342d38f4f8610e27034"}, + {file = "arviz-0.11.0.tar.gz", hash = "sha256:a8b49affe6735093c8f0cb7a96c5267214590ff5d4e3e5fb89dfd2cb3b0f2d9a"}, ] astroid = [ {file = "astroid-2.6.5-py3-none-any.whl", hash = "sha256:7b963d1c590d490f60d2973e57437115978d3a2529843f160b5003b721e1e925"}, @@ -1410,6 +1426,10 @@ toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] +tqdm = [ + {file = "tqdm-4.61.2-py2.py3-none-any.whl", hash = "sha256:5aa445ea0ad8b16d82b15ab342de6b195a722d75fc1ef9934a46bba6feafbc64"}, + {file = "tqdm-4.61.2.tar.gz", hash = "sha256:8bb94db0d4468fea27d004a0f1d1c02da3cdedc00fe491c0de986b76a04d6b0a"}, +] typed-ast = [ {file = "typed_ast-1.4.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:2068531575a125b87a41802130fa7e29f26c09a2833fea68d9a40cf33902eba6"}, {file = "typed_ast-1.4.3-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:c907f561b1e83e93fad565bac5ba9c22d96a54e7ea0267c708bffe863cbe4075"}, diff --git a/pyproject.toml b/pyproject.toml index 5124236..d885787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ python = "^3.6.1" numpy = "^1.19.5" pandas = "^1.1.5" matplotlib = "^3.1.3" -arviz = "^0.11.2" pymc3 = "^3.9.3" statsmodels = "0.12.2" geopandas = "^0.9.0" @@ -20,11 +19,13 @@ seaborn = "^0.11.1" tikzplotlib = "^0.9.9" semver = "2.12.0" scikit-learn = "^0.24.2" +arviz = {version = "0.11.0", optional = true, extras = ["studies"]} [tool.poetry.dev-dependencies] mypy = "^0.910" pytest = "^6.2.4" pylint = "^2.9.5" +tqdm = "^4.61.2" [build-system] requires = ["poetry-core>=1.0.0"]