Madrigal is an open-source model for predicting drug combination outcomes from multimodal preclinical data. This repository provides the implementation of the model as described in our project page and our paper.
- First, clone this Github repository and install following the section below.
- Set up data directories and create a
.env
file (see below). - [Optional] Download datasets from our data repo in Harvard Dataverse and reorganize according to your
.env
setup. - [Optional] Download pretrained checkpoints from our checkpoint repo in Huggingface and reorganize according to your
.env
setup.
We provide sample model pretraining (second-stage modality alignment) and training scripts in scripts/
. Specifically, the second-stage pretraining scripts are provided in ./scripts/cl_pretrain/
, and the fine-tuning scripts are provided in ./scripts/ddi_finetune/
. The scripts will need to be adapted according to your machine.
The first-stage modality adaptation training scripts (or notebooks) and checkpoints can be found in modality_pretraining/
. You can also run inference with model checkpoints using sample Jupyter notebooks (to be uploaded).
Currently, modifications of the codebase are required to enable adaptation of the model to your own dataset. Below is an outline of possible preparations.
- There are certain arguments that require modifications (see
./madrigal/parse_args.py
) if you are incorporating a new dataset.data_source
: This arg affects path to load data and training and evaluation strategy.split_method
: This arg affects path to load data and evaluation strategy.task
: Depending on the nature of your dataset, you might want to change this.loss_fn_name
: Depending ontask
, you might want to change or reimplement this.
- Preparing data: Please refer to our provided data files for the exact formatting of each file.
- Drugs
- Metadata: Key to all other files.
- Modality data
- Structure: Use
torchdrug
to generate molecular graphs in the same way as molecules are ordered in metadata. - KG: Use
PyG
to generateHeteroData
objects, making sure drug node indices are ordered in the same way as in metadata. - Cell viability: Mainly tables.
- Transcriptomics: Mainly tables.
- Note that you will need to regenerate a file (hard-coded as
rdkit2D_embeddings_combined_all_normalized.parquet
) for chemCPA usage/pretraining.
- Note that you will need to regenerate a file (hard-coded as
- Structure: Use
- Drug combination outcomes
- Tables of (label_indexed, head (drug 1), tail (drug 2), negs*) (depending on dataset splitting strategy, the negative columns will have different meanings).
- Mapping between outcome label index and outcome information.
- Drugs
Before installing madrigal
, please set up a new conda environment through mamba env create -f env_new.yaml
(this process might take 1-2 hours; see mamba
installation guidelines here). By default, our environment is with CUDA 11.7 (gcc 9.2). Please edit env_new.yaml
accordingly if you are installing in another CUDA version. We welcome contributions of instructions on setting up the environment with other version control managers such as uv
.
Then, activate this environment with mamba activate primekg
. To install a global reference to madrigal
package in your interpreter (e.g. from madrigal.X import Y
), run the following:
cd /path/to/Madrigal
python -m pip install -e .
Then, test the install by trying to import madrigal
in the interpreter:
python -c "import madrigal; print('Imported')"
Now you should be able to use import madrigal
from anywhere on your system, as long as you use the same python interpreter.
We organize our data and model output folders in the following way:
Madrigal_Data
|-- processed_data
| |-- polypharmacy_new
| | |-- DrugBank
| | | |-- split_by_\*
| | | | |-- data tables
| |-- views_features_new
| | |-- metadata tables
| | |-- str
| | | |-- torchdrug-generated molecular graphs
| | |-- kg
| | | |-- PyG-generated KGs
| | |-- cv
| | | |-- cell viability tables
| | |-- tx
| | | |-- transcriptomics tables
|-- model_output
| |-- pretrain
| | |-- DrugBank
| | | |-- split_by_\*
| |-- DrugBank
| | |-- split_by_\*
This structure is reflected in the model code. Please make necessary edits if you are using a different organization.
Then, please add a file .env
to the project directory (root of this project) and specify the following paths (with /
at the end):
PROJECT_DIR=/path/to/Madrigal/
BASE_DIR=/path/to/Madrigal_Data/
DATA_DIR=/path/to/Madrigal_Data/processed_data/
ENCODER_CKPT_DIR=/path/to/Madrigal/modality_pretraining/
CL_CKPT_DIR=/path/to/Madrigal_Data/model_output/pretrain/
Currently, hard-coded paths to embedding checkpoints exist in the get_str_encoder, get_kg_encoder, get_cv_encoder, get_tx_encoder
functions in ./madrigal/model/models.py
. Corresponding modality pretrained checkpoints are provided in ./modality_pretraining/
.
The code in this package is licensed under the MIT License.
- The
torchdrug
module needs to be imported after importingtorch_geometric
modules. torchdrug>=0.2.0.post1
is required, as earlier versions cause an issue in LR scheduler.- We use
pytorch=1.13.1
, which requirescuda<12.0
. - (Updated
env_new.yaml
to resolve this issue.)If you encounterTypeError: canonicalize_version() got an unexpected keyword argument 'strip_trailing_zero'
while installing, please check out this post. In summary, eithersetuptools<71
orpackaging>=22
is required.