Skip to content

Modular Training and Evaluation of Neural Networks

License

Notifications You must be signed in to change notification settings

choderalab/mtenn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Aug 15, 2024
a98c0a0 · Aug 15, 2024
May 2, 2024
Apr 18, 2024
May 1, 2024
Aug 15, 2024
Dec 6, 2022
Jun 21, 2022
Mar 13, 2023
Jan 31, 2024
Jun 21, 2022
Feb 1, 2024
Jun 21, 2022
Jun 21, 2022
Jun 21, 2022
Aug 15, 2024
Oct 17, 2023
Oct 20, 2023
Oct 20, 2023
Mar 13, 2023
Mar 13, 2023

Repository files navigation

MTENN

GitHub Actions Build Status codecov Documentation Status

Modular Training and Evaluation of Neural Networks

Copyright

Copyright (c) 2022, Benjamin Kaminow

Minimal usage example

Building models should be done using the mtenn.config API. A small example for a SchNet model is shown below, but more details for SchNet and other models can be found in the respective class definitions.

We will construct a SchNet model with default parameters and a delta G strategy for combining our complex, protein, and ligand representations. We will leave our predictions in the returned implicit kT units (ie no Readout block).

from mtenn.config import SchNetModelConfig

# Create the config using all default parameters (which includes the delta G strategy)
model_config = SchNetModelConfig()

# Build the actual pytorch model
model = model.build()

The input passed to this model should be a dict with the following keys (based on the underlying model):

  • SchNet
    • z: Tensor of atomic number for each atom, shape of (n,)
    • pos: Tensor of coordinates for each atom, shape of (n,3)
  • E3NN
    • x: Tensor of one-hot encodings of element for each atom, shape of (n,one_hot_length)
    • pos: Tensor of coordinates for each atom, shape of (n,3)
    • z: Tensor of bool labels of whether each atom is a protein atom (False) or ligand atom (True), shape of (n,)
  • GAT
    • x: Tensor of input atom (node) features, shape of (n,feats)
    • edge_index: Tensor giving source (first row) and dest (second row) atom indices, shape of (2,n_bonds)

The prediction can then be generated simply with:

import torch

# Using random data just for demonstration purposes
pose = {"z": torch.randint(low=1, high=17, size=(100,)), "pos": torch.rand((100, 3))}
pred = model(pose)

Installation

mtenn is now on conda-forge! To install, simply run

mamba install -c conda-forge mtenn

Acknowledgements

Project based on the Computational Molecular Science Python Cookiecutter version 1.6.