Predicts whether a patient will discontinue medication within 100 days using a compact embedding‑based neural network for categorical features, plus baselines and rich interpretability. Includes a polished Streamlit dashboard (dark theme) for metrics, global feature importance, plots, and a live prediction form with local explanations.
Target mapping: MMA_score_cat_new = "0. poor adherence" -> y=1 (discontinue); "1. good adherence" -> y=0 (continue).
From the repo root (hackathon):
python -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
pip install --upgrade pip
pip install -r content/med-discontinue/requirements.txt
# Train + CV (saves models and reports)
python content/med-discontinue/main.py --mode cv --config content/med-discontinue/config/config.yaml
# Generate interpretability artifacts (plots + JSON)
python content/med-discontinue/main.py --mode interpret --config content/med-discontinue/config/config.yaml
# Launch Streamlit dashboard
streamlit run content/med-discontinue/app.pyIf running commands from inside content/med-discontinue/:
python main.py --mode cv --config config/config.yaml
python main.py --mode interpret --config config/config.yaml
streamlit run app.py-
Entry point:
content/med-discontinue/main.py--mode cv: k‑fold CV for the embedding NN; also runs baselines; savesoutputs/models/nn_fold*.ptandoutputs/reports/nn_cv.json.--mode interpret: permutation importance, IG, PDP/ICE; saves underoutputs/reports/andoutputs/plots/.--mode test: placeholder.
-
Config:
content/med-discontinue/config/config.yamldata.stata_path(Stata.dtapath), training hyperparams, model architecture, evaluation, and output paths.
-
Data & preprocessing
src/data_loading.pyreads.dta.src/preprocessing.pyfills unknowns and label‑encodes each categorical feature.src/dataset.pybuilds a torchDatasetof categorical indices.- Preprocessing schema is saved to
outputs/reports/preproc_artifacts.jsonduring CV.
-
Model & training
src/model.pyCategoricalEmbeddingNN (per‑feature embeddings → MLP → logit).src/train.pytraining loop with early stopping, LR scheduling, BCE/focal loss, PR‑based threshold selection.src/evaluate.pyPR/ROC metrics and threshold selection.src/run_cv.pyorchestrates StratifiedKFold, saves fold checkpoints and aggregated report.
-
Baselines
src/baselines.pyLogistic Regression and RandomForest; savesoutputs/reports/baselines_cv.json.
-
Interpretability
src/interpretability/permutation.pypermutation importance. Output schema can be:{"importances": [{"feature": str, "importance": float}, ...]}- or
{"columns": [..], "f2_drop": [..]}(F2 drop per feature).
src/interpretability/captum_ig.pyglobal/local Integrated Gradients (global_ig.json, plots).src/interpretability/pdp_ice.pycategorical PDP/ICE plots and summary JSON.
-
Artifacts
outputs/models/nn_fold*.pt— best model per fold.outputs/reports/*.json— CV, baselines, permutation importance, IG, PDP/ICE, preprocessing schema.outputs/plots/*.png— permutation, IG, PDP/ICE, local example.
Path: content/med-discontinue/app.py
Features
- Dark theme UI with high contrast and subtle depth.
- Dashboard tab: CV metrics (NN + baselines), global feature importance (Altair bars), and plots gallery.
- Predict tab: dropdowns for all categorical features using
preproc_artifacts.json; loads a checkpoint; shows live probability, decision (mean CV threshold), and permutation‑based local deltas. - Diagnostics expander: shows resolved artifact paths and presence; Reload button now clears cache and reruns.
- Handles both permutation report schemas automatically (list or columns+f2_drop).
Tips
- Run both
cvandinterpretbefore launching the app to populate all sections. - If you change artifacts, click “Reload data” in the sidebar to refresh.
- If running into path issues, run from repo root or adjust
stata_pathand command paths accordingly.
med-discontinue/
├─ data/ # Place your .dta file here
├─ config/
│ └─ config.yaml # Paths & hyperparameters
├─ src/
│ ├─ data_loading.py # Load Stata
│ ├─ preprocessing.py # Fill 'Unknown', label-encode categoricals
│ ├─ dataset.py # PyTorch Dataset for categorical indices
│ ├─ model.py # Embedding NN (categorical)
│ ├─ losses.py # BCEWithLogits + optional focal
│ ├─ train.py # Train loop + early stopping
│ ├─ evaluate.py # Metrics, PR/ROC, threshold selection
│ ├─ calibrate.py # Temperature scaling
│ ├─ interpretability/
│ │ ├─ permutation.py # Permutation importance
│ │ ├─ captum_ig.py # Integrated Gradients
│ │ └─ pdp_ice.py # PDP/ICE for categoricals
│ ├─ baselines.py # Logistic Regression & RandomForest
│ ├─ utils.py # Seed, logging, plotting
│ └─ run_cv.py # k-fold orchestration
├─ outputs/
│ ├─ models/
│ ├─ plots/
│ └─ reports/
├─ requirements.txt
├─ app.py
└─ main.py
- “File not found” for
stata_path: ensure the path in config matches your file name exactly (including spaces), or run from repo root. - Dashboard says “Permutation importance report not found”: run interpret mode, then click Reload in the sidebar.
- Streamlit caching: the Reload button clears cache and reruns.