Drug repositioning finds new uses for existing drugs, offering a cost-effective alternative to traditional drug development. Graph Neural Networks (GNNs) have shown promise in predicting drug-disease associations but often lack explainability, making it difficult to validate predictions.
DBR-X addresses this challenge by integrating:
✅ A link prediction module to identify potential drug-disease relationships.
✅ A path-identification module to provide interpretable explanations
Run your own subgraph collection procedure using the following steps:
Collect chains around train drug queries, joining the drug query entity to the disease answer.
python src/00_collect_subgraphs/00_find_paths.py --data_name <dataset_name> \
--data_dir_name <path_to_train_graph_files> \
--output_dir <path_to_save_output_file> \
--cutoff 3 <Number of hops to consider when collecting chains> \
--paths_to_collect 1000 <Number of total paths to collect for each query>
For a given query in train/dev/test set, retrieve its KNN queries from the training set. Then, gather the collected paths from step 1 and traverse the knowledge graph.
python src/00_collect_subgraphs/01_graph_collection.py --data_name <dataset_name> \
--data_dir_name <path_to_train_test_dev_files> \
--knn 5 <Number of KNN to consider for each query> \
--collected_chains_name <path_to_collected_chains_file_in_step_1> \
--branch_size 1000 <Max considered nodes when traversing the graph> \
--output_dir <path_to_save_output_file>
The 00_link_prediction_module.py
file is the main file for training the link-prediction model. The module provides functionality to:
- Load and process knowledge graph datasets
- Train a HeteroRGCN model for link prediction
- Evaluate model performance on validation and test sets
- Log training metrics and save predictions
data_name
(str, default: "MIND"): Name of the knowledge graph datasetdata_dir
(str, default: "data/MIND/"): Directory containing train, test, and dev datapaths_file_dir
(str, default: "MIND_cbr_subgraph_knn-15_branch-1000.pkl"): Name of the paths filenum_neighbors_train
(int, default: 5): Number of neighbor entities for trainingnum_neighbors_eval
(int, default: 5): Number of neighbor entities for evaluation
emb_dim
(int, default: 128): Dimension of node embeddingshidden_dim
(int, default: 128): Dimension of initial GCN layerout_dim
(int, default: 128): Dimension of output GCN layerdevice
(torch.device, default: None): Device for training (auto-detected if None)
use_wandb
(int, default: 0): Enable WandB logging (1 to enable)dist_metric
(str, default: "l2"): Distance metric for similarity ("l2" or "cosine")dist_aggr1
(str, default: "mean"): Aggregation function for neighbor queries ("none", "mean", "sum")dist_aggr2
(str, default: "mean"): Aggregation function across all neighbor queries ("mean", "sum")sampling_loss
(float, default: 1.0): Fraction of negative samples to usetemperature
(float, default: 1.0): Temperature for cross-entropy loss scalinglearning_rate
(float, default: 0.001): Initial learning rate for trainingwarmup_step
(int, default: 0): Number of warmup steps for schedulerweight_decay
(float, default: 0.0): Weight decay for AdamW optimizernum_train_epochs
(int, default: 20): Total number of training epochsgradient_accumulation_steps
(int, default: 1): Steps to accumulate gradients before updatecheck_steps
(float, default: 5.0): sFrequency of training progress checksres_name
(str, default: "mind_test_predictions"): Name of output predictions fileoutput_dir
(str, default: "link_prediction_results/MIND/"): Directory to save resultsmodel_path
(str, default: "link_prediction_results/MIND/model"): Directory to save best model
The 01_important_paths_module.py
file is the main file for extracting important paths that explain the predictions made by the trained link-prediction module.
data_dir
(str, default: "data/MIND/"): Directory containing the dataset.data_name
(str, default: "MIND"): Name of the dataset.paths_file_dir
(str, default: "MIND_cbr_subgraph_knn-15_branch-1000.pkl"): Name of the paths file from subgraph collection.model_path
(str, default: "link_prediction_results/MIND/model/"): Directory containing the model checkpoint.model_name
(str, default: "best_model_MIND_20250330_075232.pt"): Model checkpoint filename.output_dir
(str, default: "important_paths_results/MIND/"): Directory to save important paths.device
(str, default: "cuda"): Device to run on ("cuda" or "cpu").seed
(int, default: 42): Random seed for reproducibility.split
(str, default: "train"): Data split to evaluate ("train", "dev", or "test").
lr_
(float, default: 0.1): Learning rate for the explainer.alpha
(float, default: 1.0): Weight for the alpha term in the loss function.beta
(float, default: 1.0): Weight for the beta term in the loss function.num_epochs
(int, default: 10): Number of epochs to train the explainer.num_paths
(int, default: 5): Number of important paths to extract per query.max_path_length
(int, default: 4): Maximum length of extracted paths.degree_thr
(int, default: 10): Degree threshold for path filtering.penalty
(float, default: 1.0): Penalty term for path loss.