-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fd71b60
commit 0e6acdd
Showing
45 changed files
with
214,703 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Smart Distributed Data Factory | ||
|
||
![Static Badge](https://img.shields.io/badge/bioRxiv-10.1101%2F2024.10.22.619651-red) [![Zenodo](https://zenodo.org/badge/DOI/10.5281/zenodo.14008357.svg)](https://doi.org/10.5281/zenodo.14008357) | ||
|
||
This repository hosts the source code for the experiments in the paper **"Smart Distributed Data Factory: Volunteer Computing Platform for Active Learning-Driven Molecular Dara Acquisition"**. | ||
The paper | ||
The repository provides scripts for the training and inference of energy prediction models, as well as the active-learning framework simulation. | ||
|
||
The pre-print with the detailed description of the methods and implementation is available on [bioRxiv](https://www.biorxiv.org/content/10.1101/2024.10.22.619651v1).\ | ||
The conformational energy dataset and the benchmark for machine learning models is available on [Zenodo](https://zenodo.org/records/14008357). | ||
|
||
### Conformational energy prediction | ||
|
||
We provide a script for running the inference with our conformational energy prediction models. For example, you can run the GENConv model on input conformations (each in a separate .sdf file) as follows: | ||
``` | ||
python -m force_field_models.inference.inference --config force_field_models/model_configs/ConfigurableGNNModel_GENConv_new_normals.yaml --checkpoint GENConv.pth --data_dir dataset/molecules --output_file predictions.csv | ||
``` | ||
The predictions are in Hartree units. | ||
|
||
Model checkpoints: | ||
- GENConv: [[Download link](https://sddf-checkpoints.s3.us-east-1.amazonaws.com/energy-v2024-Q3/GENConv.pth)][config file: `ConfigurableGNNModel_GENConv_new_normals.yaml`] | ||
- PNAConv: [[Download link](https://sddf-checkpoints.s3.us-east-1.amazonaws.com/energy-v2024-Q3/PNAConv.pth)][config file: `ConfigurableGNNModel_PNAConv_new_normals.yaml`] | ||
- ResGatedConv: [[Download link](https://sddf-checkpoints.s3.us-east-1.amazonaws.com/energy-v2024-Q3/ResGatedConv.pth)][config file: `ConfigurableGNNModel_ResGatedConv.yaml`] | ||
- GeneralConv: [[Download link](https://sddf-checkpoints.s3.us-east-1.amazonaws.com/energy-v2024-Q3/GeneralConv.pth)][config file: `ConfigurableGNNModel_GeneralConv_new_normals.yaml`] | ||
- TransformerConv: [[Download link](https://sddf-checkpoints.s3.us-east-1.amazonaws.com/energy-v2024-Q3/TransformerConv.pth)][config file: `ConfigurableGNNModel_TransformerConv_new_normals_1.yaml`] | ||
|
||
#### Installation steps | ||
|
||
In order to run the code, you first need to have Python 3.11 or Python 3.12 installed on your system. | ||
Then, you should install the remaining dependencies using: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Active learning-based conformation sampling | ||
|
||
We also provide a script (`force_field_models/train/cycle.py`) for running the simulation of the active learning-based dataset sampling. | ||
In order to run it, you should specify the model configs, the initial datasets, and training-related hyperparameters. | ||
A detailed explanation and an example is given in [cycle_README.md](cycle_README.md). | ||
|
||
### How to cite this work | ||
|
||
For citation please use: | ||
- The paper (pre-print):\ | ||
*Ghukasyan, T., Altunyan, V., Bughdaryan, A., Aghajanyan, T., Smbatyan, K., Papoian, G. A., & Petrosyan, G. (2024). SMART DATA FACTORY: VOLUNTEER COMPUTING PLATFORM FOR ACTIVE LEARNING-DRIVEN MOLECULAR DATA ACQUISITION. bioRxiv, 2024-10.* | ||
- The dataset:\ | ||
*Altunyan, V., Ghukasyan, T., Bughdaryan, A., Aghajanyan, T., Smbatyan, K., Papoian, G., & Petrosyan, G. (2024). SDDF Energy Dataset (2024-Q3) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.14008357* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# README | ||
|
||
## Running the Training Cycle | ||
|
||
To run the training cycle, ensure that the following parameters are properly configured in the configuration file: | ||
|
||
### Configuration Instructions: | ||
1. **Seed Configuration**: | ||
- Set the **seed** to ensure reproducibility. | ||
|
||
2. **Buffer Settings**: | ||
- Define the following parameters: | ||
- `buffer_size`: Size of the buffer. | ||
- `seed_size`: Initial size of the buffer. | ||
- `buffer_step_size`: Increment size for the buffer at each step. | ||
|
||
3. **Training Data Path**: | ||
- Provide the **path** to the processed training data. The data should be in `.npz` file format. | ||
|
||
4. **Use Model Predictions or Random Selection**: | ||
- Indicate whether to use model prediction scores or select randomly using the `use_buffer_predictions` parameter. | ||
|
||
5. **Step Configuration**: | ||
- Set the **step_num** to define the number of steps in the cycle. | ||
|
||
6. **Training Parameters**: | ||
- Specify the following: | ||
- `max_epoch`: Maximum number of epochs to train each model. | ||
- `gpu_device`: GPU device to be used for training. | ||
|
||
7. **Model Configurations**: | ||
- Provide **config files** for each model. | ||
- Each model's configuration file should include: | ||
- The model type and name. | ||
- Basic configurations required for the model. | ||
|
||
### Running the Cycle: | ||
|
||
To execute the training cycle, use the following command: | ||
|
||
Use [cycle experiment config file](force_field_models/sdf_experiments/test_configurations/cycle_readme_example.yaml) | ||
```bash | ||
python -m force_field_models.train.cycle -tc cycle_readme_example.yaml | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .inference.inference_model import EnsembleForceField |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import sys | ||
|
||
import torch | ||
|
||
from ..data.utils import get_energy_value | ||
|
||
|
||
if __name__ == "__main__": | ||
src_path = sys.argv[1] | ||
dst_path = sys.argv[2] | ||
|
||
for key in ['test', 'valid', 'train']: | ||
data_list = torch.load(f'{src_path}/energy_{key}.pt') | ||
for data in data_list: | ||
energy = get_energy_value(data.file_name) | ||
data.initial_target = energy | ||
del data.y | ||
|
||
torch.save(data_list, f"{dst_path}/energy_{key}_initial_targets.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
key,value | ||
H_H,0 | ||
H_C,1 | ||
H_N,2 | ||
H_O,3 | ||
H_S,4 | ||
H_Cl,5 | ||
H_Br,6 | ||
H_F,7 | ||
C_H,9 | ||
C_C,10 | ||
C_N,11 | ||
C_O,12 | ||
C_S,13 | ||
C_Cl,14 | ||
C_Br,15 | ||
C_F,16 | ||
N_H,18 | ||
N_C,19 | ||
N_N,20 | ||
N_O,21 | ||
N_S,22 | ||
N_Cl,23 | ||
N_Br,24 | ||
N_F,25 | ||
O_H,27 | ||
O_C,28 | ||
O_N,29 | ||
O_O,30 | ||
O_S,31 | ||
O_Cl,32 | ||
O_Br,33 | ||
O_F,34 | ||
S_H,36 | ||
S_C,37 | ||
S_N,38 | ||
S_O,39 | ||
S_S,40 | ||
S_Cl,41 | ||
S_Br,42 | ||
S_F,43 | ||
Cl_H,45 | ||
Cl_C,46 | ||
Cl_N,47 | ||
Cl_O,48 | ||
Cl_S,49 | ||
Cl_Cl,50 | ||
Cl_Br,51 | ||
Cl_F,52 | ||
Br_H,54 | ||
Br_C,55 | ||
Br_N,56 | ||
Br_O,57 | ||
Br_S,58 | ||
Br_Cl,59 | ||
Br_Br,60 | ||
Br_F,61 | ||
F_H,63 | ||
F_C,64 | ||
F_N,65 | ||
F_O,66 | ||
F_S,67 | ||
F_Cl,68 | ||
F_Br,69 | ||
F_F,70 | ||
H_<UNK>,8 | ||
C_<UNK>,17 | ||
N_<UNK>,26 | ||
O_<UNK>,35 | ||
S_<UNK>,44 | ||
Cl_<UNK>,53 | ||
Br_<UNK>,62 | ||
F_<UNK>,71 | ||
<UNK>_H,72 | ||
<UNK>_C,73 | ||
<UNK>_N,74 | ||
<UNK>_O,75 | ||
<UNK>_S,76 | ||
<UNK>_Cl,77 | ||
<UNK>_Br,78 | ||
<UNK>_F,79 | ||
<UNK>_<UNK>,80 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import os | ||
import sys | ||
|
||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from ..data.data import EnergyDatasetH5 | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
prefix = sys.argv[1] # e.g. 'train', 'valid', or 'test' | ||
src_h5_path = sys.argv[2] | ||
dst_dir = sys.argv[3] | ||
|
||
dataset = EnergyDatasetH5(src_h5_path, prefix=prefix) | ||
|
||
for i, mol_data in tqdm(enumerate(dataset)): | ||
npz_file_path = os.path.join(dst_dir, f'mol_{prefix}_{i}.npz') | ||
np.savez( | ||
npz_file_path, | ||
x=mol_data.x[0], | ||
edge_index=mol_data.edge_index, | ||
edge_attr_0=mol_data.edge_attr[0], | ||
edge_attr_1=mol_data.edge_attr[1], | ||
edge_attr_2=mol_data.edge_attr[2], | ||
edge_weights=mol_data.edge_weights, | ||
initial_target=mol_data.initial_target, | ||
pos=mol_data.pos, | ||
smiles=mol_data.smiles, | ||
file_name=mol_data.file_name | ||
) |
Oops, something went wrong.