Skip to content
This repository has been archived by the owner on Sep 1, 2021. It is now read-only.

Commit

Permalink
Fix: epimargin/plots.py:62:0: E0102: function already defined line 14
Browse files Browse the repository at this point in the history
************* Module epimargin.plots
(function-redefined)
epimargin/plots.py:194:0: E0102: function already defined line 14
(function-redefined)

Fixes: Use `import matplotlib.pyplot as plt` rather than `from matplotlib.pyplot import *`.
Related: COVID-IWG#123
  • Loading branch information
Dilawar Singh committed Jul 24, 2021
1 parent ce000ee commit 0bc6494
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 71 deletions.
136 changes: 71 additions & 65 deletions epimargin/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
import seaborn as sns
import tikzplotlib
from matplotlib.patheffects import Normal, Stroke
from matplotlib.pyplot import *

# pylint will complain about grid overwriting matplotlib.pyplot.grid
# from matplotlib.pyplot import *

import matplotlib.pyplot as plt

from .models import SIR

def normalize_dates(dates):
try:
try:
return [_.to_pydatetime().date() for _ in dates]
except AttributeError:
return dates
Expand All @@ -36,7 +40,7 @@ def normalize_dates(dates):
BLK = "#292f36"
BLK_CI = "#aeb7c2"

### stoplight
### stoplight
RED = "#D63231"
YLW = "#FD8B5A"
GRN = "#38AE66"
Expand All @@ -48,16 +52,18 @@ def normalize_dates(dates):
ANOMALY_RED = "#D63231"
PRED_PURPLE = "#554B68"

## policy simulations
## policy simulations
SIM_PALETTE = ["#437034", "#7D4343", "#43587D", "#7D4370"]

# typography
def rebuild_font_cache():
import matplotlib.font_manager
matplotlib.font_manager._rebuild()
# FIXME: Maptlotlib 3>.1.x as rebuild_if_missing set to True.
# Also this is not used anywhere.
## typography
#def rebuild_font_cache():
# import matplotlib.font_manager
# matplotlib.font_manager._rebuild()

def despine(**kwargs):
pass
pass

def grid(flag):
if flag:
Expand All @@ -67,7 +73,7 @@ def grid(flag):

# container class for different theme
Aesthetics = namedtuple(
"Aesthetics",
"Aesthetics",
["title", "label", "note", "ticks", "style", "palette", "accent", "despine", "framealpha", "handlelength"]
)

Expand Down Expand Up @@ -135,7 +141,7 @@ def set_theme(name):
theme = substack_settings
elif name == "minimal":
theme = minimal_settings
else: # default
else: # default
theme = default_settings
sns.set(style = theme.style, palette = theme.palette, font = theme.ticks["family"])
mpl.rcParams.update({"font.size": 22})
Expand Down Expand Up @@ -166,20 +172,20 @@ def rgb_to_dec(value):
def get_continuous_cmap(hex_list, float_list=None):
''' creates and returns a color map that can be used in heat map figures.
If float_list is not provided, colour map graduates linearly between each color in hex_list.
If float_list is provided, each color in hex_list is mapped to the respective location in float_list.
If float_list is provided, each color in hex_list is mapped to the respective location in float_list.
Parameters
----------
hex_list: list of hex code strings
float_list: list of floats between 0 and 1, same length as hex_list. Must start with 0 and end with 1.
Returns
----------
colour map'''
rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list]
if not float_list:
float_list = list(np.linspace(0,1,len(rgb_list)))

cdict = dict()
for num, col in enumerate(['red', 'green', 'blue']):
col_list = [[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))]
Expand All @@ -193,7 +199,7 @@ def get_continuous_cmap(hex_list, float_list=None):
default_cmap = get_continuous_cmap([GRN, YLW, RED, RED], [0, 0.8, 0.9, 1])
def get_cmap(vmin = 0, vmax = 3, cmap = default_cmap):
return mpl.cm.ScalarMappable(
norm = mpl.colors.Normalize(vmin, vmax),
norm = mpl.colors.Normalize(vmin, vmax),
cmap = cmap
)

Expand All @@ -213,7 +219,7 @@ def __init__(self, fig: Optional[mpl.figure.Figure] = None):
def axis_labels(self, x, y, enforce_spacing = True, **kwargs):
kwargs["fontdict"] = kwargs.get("fontdict", theme.label)
if enforce_spacing and not x.startswith("\n"):
x = "\n" + x
x = "\n" + x
if enforce_spacing and not y.endswith("\n"):
y = y + "\n"
return self.xlabel(x, **kwargs).ylabel(y, **kwargs)
Expand All @@ -222,13 +228,13 @@ def xlabel(self, xl: str, **kwargs):
kwargs["fontdict"] = kwargs.get("fontdict", theme.label)
plt.xlabel(xl, **kwargs)
plt.gca().xaxis.label.set_color("dimgray")
return self
return self

def ylabel(self, yl: str, **kwargs):
kwargs["fontdict"] = kwargs.get("fontdict", theme.label)
plt.ylabel(yl, **kwargs)
plt.gca().yaxis.label.set_color("dimgray")
return self
return self

# stack title/subtitle vertically
def title(self, text: str, **kwargs):
Expand All @@ -242,15 +248,15 @@ def title(self, text: str, **kwargs):
kwargs["fontdict"] = kwargs.get("fontdict", theme.title)
kwargs["fontweight"] = kwargs.get("fontweight", theme.title["weight"])
plt.suptitle(text, **kwargs)
return self
return self

def annotate(self, text: str, **kwargs):
kwargs["fontdict"] = kwargs.get("fontdict", theme.note)
kwargs["loc"] = kwargs.get("loc", "left")
plt.title(text, **kwargs)
return self

# stack title/subtitle horizontally
# stack title/subtitle horizontally
def l_title(self, text: str, **kwargs):
kwargs["loc"] = "left"
kwargs["ha"] = kwargs.get("ha", "left")
Expand All @@ -259,47 +265,47 @@ def l_title(self, text: str, **kwargs):
kwargs["fontdict"] = kwargs.get("fontdict", theme.title)
kwargs["fontweight"] = kwargs.get("fontweight", theme.title["weight"])
plt.title(text, **kwargs)
return self
return self

def r_title(self, text: str, **kwargs):
kwargs["loc"] = "right"
kwargs["ha"] = kwargs.get("ha", "right")
kwargs["va"] = kwargs.get("va", "bottom")
kwargs["fontdict"] = kwargs.get("fontdict", theme.note)
kwargs["color"] = theme.accent
plt.title(text, **kwargs)
return self
return self

def size(self, w, h):
self.figure.set_size_inches(w, h)
return self

def legend(self, *args, **kwargs):
kwargs["framealpha"] = kwargs.get("framealpha", theme.framealpha)
kwargs["handlelength"] = kwargs.get("handlelength", theme.handlelength)
plt.legend(*args, **kwargs)
return self

def format_xaxis(self, fmt = DATE_FMT):
plt.gca().xaxis.set_major_formatter(DATE_FMT)
plt.gca().xaxis.set_minor_formatter(DATE_FMT)
return self
return self

def save(self, filename: Path, **kwargs):
if str(filename).endswith("tex"):
tikzplotlib.save(filename, **kwargs)
return self
return self
kwargs["transparent"] = kwargs.get("transparent", str(filename).endswith("svg"))
plt.savefig(filename, **kwargs)
return self
return self

def adjust(self, **kwargs):
plt.subplots_adjust(**kwargs)
return self
return self

def show(self, **kwargs):
plt.show(**kwargs)
return self
return self

def plot_SIRD(model: SIR, layout = (1, 4)) -> PlotDevice:
""" plot all 4 available curves (S, I, R, D) for a given SIR model """
Expand All @@ -311,7 +317,7 @@ def plot_SIRD(model: SIR, layout = (1, 4)) -> PlotDevice:
d = ax.semilogy(t, model.D, alpha=0.75, label="Deaths", )
r = ax.semilogy(t, model.R, alpha=0.75, label="Recovered", )
ax.label_outer()

fig.legend([s, i, r, d], labels = ["S", "I", "R", "D"], loc = "center right", borderaxespad = 0.1)
return PlotDevice(fig)

Expand All @@ -320,7 +326,7 @@ def plot_curve(models: Sequence[SIR], labels: Sequence[str], curve: str = "I"):
fig = plt.figure()
for (model, label) in zip(models, labels):
plt.semilogy(model.aggregate(curve), label = label, figure = fig)
plt.legend()
plt.legend()
plt.tight_layout()
return PlotDevice(fig)

Expand All @@ -336,21 +342,21 @@ def gantt_chart(gantt_data, start_date: Optional[str] = None, show_cbar = True):
else:
xticklabels = sorted(gantt_df.day.unique())
xlabel = "Days Since Beginning of Adaptive Control"
ax = sns.heatmap(gantt_pv["beta"], linewidths = 2, alpha = 0.8,
ax = sns.heatmap(gantt_pv["beta"], linewidths = 2, alpha = 0.8,
annot = gantt_pv["R"], annot_kws={"size": 8},
cmap = ["#38AE66", "#FFF3B4", "#FD8B5A", "#D63231"],
cbar = show_cbar,
yticklabels = gantt_df["district"].unique(),
xticklabels = xticklabels,
cbar_kws = {
"ticks":[0.5, 1, 2, 2.5],
"label": "Mobility",
"format": mpl.ticker.FuncFormatter(lambda x, pos: {0.5:"voluntary", 1:"cautionary", 2:"partial", 2.5:"restricted"}[x]),
"orientation": "horizontal",
"aspect": 50,
"ticks":[0.5, 1, 2, 2.5],
"label": "Mobility",
"format": mpl.ticker.FuncFormatter(lambda x, pos: {0.5:"voluntary", 1:"cautionary", 2:"partial", 2.5:"restricted"}[x]),
"orientation": "horizontal",
"aspect": 50,
"drawedges": True,
"fraction": 0.05,
"pad": 0.10,
"pad": 0.10,
"shrink": 0.5
}
)
Expand All @@ -368,12 +374,12 @@ def predictions(date_range, model, color, bounds = [2.5, 97.5], curve = "dT"):
return [(range_marker, median_marker), model.name]

def simulations(
simulation_results: Sequence[Tuple[SIR]],
labels: Sequence[str],
historical: Optional[pd.Series] = None,
historical_label: str = "Empirical Case Data",
curve: str = "dT",
smoothing: Optional[np.ndarray] = None,
simulation_results: Sequence[Tuple[SIR]],
labels: Sequence[str],
historical: Optional[pd.Series] = None,
historical_label: str = "Empirical Case Data",
curve: str = "dT",
smoothing: Optional[np.ndarray] = None,
semilog: bool = True) -> PlotDevice:
""" plot simulation results for new daily cases and optionally show historical trends """

Expand Down Expand Up @@ -408,13 +414,13 @@ def simulations(
p = plt.plot([pd.Timestamp(t) for t in smoothing[:, 0]], smoothing[:, 1], 'k-', linewidth = 1)
legends.append(p)
legend_labels.append("smoothed_data")

for (rng, label, color) in zip(ranges, labels, SIM_PALETTE):
p, = plt.plot(t, rng["avg"], color = color, linewidth = 2)
f = plt.fill_between(t, rng["min"], rng["max"], color = color, alpha = 0.2)
legends.append((p, f))
legend_labels.append(label)

plt.gca().xaxis.set_major_formatter(DATE_FMT)
plt.gca().xaxis.set_minor_formatter(DATE_FMT)
plt.legend(legends, legend_labels, prop = dict(size = 20), handlelength = theme.handlelength, framealpha = theme.framealpha, loc = "best")
Expand All @@ -429,7 +435,7 @@ def Rt(dates, Rt_pred, Rt_CI_upper, Rt_CI_lower, CI, ymin = 0.5, ymax = 3, yaxis
""" plot Rt and associated confidence intervals over time """
CI_marker = plt.fill_between(dates, Rt_CI_lower, Rt_CI_upper, color = BLK, alpha = 0.3)
Rt_marker, = plt.plot(dates, Rt_pred, color = BLK, linewidth = 2, zorder = 5, solid_capstyle = "butt")
if yaxis_colors:
if yaxis_colors:
plt.plot([dates[0], dates[0]], [2.5, ymax], color = RED, linewidth = 6, alpha = 0.9, solid_capstyle="butt", zorder = 10)
plt.plot([dates[0], dates[0]], [1, 2.5], color = YLW, linewidth = 6, alpha = 0.9, solid_capstyle="butt", zorder = 10)
plt.plot([dates[0], dates[0]], [ymin, 1], color = GRN, linewidth = 6, alpha = 0.9, solid_capstyle="butt", zorder = 10)
Expand All @@ -447,13 +453,13 @@ def Rt(dates, Rt_pred, Rt_CI_upper, Rt_CI_lower, CI, ymin = 0.5, ymax = 3, yaxis
plt.gca().xaxis.set_minor_formatter(DATE_FMT)
set_tick_size(theme.ticks["size"])
pd.markers = {"Rt" : (CI_marker, Rt_marker)}
return pd
return pd

def daily_cases(dates, T_pred, T_CI_upper, T_CI_lower, new_cases_ts, anomaly_dates, anomalies, CI, prediction_ts = None):
def daily_cases(dates, T_pred, T_CI_upper, T_CI_lower, new_cases_ts, anomaly_dates, anomalies, CI, prediction_ts = None):
""" plots expected, smoothed cases from simulated annealing training """
new_cases_dates = dates[-len(new_cases_ts):]
exp_cases_dates = dates[-len(T_pred):]
valid_idx = [i for i in range(len(dates)) if dates[i] not in anomaly_dates]
valid_idx = [i for i in range(len(dates)) if dates[i] not in anomaly_dates]
T_CI_lower_rect = [min(l, u) for (l, u) in zip(T_CI_lower, T_CI_upper)]
T_CI_upper_rect = [max(l, u) for (l, u) in zip(T_CI_lower, T_CI_upper)]
observed_marker, = plt.plot([d for d in new_cases_dates if d not in anomaly_dates], [new_cases_ts[i] for i in range(len(new_cases_ts)) if new_cases_dates[i] not in anomaly_dates], color = OBS_BLK, linewidth = 2, zorder = 8)
Expand All @@ -473,12 +479,12 @@ def daily_cases(dates, T_pred, T_CI_upper, T_CI_lower, new_cases_ts, anomaly_dat
plt.vlines(dates[-1], ymin = 0, ymax = top, color = "black", linestyles = "dotted")
legends += [(predicted_CI_marker, predicted_marker)]
labels += [label]
else:
else:
end = dates[-1]
plt.ylim(bottom = 0, top = top)
legends += [anomalies_marker]
labels += ["anomalies"]
xlim(left = dates[0], right = end)
plt.xlim(left = dates[0], right = end)
plt.legend(legends, labels, prop = {'size': 14}, framealpha = theme.framealpha, handlelength = theme.handlelength, loc = "best")
plt.gca().xaxis.set_major_formatter(DATE_FMT)
plt.gca().xaxis.set_minor_formatter(DATE_FMT)
Expand All @@ -494,25 +500,25 @@ def choropleth(gdf, label_fn = lambda _: "", col = "Rt", title = "$R_t$", label_
ax.set_xticks([])
ax.set_yticks([])
if title:
ax.set_title(title, loc="left", fontdict = theme.label)
ax.set_title(title, loc="left", fontdict = theme.label)
gdf.plot(color=[mappable.to_rgba(_) for _ in gdf[col]], ax = ax, edgecolors="black", linewidth=0.5, missing_kwds = {"color": theme.accent, "edgecolor": "white"})
if label_fn is not None:
for (_, row) in gdf.iterrows():
label = label_fn(row)
value = round(row[col], 2)
ax.annotate(
s = f"{label}{value}",
xy = list(row["pt"].coords)[0],
ha = "center",
fontfamily = theme.note["family"],
color = "black", **label_kwargs,
s = f"{label}{value}",
xy = list(row["pt"].coords)[0],
ha = "center",
fontfamily = theme.note["family"],
color = "black", **label_kwargs,
fontweight = "semibold",
size = 12)\
.set_path_effects([Stroke(linewidth = 2, foreground = "white"), Normal()])
cbar_ax = fig.add_axes([0.90, 0.25, 0.01, 0.5])
cb = fig.colorbar(mappable = mappable, orientation = "vertical", cax = cbar_ax)
cbar_ax.set_title("$R_t$", fontdict = theme.note)

return PlotDevice(fig)

def double_choropleth(gdf, label_fn = lambda _: "", Rt_col = "Rt", Rt_proj_col = "Rt_proj", titles = ["Current $R_t$", "Projected $R_t$ (1 Week)"], arrangement = (1, 2), label_kwargs = {}, mappable = sm):
Expand All @@ -523,7 +529,7 @@ def double_choropleth(gdf, label_fn = lambda _: "", Rt_col = "Rt", Rt_proj_col =
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title, loc="left", fontdict = theme.label)
ax.set_title(title, loc="left", fontdict = theme.label)
gdf.plot(color=[mappable.to_rgba(_) for _ in gdf[col]], ax = ax, edgecolors="black", linewidth=0.5, missing_kwds = {"color": theme.accent, "edgecolor": "white"})
if label_fn is not None:
for (_, row) in gdf.iterrows():
Expand All @@ -536,7 +542,7 @@ def double_choropleth(gdf, label_fn = lambda _: "", Rt_col = "Rt", Rt_proj_col =
cbar_ax = fig.add_axes([0.95, 0.25, 0.01, 0.5])
cb = fig.colorbar(mappable = mappable, orientation = "vertical", cax = cbar_ax)
cbar_ax.set_title("$R_t$", fontdict = theme.note)

return PlotDevice(fig)

def double_choropleth_v(*args, **kwargs):
Expand Down
Loading

0 comments on commit 0bc6494

Please sign in to comment.