Skip to content

Commit

Permalink
resolve linting issues (#129)
Browse files Browse the repository at this point in the history
* resolve pylint issues

* renames, refactors, and type annotations

* restructure class hierarchy to reflect different parallelizing mechanisms

* second round of linting checks
  • Loading branch information
satejsoman committed Aug 5, 2021
1 parent 7170f5f commit 0b0b88a
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 89 deletions.
2 changes: 1 addition & 1 deletion epimargin/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def rollingOLS(totals: pd.DataFrame, window: int = 3, infectious_period: float =
model = RollingOLS.from_formula(formula = "logdelta ~ time", window = window, data = totals)
rolling = model.fit(method = "lstsq")

growthrates = rolling.params.join(rolling.bse, rsuffix="_stderr")
growthrates = pd.DataFrame(rolling.params).join(rolling.bse, rsuffix="_stderr")
growthrates["rsq"] = rolling.rsquared
growthrates.rename(lambda s: s.replace("time", "gradient").replace("const", "intercept"), axis = 1, inplace = True)

Expand Down
2 changes: 1 addition & 1 deletion epimargin/etl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .commons import download_data
__all__ = [download_data]
__all__ = ["download_data"]
86 changes: 54 additions & 32 deletions epimargin/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Iterator, Optional, Sequence, Tuple, Union
from typing import Dict, Iterator, Optional, Sequence, Tuple, Union, List

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -124,7 +124,7 @@ def forward_epi_step(self, dB: int = 0):
# parallel poisson draws for infection
def parallel_forward_epi_step(self, dB: int = 0, num_sims = 10000):
# get previous state
S, I, R, D, N = (vector[-1].copy() for vector in (self.S, self.I, self.R, self.D, self.N))
S, I, R, D, N = (vector[-1] for vector in (self.S, self.I, self.R, self.D, self.N))

# update state
Rt = self.Rt0 * S/N
Expand All @@ -148,9 +148,9 @@ def parallel_forward_epi_step(self, dB: int = 0, num_sims = 10000):

I -= (num_dead + num_recov)

S = S.clip(0)
I = I.clip(0)
D = D.clip(0)
S = max(0, S)
I = max(0, I)
D = max(0, D)

N = S + I + R
beta = (num_cases * N)/(b * S * I)
Expand All @@ -172,15 +172,15 @@ def parallel_forward_epi_step(self, dB: int = 0, num_sims = 10000):
# parallel binomial draws for infection
def parallel_forward_binom_step(self, dB: int = 0, num_sims = 10000):
# get previous state
S, I, R, D, N = (vector[-1].copy() for vector in (self.S, self.I, self.R, self.D, self.N))
S, I, R, D, N = (vector[-1] for vector in (self.S, self.I, self.R, self.D, self.N))

# update state
Rt = self.Rt0 * S/N
p = self.gamma * Rt * I/N

num_cases = binom.rvs(n = S.astype(int), p = p, size = num_sims)
self.upper_CI.append(binom.ppf(self.CI, n = S.astype(int), p = p))
self.lower_CI.append(binom.ppf(1 - self.CI, n = S.astype(int), p = p))
num_cases = binom.rvs(n = S, p = p, size = num_sims)
self.upper_CI.append(binom.ppf(self.CI, n = S, p = p))
self.lower_CI.append(binom.ppf(1 - self.CI, n = S, p = p))

I += num_cases
S -= num_cases
Expand All @@ -195,12 +195,11 @@ def parallel_forward_binom_step(self, dB: int = 0, num_sims = 10000):

I -= (num_dead + num_recov)

S = S.clip(0)
I = I.clip(0)
D = D.clip(0)
S = max(0, S)
I = max(0, I)
D = max(0, D)

N = S + I + R
# beta = (num_cases * N)/(b * S * I)

# update state vectors
self.Rt.append(Rt)
Expand All @@ -222,17 +221,20 @@ def run(self, days: int):
def __repr__(self) -> str:
return f"[{self.name}]"

class Age_SIRVD(SIR):
""" age-structured compartmental model with a vaccinated class for each age bin """
class Age_SIRVD():
""" age-structured compartmental model with a vaccinated class for each age bin
note that the underlying parallelizing mechanism is different from that of SIR and NetworkedSIR
"""
def __init__(self,
name: str, # name of unit
population: int, # unit population
dT0: Optional[int] = None, # last change in cases, None -> Poisson random intro
Rt0: float = 1.9, # initial reproductive rate,
S0: int = 0, # initial susceptible
I0: int = 0, # initial infected
R0: int = 0, # initial recovered
D0: int = 0, # initial dead
S0: np.array = np.array(0), # initial susceptibles
I0: np.array = np.array(0), # initial infected
R0: np.array = np.array(0), # initial recovered
D0: np.array = np.array(0), # initial dead
infectious_period: int = 5, # how long disease is communicable in days
introduction_rate: float = 5.0, # parameter for new community transmissions (lambda)
mortality: float = 0.02, # I -> D transition probability
Expand All @@ -245,7 +247,36 @@ def __init__(self,
ve: float = 0.7, # vaccine effectiveness
random_seed: int = 0 # random seed,
):
super().__init__(name, population, dT0=dT0, Rt0=Rt0, I0=I0, R0=R0, D0=D0, infectious_period=infectious_period, introduction_rate=introduction_rate, mortality=mortality, mobility=mobility, upper_CI=upper_CI, lower_CI=lower_CI, CI=CI, random_seed=random_seed)
self.name = name
self.pop0 = population
self.gamma = 1.0/infectious_period
self.ll = introduction_rate
self.m = mortality
self.mu = mobility
self.Rt0 = Rt0
self.CI = CI

# state and delta vectors
if dT0 is None:
dT0 = np.random.poisson(self.ll) # initial number of new cases
self.dT = [dT0] # case change rate, initialized with the first introduction, if any
self.Rt = [Rt0]
self.b = [np.exp(self.gamma * (Rt0 - 1.0))]
self.S = [S0 if S0 is not None else population - R0 - D0 - I0]
self.I = [I0]
self.R = [R0]
self.D = [D0]
self.dR = [0]
self.dD = [0]
self.N = [population - D0] # total population = S + I + R
self.beta = [Rt0 * self.gamma] # initial contact rate
self.total_cases = [I0] # total cases
self.upper_CI = [upper_CI]
self.lower_CI = [lower_CI]

np.random.seed(random_seed)


self.N = [S0 + I0 + R0]
shape = (sims, bins) = S0.shape

Expand Down Expand Up @@ -276,7 +307,7 @@ def __init__(self,

self.dT_total = [np.zeros(sims)]
self.dD_total = [np.zeros(sims)]
self.dV = []
self.dV: List[np.array] = []

self.rng = np.random.default_rng(random_seed)

Expand Down Expand Up @@ -440,21 +471,12 @@ def set_parameters(self, **kwargs):
unit.__setattr__(attr, val)
return self

def aggregate(self, curves: Union[Sequence[str], str] = ["Rt", "b", "S", "I", "R", "D", "P", "beta"]) -> Union[Dict[str, Sequence[float]], Sequence[float]]:
if isinstance(curves, str):
single_curve = curves
curves = [curves]
else:
single_curve = False
aggs = {
def aggregate(self, curves: Union[Sequence[str], str] = ["Rt", "b", "S", "I", "R", "D", "P", "beta"]) -> Dict[str, Sequence[float]]:
return {
curve: list(map(sum, zip(*(unit.__getattribute__(curve) for unit in self.units))))
for curve in curves
}

if single_curve:
return aggs[single_curve]
return aggs

class SEIR():
""" stochastic SEIR model without external introductions """
def __init__(self,
Expand Down
33 changes: 12 additions & 21 deletions epimargin/plots.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from build.lib.epimargin.models import NetworkedSIR
import datetime
from collections import namedtuple
from pathlib import Path
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, List, Dict

import matplotlib as mpl
import matplotlib.dates as mdates
Expand All @@ -11,7 +12,6 @@
import seaborn as sns
import tikzplotlib
from matplotlib.patheffects import Normal, Stroke
from matplotlib.pyplot import *

from .models import SIR

Expand Down Expand Up @@ -56,15 +56,6 @@ def rebuild_font_cache():
import matplotlib.font_manager
matplotlib.font_manager._rebuild()

def despine(**kwargs):
pass

def grid(flag):
if flag:
pass
else:
pass

# container class for different theme
Aesthetics = namedtuple(
"Aesthetics",
Expand Down Expand Up @@ -301,7 +292,7 @@ def show(self, **kwargs):
plt.show(**kwargs)
return self

def plot_SIRD(model: SIR, layout = (1, 4)) -> PlotDevice:
def plot_SIRD(model: NetworkedSIR, layout = (1, 4)) -> PlotDevice:
""" plot all 4 available curves (S, I, R, D) for a given SIR model """
fig, axes = plt.subplots(layout[0], layout[1], sharex = True, sharey = True)
t = list(range(len(model.Rt)))
Expand All @@ -315,7 +306,7 @@ def plot_SIRD(model: SIR, layout = (1, 4)) -> PlotDevice:
fig.legend([s, i, r, d], labels = ["S", "I", "R", "D"], loc = "center right", borderaxespad = 0.1)
return PlotDevice(fig)

def plot_curve(models: Sequence[SIR], labels: Sequence[str], curve: str = "I"):
def plot_curve(models: Sequence[NetworkedSIR], labels: Sequence[str], curve: str = "I"):
""" plot specific epidemic curve """
fig = plt.figure()
for (model, label) in zip(models, labels):
Expand Down Expand Up @@ -368,7 +359,7 @@ def predictions(date_range, model, color, bounds = [2.5, 97.5], curve = "dT"):
return [(range_marker, median_marker), model.name]

def simulations(
simulation_results: Sequence[Tuple[SIR]],
simulation_results: Sequence[Tuple[NetworkedSIR]],
labels: Sequence[str],
historical: Optional[pd.Series] = None,
historical_label: str = "Empirical Case Data",
Expand All @@ -384,18 +375,18 @@ def simulations(
num_sims = len(simulation_results)
total_time = len(policy_outcomes[0][0])

ranges = [{"max": [], "min": [], "mdn": [], "avg": []} for _ in range(len(policy_outcomes))]
ranges: List[Dict[str, List]] = [{"max": [], "min": [], "mdn": [], "avg": []} for _ in range(len(policy_outcomes))]

for (i, policy) in enumerate(policy_outcomes):
for t in range(total_time):
curve_sorted = sorted([curve[t] for curve in policy])
for _ in range(total_time):
curve_sorted = sorted([curve[_] for curve in policy])
ranges[i]["min"].append(curve_sorted[0])
ranges[i]["max"].append(curve_sorted[-1])
ranges[i]["mdn"].append(curve_sorted[num_sims//2])
ranges[i]["avg"].append(np.mean(curve_sorted))

legends = []
legend_labels = []
legend_labels = []
if historical is not None:
p, = plt.plot(historical.index, historical, 'k-', alpha = 0.8, zorder = 10)
t = [historical.index.max() + datetime.timedelta(days = n) for n in range(total_time)]
Expand Down Expand Up @@ -478,7 +469,7 @@ def daily_cases(dates, T_pred, T_CI_upper, T_CI_lower, new_cases_ts, anomaly_dat
plt.ylim(bottom = 0, top = top)
legends += [anomalies_marker]
labels += ["anomalies"]
xlim(left = dates[0], right = end)
plt.xlim(left = dates[0], right = end)
plt.legend(legends, labels, prop = {'size': 14}, framealpha = theme.framealpha, handlelength = theme.handlelength, loc = "best")
plt.gca().xaxis.set_major_formatter(DATE_FMT)
plt.gca().xaxis.set_minor_formatter(DATE_FMT)
Expand Down Expand Up @@ -544,5 +535,5 @@ def double_choropleth_v(*args, **kwargs):
kwargs["arrangement"] = (2, 1)
return double_choropleth(*args, **kwargs)

double_choropleth.horizontal = double_choropleth
double_choropleth.vertical = double_choropleth_v
double_choropleth.horizontal = double_choropleth # type: ignore
double_choropleth.vertical = double_choropleth_v # type: ignore
Loading

0 comments on commit 0b0b88a

Please sign in to comment.