Skip to content
This repository has been archived by the owner on Sep 1, 2021. It is now read-only.

Commit

Permalink
pylint: fixes to DataFrame related error
Browse files Browse the repository at this point in the history
1. `read_csv` return Union[DataFrame, TextParser] but a few functions assume
   that read_csv will return DataFrame only. This is fixed by casting
   the results to be a DataFrame.
2. Similarly, in estimator.py file, `params` is forced to be DataFrame
   so that `join` can be applied.

ref: COVID-IWG#123
  • Loading branch information
Dilawar Singh committed Jul 24, 2021
1 parent b215386 commit 69dac1a
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 111 deletions.
48 changes: 24 additions & 24 deletions epimargin/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def rollingOLS(totals: pd.DataFrame, window: int = 3, infectious_period: float =
# run rolling regressions and get parameters
model = RollingOLS.from_formula(formula = "logdelta ~ time", window = window, data = totals)
rolling = model.fit(method = "lstsq")
growthrates = rolling.params.join(rolling.bse, rsuffix="_stderr")

growthrates = pd.DataFrame(rolling.params).join(rolling.bse, rsuffix="_stderr")
growthrates["rsq"] = rolling.rsquared
growthrates.rename(lambda s: s.replace("time", "gradient").replace("const", "intercept"), axis = 1, inplace = True)

Expand All @@ -38,13 +38,13 @@ def rollingOLS(totals: pd.DataFrame, window: int = 3, infectious_period: float =
return growthrates

def analytical_MPVS(
infection_ts: pd.DataFrame,
infection_ts: pd.DataFrame,
smoothing: Callable,
alpha: float = 3.0, # shape
alpha: float = 3.0, # shape
beta: float = 2.0, # rate
CI: float = 0.95, # confidence interval
CI: float = 0.95, # confidence interval
infectious_period: int = 5*days, # inf period = 1/gamma,
variance_shift: float = 0.99, # how much to scale variance parameters by when anomaly detected
variance_shift: float = 0.99, # how much to scale variance parameters by when anomaly detected
totals: bool = True # are these case totals or daily new cases?
):
"""Estimates Rt ~ Gamma(alpha, 1/beta), and implements an analytical expression for a mean-preserving variance increase whenever case counts fall outside the CI defined by a negative binomial distribution"""
Expand All @@ -53,8 +53,8 @@ def analytical_MPVS(
if totals:
# daily_cases = np.diff(infection_ts.clip(lower = 0)).clip(min = 0) # infection_ts clipped because COVID19India API does weird stuff
daily_cases = infection_ts.clip(lower = 0).diff().clip(lower = 0).iloc[1:]
else:
daily_cases = infection_ts
else:
daily_cases = infection_ts
total_cases = np.cumsum(smoothing(np.squeeze(daily_cases)))

v_alpha, v_beta = [], []
Expand Down Expand Up @@ -105,18 +105,18 @@ def analytical_MPVS(
T_CI_lower.append(T_lower)

_np = p
_nr = r
_nr = r
anomaly_noted = False
counter = 0
while not (T_lower < new_cases < T_upper):
if not anomaly_noted:
anomalies.append(new_cases)
anomaly_dates.append(dates[i])

# logger.debug("anomaly identified at time %s: %s < %s < %s, r: %s, p: %s, annealing iteration: %s", i, T_lower, new_cases, T_upper, _nr, _np, counter+1)
# nnp = 0.95 *_np # <- where does this come from
# nnp = 0.95 *_np # <- where does this come from
_nr = variance_shift * _nr * ((1-_np)/(1-variance_shift*_np) )
_np = variance_shift * _np
_np = variance_shift * _np
T_upper = nbinom.ppf(CI, _nr, _np)
T_lower = nbinom.ppf(1-CI, _nr, _np)
T_lower, T_upper = sorted((T_lower, T_upper))
Expand All @@ -130,7 +130,7 @@ def analytical_MPVS(
raise ValueError("Number of iterations exceeded")
else:
if anomaly_noted:
alpha = _nr # update distribution on R with new parameters that enclose the anomaly
alpha = _nr # update distribution on R with new parameters that enclose the anomaly
beta = _np/(1-_np) * old_new_cases

T_pred[-1] = nbinom.mean(_nr, _np)
Expand All @@ -141,22 +141,22 @@ def analytical_MPVS(
Rt_upper = max(0, 1 + infectious_period * np.log(Gamma.ppf(CI , a = alpha, scale = 1/beta)))
Rt_lower = max(0, 1 + infectious_period * np.log(Gamma.ppf(1 - CI, a = alpha, scale = 1/beta)))

# replace latest CI time series entries with adjusted CI
# replace latest CI time series entries with adjusted CI
Rt_CI_upper[-1] = Rt_upper
Rt_CI_lower[-1] = Rt_lower
return (
dates[2:],
Rt_pred, Rt_CI_upper, Rt_CI_lower,
T_pred, T_CI_upper, T_CI_lower,
total_cases, new_cases_ts,
dates[2:],
Rt_pred, Rt_CI_upper, Rt_CI_lower,
T_pred, T_CI_upper, T_CI_lower,
total_cases, new_cases_ts,
anomalies, anomaly_dates
)

def parametric_scheme_mcmc(daily_cases, CI = 0.95, gamma = 0.2, chains = 4, tune = 1000, draws = 1000, **kwargs):
""" Implements the Bettencourt/Soman parametric scheme via MCMC sampling """
if isinstance(daily_cases, (pd.DataFrame, pd.Series)):
case_values = daily_cases.values
else:
else:
case_values = np.array(daily_cases)
with pm.Model() as mcmc_model:
# lag new case counts
Expand All @@ -167,22 +167,22 @@ def parametric_scheme_mcmc(daily_cases, CI = 0.95, gamma = 0.2, chains = 4, tune
dT = pm.Poisson("dT", mu = dT_lag0, shape = (n,))
bt = pm.Gamma("bt", alpha = dT_lag0.cumsum(), beta = 0.0001 + dT_lag1.cumsum(), shape = (n,))
Rt = pm.Deterministic("Rt", 1 + pm.math.log(bt)/gamma)

trace = pm.sample(model = mcmc_model, chains = chains, tune = tune, draws = draws, cores = 1, **kwargs)
return (mcmc_model, trace, pm.summary(trace, hdi_prob = CI))

def branching_random_walk(daily_cases, CI = 0.95, gamma = 0.2, chains = 4, tune = 1000, draws = 1000, **kwargs):
""" estimate Rt using a random walk for branch parameter, adapted from old Rt.live code """
if isinstance(daily_cases, (pd.DataFrame, pd.Series)):
case_values = daily_cases.values
else:
else:
case_values = np.array(daily_cases)
with pm.Model() as mcmc_model:
# lag new case counts
dT_lag0 = case_values[1:]
dT_lag1 = case_values[:-1]
n = len(dT_lag0)

# Random walk magnitude
step_size = pm.HalfNormal('step_size', sigma = 0.03)
theta_raw_init = pm.Normal('theta_raw_init', 0.1, 0.1)
Expand All @@ -192,15 +192,15 @@ def branching_random_walk(daily_cases, CI = 0.95, gamma = 0.2, chains = 4, tune

Rt = pm.Deterministic("Rt", 1 + theta/gamma)
expected_cases = pm.Poisson('dT', mu = dT_lag1 * pm.math.exp(theta), observed = dT_lag0)

trace = pm.sample(model = mcmc_model, chains = chains, tune = tune, draws = draws, cores = 1, **kwargs)
return (mcmc_model, trace, pm.summary(trace, hdi_prob = CI))

def linear_projection(dates, R_values, smoothing, period = 7*days):
""" return 7-day linear projection """
julian_dates = [_.to_julian_date() for _ in dates[-smoothing//2:None]]
return OLS(
R_values[-smoothing//2:None],
R_values[-smoothing//2:None],
add_constant(julian_dates)
)\
.fit()\
Expand Down
159 changes: 81 additions & 78 deletions epimargin/etl/covid19india.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
new_states = set("Telangana")

# states renamed in 2011
renamed_states = {
renamed_states = {
"Orissa" : "Odisha",
"Pondicherry" : "Puducherry"
}
Expand Down Expand Up @@ -106,7 +106,7 @@
"Gender",
"Detected City",
"Notes",
'Contracted from which Patient (Suspected)',
'Contracted from which Patient (Suspected)',
'Nationality',
"Source_1",
"Source_2",
Expand All @@ -119,40 +119,40 @@
}

columns_v4 = v4 = [
'Entry_ID',
'State Patient Number',
'Date Announced',
'Entry_ID',
'State Patient Number',
'Date Announced',
'Age Bracket',
'Gender',
'Detected City',
'Detected District',
'Gender',
'Detected City',
'Detected District',
'Detected State',
'State code',
'Num Cases',
'State code',
'Num Cases',
'Current Status',
'Contracted from which Patient (Suspected)',
'Notes',
'Contracted from which Patient (Suspected)',
'Notes',
'Source_1',
'Source_2',
'Source_3',
'Nationality',
'Source_2',
'Source_3',
'Nationality',
'Type of transmission',
'Status Change Date',
'Status Change Date',
'Patient Number'
]

drop_cols_v4 = {
"Entry_ID",
'Age Bracket',
'Gender',
'Gender',
'Detected City',
'State code',
'Contracted from which Patient (Suspected)',
'Notes',
'Notes',
'Source_1',
'Source_2',
'Source_3',
'Nationality',
'Source_2',
'Source_3',
'Nationality',
'Type of transmission',
"State Patient Number"
}
Expand Down Expand Up @@ -180,7 +180,7 @@
'BR' : 'Bihar',
'CH' : 'Chandigarh',
'CT' : 'Chhattisgarh',
'DD' : 'Daman & Diu',
'DD' : 'Daman & Diu',
'DNDD': 'Dadra & Nagar Haveli and Daman & Diu',
'DL' : 'Delhi',
'DN' : 'Dadra & Nagar Haveli',
Expand Down Expand Up @@ -215,52 +215,52 @@
'WB' : 'West Bengal',
}

state_name_lookup = {
'Andaman & Nicobar Islands' : 'AN',
'Andaman And Nicobar Islands': 'AN',
'Andaman and Nicobar Islands': 'AN',
'Andhra Pradesh' : 'AP',
'Arunachal Pradesh' : 'AR',
'Assam' : 'AS',
'Bihar' : 'BR',
'Chandigarh' : 'CH',
'Chhattisgarh' : 'CT',
'Daman & Diu' : 'DD',
'Daman And Diu' : 'DD',
'Daman and Diu' : 'DD',
'Delhi' : 'DL',
'Dadra & Nagar Haveli' : 'DN',
'Dadra And Nagar Haveli' : 'DN',
'Dadra and Nagar Haveli' : 'DN',
'Goa' : 'GA',
'Gujarat' : 'GJ',
'Himachal Pradesh' : 'HP',
'Haryana' : 'HR',
'Jharkhand' : 'JH',
'Jammu & Kashmir' : 'JK',
'Jammu and Kashmir' : 'JK',
'Jammu And Kashmir' : 'JK',
'Karnataka' : 'KA',
'Kerala' : 'KL',
'Ladakh' : 'LA',
'Lakshadweep' : 'LD',
'Maharashtra' : 'MH',
'Meghalaya' : 'ML',
'Manipur' : 'MN',
'Madhya Pradesh' : 'MP',
'Mizoram' : 'MZ',
'Nagaland' : 'NL',
'Odisha' : 'OR',
'Punjab' : 'PB',
'Puducherry' : 'PY',
'Rajasthan' : 'RJ',
'Sikkim' : 'SK',
'Telangana' : 'TG',
'Tamil Nadu' : 'TN',
'Tripura' : 'TR',
'India' : 'TT',
'State Unassigned' : 'UN',
'Uttar Pradesh' : 'UP',
state_name_lookup = {
'Andaman & Nicobar Islands' : 'AN',
'Andaman And Nicobar Islands': 'AN',
'Andaman and Nicobar Islands': 'AN',
'Andhra Pradesh' : 'AP',
'Arunachal Pradesh' : 'AR',
'Assam' : 'AS',
'Bihar' : 'BR',
'Chandigarh' : 'CH',
'Chhattisgarh' : 'CT',
'Daman & Diu' : 'DD',
'Daman And Diu' : 'DD',
'Daman and Diu' : 'DD',
'Delhi' : 'DL',
'Dadra & Nagar Haveli' : 'DN',
'Dadra And Nagar Haveli' : 'DN',
'Dadra and Nagar Haveli' : 'DN',
'Goa' : 'GA',
'Gujarat' : 'GJ',
'Himachal Pradesh' : 'HP',
'Haryana' : 'HR',
'Jharkhand' : 'JH',
'Jammu & Kashmir' : 'JK',
'Jammu and Kashmir' : 'JK',
'Jammu And Kashmir' : 'JK',
'Karnataka' : 'KA',
'Kerala' : 'KL',
'Ladakh' : 'LA',
'Lakshadweep' : 'LD',
'Maharashtra' : 'MH',
'Meghalaya' : 'ML',
'Manipur' : 'MN',
'Madhya Pradesh' : 'MP',
'Mizoram' : 'MZ',
'Nagaland' : 'NL',
'Odisha' : 'OR',
'Punjab' : 'PB',
'Puducherry' : 'PY',
'Rajasthan' : 'RJ',
'Sikkim' : 'SK',
'Telangana' : 'TG',
'Tamil Nadu' : 'TN',
'Tripura' : 'TR',
'India' : 'TT',
'State Unassigned' : 'UN',
'Uttar Pradesh' : 'UP',
'Uttarakhand' : 'UT',
'West Bengal' : 'WB',
'Dadra And Nagar Haveli And Daman And Diu' : "DNDD"
Expand All @@ -275,16 +275,16 @@ def standardize_column_headers(df: pd.DataFrame):

# load data until April 26
def load_data_v3(path: Path, drop = drop_cols_v3):
cases = pd.read_csv(path,
cases = pd.read_csv(path,
usecols = set(columns_v3) - drop,
dayfirst = True, # source data does not have consistent date format so cannot rely on inference
parse_dates = ["Date Announced", "Status Change Date"])
standardize_column_headers(cases)
return cases

# load data for April 27 - May 09
# load data for April 27 - May 09
def load_data_v4(path: Path, drop = drop_cols_v3):
cases = pd.read_csv(path,
cases = pd.read_csv(path,
usecols = set(columns_v4) - drop,
dayfirst = True, # source data does not have consistent date format so cannot rely on inference
parse_dates = ["Date Announced", "Status Change Date"])
Expand All @@ -299,7 +299,7 @@ def add_time_col(grp_df):
def get_time_series(df: pd.DataFrame, group_col: Optional[Sequence[str]] = None, drop_negatives = True) -> pd.DataFrame:
if group_col:
group_cols = (group_col if isinstance(group_col, list) else [group_col]) + ["status_change_date", "current_status"]
else:
else:
group_cols = ["status_change_date", "current_status"]
if drop_negatives:
df = df[df["num_cases"] >= 0]
Expand All @@ -319,16 +319,16 @@ def load_all_data(v3_paths: Sequence[Path], v4_paths: Sequence[Path]) -> pd.Data
all_cases = pd.concat(cases_v3 + cases_v4)
all_cases["status_change_date"] = all_cases["status_change_date"].fillna(all_cases["date_announced"])
all_cases["detected_state"] = all_cases["detected_state"].str.strip().str.title()
all_cases["detected_district"] = all_cases["detected_district"].str.strip().str.title()
all_cases["detected_district"] = all_cases["detected_district"].str.strip().str.title()
return all_cases.dropna(subset = ["detected_state"])

# assuming analysis for data structure from COVID19-India saved as resaved, properly-quoted file (v1 and v2)
def load_data(datapath: Path, reduced: bool = False, schema: Optional[Sequence[str]] = None) -> pd.DataFrame:
def load_data(datapath: Path, reduced: bool = False, schema: Optional[Sequence[str]] = None) -> pd.DataFrame:
if not schema:
schema = columns_v1
df = pd.read_csv(datapath,
skiprows = 1, # supply fixed header in order to deal with Google Sheets export issues
names = schema,
df = pd.read_csv(datapath,
skiprows = 1, # supply fixed header in order to deal with Google Sheets export issues
names = schema,
usecols = (lambda _: _ not in drop_cols) if reduced else None,
dayfirst = True, # source data does not have consistent date format so cannot rely on inference
parse_dates = ["Date Announced", "Status Change Date"])
Expand All @@ -345,7 +345,10 @@ def replace_district_names(df_state: pd.DataFrame, state_district_maps: pd.DataF
def load_statewise_data(statewise_data_path: Path, drop_unassigned: bool = True) -> pd.DataFrame:
df_raw = pd.read_csv(statewise_data_path, parse_dates = ["Date"])
df_raw.rename(columns=state_code_lookup, inplace=True)
df = pd.DataFrame(df_raw.set_index(["Date","Status"]).unstack().unstack()).reset_index()
# NOTE: df_raw could be DataFrame or TextParser. Force it to be DataFrame to keep
# pylint happy.
df_raw = pd.DataFrame(df_raw).set_index(["Date", "Status"]).unstack().unstack()
df = df_raw.reset_index()
df.columns = ["state", "current_status", "status_change_date", "num_cases"]
df.replace("Confirmed", "Hospitalized", inplace=True)
# drop negative cases and cases with no state assigned
Expand Down
Loading

0 comments on commit 69dac1a

Please sign in to comment.