From 983c3ba33e20027a6c28e2c3a51f3b049935c41f Mon Sep 17 00:00:00 2001 From: Dilawar Singh Date: Sat, 24 Jul 2021 10:05:18 +0530 Subject: [PATCH] potential fix: https://github.com/COVID-IWG/epimargin/issues/124 and potential fix: https://github.com/COVID-IWG/epimargin/issues/125 --- epimargin/plots.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/epimargin/plots.py b/epimargin/plots.py index a6156fd..7003b44 100644 --- a/epimargin/plots.py +++ b/epimargin/plots.py @@ -17,7 +17,7 @@ import matplotlib.pyplot as plt -from .models import SIR +from .models import SIR, NetworkedSIR def normalize_dates(dates): try: @@ -307,21 +307,22 @@ def show(self, **kwargs): plt.show(**kwargs) return self -def plot_SIRD(model: SIR, layout = (1, 4)) -> PlotDevice: +def plot_SIRD(model: NetworkedSIR, layout = (1, 4)) -> PlotDevice: """ plot all 4 available curves (S, I, R, D) for a given SIR model """ fig, axes = plt.subplots(layout[0], layout[1], sharex = True, sharey = True) - t = list(range(len(model.Rt))) - for (ax, model) in zip(axes.flat, model.units): - s = ax.semilogy(t, model.S, alpha=0.75, label="Susceptibles") - i = ax.semilogy(t, model.I, alpha=0.75, label="Infectious", ) - d = ax.semilogy(t, model.D, alpha=0.75, label="Deaths", ) - r = ax.semilogy(t, model.R, alpha=0.75, label="Recovered", ) + # FIXME: commented out because model can not be of type SIR (it doesn't have units). + for (ax, _model) in zip(axes.flat, model.units): + t = list(range(len(_model.Rt))) + s = ax.semilogy(t, _model.S, alpha=0.75, label="Susceptibles") + i = ax.semilogy(t, _model.I, alpha=0.75, label="Infectious", ) + d = ax.semilogy(t, _model.D, alpha=0.75, label="Deaths", ) + r = ax.semilogy(t, _model.R, alpha=0.75, label="Recovered", ) ax.label_outer() fig.legend([s, i, r, d], labels = ["S", "I", "R", "D"], loc = "center right", borderaxespad = 0.1) return PlotDevice(fig) -def plot_curve(models: Sequence[SIR], labels: Sequence[str], curve: str = "I"): +def plot_curve(models: Sequence[NetworkedSIR], labels: Sequence[str], curve: str = "I"): """ plot specific epidemic curve """ fig = plt.figure() for (model, label) in zip(models, labels): @@ -374,7 +375,7 @@ def predictions(date_range, model, color, bounds = [2.5, 97.5], curve = "dT"): return [(range_marker, median_marker), model.name] def simulations( - simulation_results: Sequence[Tuple[SIR]], + simulation_results: Sequence[Tuple[NetworkedSIR]], labels: Sequence[str], historical: Optional[pd.Series] = None, historical_label: str = "Empirical Case Data", @@ -425,7 +426,8 @@ def simulations( plt.gca().xaxis.set_minor_formatter(DATE_FMT) plt.legend(legends, legend_labels, prop = dict(size = 20), handlelength = theme.handlelength, framealpha = theme.framealpha, loc = "best") - plt.xlim(left = historical.index[0], right = t[-1]) + if historical is not None: + plt.xlim(left = historical.index[0], right = t[-1]) if semilog: plt.semilogy() set_tick_size(14)