diff --git a/epimargin/policy.py b/epimargin/policy.py index fc84dba..eeec659 100644 --- a/epimargin/policy.py +++ b/epimargin/policy.py @@ -246,7 +246,7 @@ def __init__(self, daily_doses: int, effectiveness: float, bin_populations: np.a def name(self) -> str: return f"{self.label}prioritized" - def distribute_doses(self, model: SIR, num_sims: int = 10_000) -> Tuple[np.ndarray]: + def distribute_doses(self, model: SIR, num_sims: int = 10_000) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if self.exhausted(model): return (None, None, None) # return (np.zeros(self.age_ratios.shape), np.zeros(self.age_ratios.shape), np.zeros(self.age_ratios.shape)) @@ -254,7 +254,7 @@ def distribute_doses(self, model: SIR, num_sims: int = 10_000) -> Tuple[np.ndarr model.S[-1] -= dV model.parallel_forward_epi_step(num_sims = num_sims) - dVx = np.zeros(self.bin_populations.shape) + dVx : np.ndarray = np.zeros(self.bin_populations.shape) bin_idx, age_bin = next(((i, age_bin) for (i, age_bin) in enumerate(self.prioritization) if self.bin_populations[age_bin] > 0), (None, None)) if age_bin is not None: if self.bin_populations[age_bin] > self.daily_doses: @@ -269,6 +269,7 @@ def distribute_doses(self, model: SIR, num_sims: int = 10_000) -> Tuple[np.ndarr self.bin_populations[self.prioritization[bin_idx + 1]] -= leftover else: print("vaccination exhausted", self.bin_populations, self.prioritization) + return ( dVx, dVx * self.effectiveness,