Skip to content

SIR model parameter estimation using a novel algorithm for differentiated uniformization.

License

Notifications You must be signed in to change notification settings

spang-lab/TenSIR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TenSIR

Parameter estimation on epidemic data under the SIR model using a novel algorithm for differentiated uniformization of Markov transition rate matrices in tensor representation.

This repository contains the code for the manuscript Differentiated uniformization: A new method for inferring Markov chains on combinatorial state spaces including stochastic epidemic models.

Data

We used the data from the Austrian BMSGPK on the COVID-19 pandemic from March 2020 to August 2020. A CSV file containing the data used by us can be found here if the API is subject to change in the future.

Susceptible and infected people to/with COVID-19 in Austria during the early months of the pandemic

Timeline plot

Results

Kernel density estimation plot of points generated by Hamilton Monte Carlo simulation

HMC plot

The x marks the least squares estimate after grid search using the default deterministic model (system of ODEs).

Reproducing the results

Prerequisites

  • Python 3.9+ (tested with Python 3.9). For installation instructions please refer to the documentation. Also make sure that Python and Pip are in your PATH and reachable from a terminal/Powershell via python3 and pip3.

Setup

We advise you to use a virtual environment for running the code. See setup and activation instructions here. After you activated it, run the following command to install the dependencies:

pip3 install -r requirements.txt

Generating points

To exactly reproduce our results, one should use the generate-points.py script. For example, to generate 1000 points using NUTS sampling for month March (3) in run number 0 with 16 threads, do:

python3 source/generate-points.py --sampling nuts --points 1000 --month 3 --run 0 --threads 16

Use the --help option to show all possible/necessary parameters:

python3 source/generate-points.py --help

The random number generator is seeded uniquely for each run by seed = month * 1000 + run.

Technically any thread count can be set, but vanishing returns are expected for thread counts greater than 60. The runs are independent and can be computed in parallel on e.g. different compute nodes.

The defaults for the parameters have been used in the paper's results.

Leveraging HPC clusters

Especially for months March, April and August the simulation can take quite some time. We suggest you use a compute cluster if accessible. Please refer to your cluster's documentation or administration for setup instructions.

Create plots

First, make sure that the venv is still activated.

Once you have the points in the results directory you can produce the evaluations via the subcommands of the evaluations.py script. See which commands exist with the --help option:

python3 source/plots.py --help

Exact reproducibility

When the following commands are executed, the exact same results will be created (The setup from above must be done first, and the venv activated).

Expand
# generate all MH points
python source/generate-points.py --sampler mh --month 3 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 3 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 4 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 5 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 6 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 7 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler mh --month 8 --run 9 --tune 100 --draws 900 --threads 60

# generate all HMC points
python source/generate-points.py --sampler nuts --month 3 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 3 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 4 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 5 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 6 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 7 --run 9 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 0 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 1 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 2 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 3 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 4 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 5 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 6 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 7 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 8 --tune 100 --draws 900 --threads 60
python source/generate-points.py --sampler nuts --month 8 --run 9 --tune 100 --draws 900 --threads 60

# generate all plots and evaluations
python source/evaluation.py tables
python source/evaluation.py density-plot 
python source/evaluation.py timeline-plot 
python source/evaluation.py trace-plot --month 5 --chain 0

# for supplement additionally
python source/evaluation.py trace-plot --month 3 --chain 0
python source/evaluation.py trace-plot --month 4 --chain 0
python source/evaluation.py trace-plot --month 6 --chain 0
python source/evaluation.py trace-plot --month 7 --chain 0
python source/evaluation.py trace-plot --month 8 --chain 0
python source/evaluation.py autocorrelation-plot --month 3
python source/evaluation.py autocorrelation-plot --month 4
python source/evaluation.py autocorrelation-plot --month 5
python source/evaluation.py autocorrelation-plot --month 6
python source/evaluation.py autocorrelation-plot --month 7
python source/evaluation.py autocorrelation-plot --month 8

Note: As these are highly intensive computations, it will take a long time. Generating the points anywhere else than on a high-performance cluster is infeasible.

Technical notes for understanding/reviewing the code

In our framework we use the convention Theta = (alpha, beta) and theta = (log(alpha), log(beta)) where alpha, beta are the parameters of the SIR model.`

For specific reference about a function/module please see the doc strings in the code itself. Furthermore, see the repository's structure:

Project structure

.
├── cache/                     Directory for temporary caching (e.g. the downloaded data or the states
│   │                          of the RNG when generating points)
│   └── data-austria.csv       Cached CSV file containing the raw data
├── logs/                      Target directory for logs which are output during generation of points
├── results/                   Target directory for simulation/plotting results
│   ├── mh-points/             Target for generated points with the HM sampler
│   ├── nuts-points/           Target for generated points with the HMC (NUTS) sampler
│   └── plots/                 Target for generated plots
├── source/                    Directory containing all source code
│   ├── tensir/                Module containing all the library functionality (not intended to be used via public CLI)
│   │   ├── optimize/          Module containing optimization routines for debugging (deterministic ODE solver,
│   │   │                      grid search, gradient ascent)
│   │   ├── uniformization/    Implementation of the derivative and forward evaluation via uniformization (research
│   │   │                      object of this paper)
│   │   └── data.py            Functions for loading data
│   ├── evaluation.py          Main script to generate the plots
│   └── generate-points.py     Main script to generate, i.e. sample, points
└── requirements.txt           Text file containing all dependencies (adhering to the Python convention)