From 16a32fff17bac9fa1847fd568ada6d0ef7ffa842 Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 17 Jul 2024 13:46:33 -0400 Subject: [PATCH 01/14] Add PyG-based GAT implementation. --- mtenn/conversion_utils/gat_pyg.py | 133 ++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 mtenn/conversion_utils/gat_pyg.py diff --git a/mtenn/conversion_utils/gat_pyg.py b/mtenn/conversion_utils/gat_pyg.py new file mode 100644 index 0000000..a27f929 --- /dev/null +++ b/mtenn/conversion_utils/gat_pyg.py @@ -0,0 +1,133 @@ +""" +``Representation`` and ``Strategy`` implementations for the graph attention model +architecture. The underlying model that we use is the implementation in +`PyTorch Geometric `_. +""" +from copy import deepcopy +import torch +from torch_geometric.nn.models import GAT as PygGAT + +from mtenn.model import LigandOnlyModel + + +class GAT(torch.nn.Module): + """ + ``mtenn`` wrapper around the PyTorch Geometric GAT model. This class handles + construction of the model and the formatting into ``Representation`` and + ``Strategy`` blocks. + """ + + def __init__(self, *args, model=None, **kwargs): + """ + Initialize the underlying ``torch_geometric.nn.models.GAT`` model. If a value is + passed for ``model``, builds a new ``torch_geometric.nn.models.GAT`` model based + on those hyperparameters, and copies over the weights. Otherwise, all ``*args`` + and ``**kwargs`` are passed directly to the ``torch_geometric.nn.models.GAT`` + constructor. + + Parameters + ---------- + model : ``torch_geometric.nn.models.GAT``, optional + PyTorch Geometric model to use to construct the underlying model + """ + super().__init__() + + # If no model is passed, construct model based on passed args, otherwise copy + # all parameters and weights over + if model is None: + self.gnn = PygGAT(*args, **kwargs) + else: + self.gnn = deepcopy(model) + + # Predict from mean of node features + self.predict = torch.nn.Linear(self.gnn.out_channels, 1) + + def forward(self, data): + """ + Make a prediction of the target property based on an input molecule graph. + + Parameters + ---------- + data : dict[str, torch.Tensor] + This dictionary should at minimum contain entries for: + + * ``"x"``: Atom coordinates, shape of (num_atoms, num_features) + + * ``"edge_index"``: All edges in the graph, shape of (2, num_edges) with the + first row giving the source node indices and the second row giving the + destination node indices for each edge + + Returns + ------- + torch.Tensor + Model prediction + """ + # Run through GNN + graph_gred = self.gnn(x=data["x"], edge_index=data["edge_index"]) + # Take mean of feature values across nodes + graph_gred = graph_gred.mean(dim=0) + # Make final prediction + return self.predict(graph_gred) + + def _get_representation(self): + """ + Input model, remove last layer. + + Returns + ------- + GAT + Copied GAT model with the last layer replaced by an Identity module + """ + + # Copy model so initial model isn't affected + model_copy = deepcopy(self.gnn) + + return model_copy + + def _get_energy_func(self): + """ + Return last layer of the model. + + Returns + ------- + torch.nn.Linear + Final energy prediction layer of the model + """ + + return deepcopy(self.readout) + + @staticmethod + def get_model( + *args, + model=None, + fix_device=False, + pred_readout=None, + **kwargs, + ): + """ + Exposed function to build a :py:class:`LigandOnlyModel + ` from a :py:class:`GAT + ` (or args/kwargs). If no ``model`` is given, + use the ``*args`` and ``**kwargs``. + + Parameters + ---------- + model: mtenn.conversion_utils.gat.GAT, optional + ``GAT`` model to use to build the ``LigandOnlyModel`` object. If not + provided, a model will be built using the passed ``*args`` and ``**kwargs`` + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary + pred_readout : mtenn.readout.Readout, optional + ``Readout`` object for the energy predictions + + Returns + ------- + mtenn.model.LigandOnlyModel + ``LigandOnlyModel`` object containing the model and desired ``Readout`` + """ + if model is None: + model = GAT(*args, **kwargs) + + return LigandOnlyModel(model=model, readout=pred_readout, fix_device=fix_device) From 8b6683d46e683e48c25b9df675fed5c9c2016618 Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 17 Jul 2024 14:23:39 -0400 Subject: [PATCH 02/14] Add PyG GAT to config. --- mtenn/config.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/mtenn/config.py b/mtenn/config.py index c88eff2..0acef36 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -78,6 +78,7 @@ class ModelType(StringEnum): """ GAT = "GAT" + gat_pyg = "gat_pyg" schnet = "schnet" e3nn = "e3nn" visnet = "visnet" @@ -679,6 +680,75 @@ def _update(self, config_updates={}) -> GATModelConfig: return GATModelConfig(**new_config) +class PyGGATModelConfig(ModelConfigBase): + """ + Class for constructing a GAT ML model. Default values here are based on the values + in DGL-LifeSci. + """ + + model_type: ModelType = Field(ModelType.schnet, const=True) + + in_channels: int = Field( + -1, + description=( + "Input size. Can be left as -1 (default) to interpret based on " + "first forward call." + ), + ) + hidden_channels: int = Field(32, description="Hidden embedding size.") + num_layers: int = Field(2, description="Number of GAT layers.") + dropout: float = Field(0, description="Dropout probability.") + heads: int = Field(4, description="Number of attention heads for each GAT layer.") + negative_slope: float = Field( + 0.2, description="LeakyReLU angle of the negative slope." + ) + + def _build(self, mtenn_params={}): + """ + Build an ``mtenn`` GAT ``Model`` from this config. + + :meta public: + + Parameters + ---------- + mtenn_params : dict, optional + Dictionary that stores the ``Readout`` objects for the individual + predictions and for the combined prediction, and the ``Combination`` object + in the case of a multi-pose model. These are all constructed the same for all + ``Model`` types, so we can just handle them in the base class. Keys in the + dict will be: + + * "combination": :py:mod:`Combination ` + + * "pred_readout": :py:mod:`Readout ` for individual + pose predictions + + * "comb_readout": :py:mod:`Readout ` for combined + prediction (in the case of a multi-pose model) + + although the combination-related entries will be ignore because this is a + ligand-only model. + + Returns + ------- + mtenn.model.Model + Model constructed from the config + """ + from mtenn.conversion_utils.gat_pyg import GAT + + model = GAT( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + num_layers=self.num_layers, + dropout=self.dropout, + heads=self.heads, + negative_slope=self.negative_slope, + ) + + pred_readout = mtenn_params.get("pred_readout", None) + return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) + + class SchNetModelConfig(ModelConfigBase): """ Class for constructing a SchNet ML model. Default values here are the default values From ace257f67cbddf76fe52ad5d14773bc7370d7531 Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 14 Aug 2024 16:24:59 -0400 Subject: [PATCH 03/14] Update model_type. --- mtenn/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index 0acef36..08e42c2 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -686,7 +686,7 @@ class PyGGATModelConfig(ModelConfigBase): in DGL-LifeSci. """ - model_type: ModelType = Field(ModelType.schnet, const=True) + model_type: ModelType = Field(ModelType.gat_pyg, const=True) in_channels: int = Field( -1, From b81ec61305a659d09f52705397a14f38e6627eb6 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 10:28:58 -0400 Subject: [PATCH 04/14] Migrate PyG-GAT -> GAT. --- mtenn/config.py | 282 +----------------------------- mtenn/conversion_utils/gat.py | 127 ++++---------- mtenn/conversion_utils/gat_pyg.py | 133 -------------- 3 files changed, 38 insertions(+), 504 deletions(-) delete mode 100644 mtenn/conversion_utils/gat_pyg.py diff --git a/mtenn/config.py b/mtenn/config.py index 08e42c2..504d3c6 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -78,7 +78,6 @@ class ModelType(StringEnum): """ GAT = "GAT" - gat_pyg = "gat_pyg" schnet = "schnet" e3nn = "e3nn" visnet = "visnet" @@ -404,289 +403,12 @@ def _check_grouped(values): class GATModelConfig(ModelConfigBase): - """ - Class for constructing a graph attention ML model. Note that there are two methods - for defining the size of the model: - - * If single values are passed for all parameters, the value of ``num_layers`` will - be used as the size of the model, and each layer will have the parameters given - - * If a list of values is passed for any parameters, all parameters must be lists of - the same size, or single values. For parameters that are single values, that same - value will be used for each layer. For parameters that are lists, those lists will - be used - - Parameters passed as strings are assumed to be comma-separated lists, and will first - be cast to lists of the appropriate type, and then processed as described above. - - If lists of multiple different (non-1) sizes are found, an error will be raised. - - Default values here are the default values given in DGL-LifeSci. - """ - - # Import as private, mainly so Sphinx doesn't autodoc it - from dgllife.utils import CanonicalAtomFeaturizer as _CanonicalAtomFeaturizer - - # Dict of model params that can be passed as a list, and the type that each will be - # cast to - LIST_PARAMS: ClassVar[dict] = { - "hidden_feats": int, - "num_heads": int, - "feat_drops": float, - "attn_drops": float, - "alphas": float, - "residuals": bool, - "agg_modes": str, - "activations": None, - "biases": bool, - } #: :meta private: - - model_type: ModelType = Field(ModelType.GAT, const=True) - - in_feats: int = Field( - _CanonicalAtomFeaturizer().feat_size(), - description=( - "Input node feature size. Defaults to size of the " - "``CanonicalAtomFeaturizer``." - ), - ) - num_layers: int = Field( - 2, - description=( - "Number of GAT layers. Ignored if a list of values is passed for any " - "other argument." - ), - ) - hidden_feats: str | int | list[int] = Field( - 32, - description=( - "Output size of each GAT layer. If an ``int`` is passed, the value for " - "``num_layers`` will be used to determine the size of the model. If a list " - "of ``int`` s is passed, the size of the model will be inferred from the " - "length of the list." - ), - ) - num_heads: str | int | list[int] = Field( - 4, - description=( - "Number of attention heads for each GAT layer. Passing an ``int`` or list " - "of ``int`` s functions similarly as for ``hidden_feats``." - ), - ) - feat_drops: str | float | list[float] = Field( - 0, - description=( - "Dropout of input features for each GAT layer. Passing a ``float`` or " - "list of ``float`` s functions similarly as for ``hidden_feats``." - ), - ) - attn_drops: str | float | list[float] = Field( - 0, - description=( - "Dropout of attention values for each GAT layer. Passing a ``float`` or " - "list of ``float`` s functions similarly as for ``hidden_feats``." - ), - ) - alphas: str | float | list[float] = Field( - 0.2, - description=( - "Hyperparameter for ``LeakyReLU`` gate for each GAT layer. Passing a " - "``float`` or list of ``float`` s functions similarly as for " - "``hidden_feats``." - ), - ) - residuals: str | bool | list[bool] = Field( - True, - description=( - "Whether to use residual connection for each GAT layer. Passing a ``bool`` " - "or list of ``bool`` s functions similarly as for ``hidden_feats``." - ), - ) - agg_modes: str | list[str] = Field( - "flatten", - description=( - "Which aggregation mode [flatten, mean] to use for each GAT layer. " - "Passing a ``str`` or list of ``str`` s functions similarly as for " - "``hidden_feats``." - ), - ) - activations: Callable | list[Callable] | list[None] | None = Field( - None, - description=( - "Activation function for each GAT layer. Passing a function or " - "list of functions functions similarly as for ``hidden_feats``." - ), - ) - biases: str | bool | list[bool] = Field( - True, - description=( - "Whether to use bias for each GAT layer. Passing a ``bool`` or " - "list of ``bool`` s functions similarly as for ``hidden_feats``." - ), - ) - allow_zero_in_degree: bool = Field( - False, description="Allow zero in degree nodes for all graph layers." - ) - - # Internal tracker for if the parameters were originally built from lists or using - # num_layers - _from_num_layers = False - - @root_validator(pre=False) - def massage_into_lists(cls, values) -> GATModelConfig: - """ - Validator to handle unifying all the values into the proper list forms based on - the rules described in the class docstring. - """ - # First convert string lists to actual lists - for param, param_type in cls.LIST_PARAMS.items(): - param_val = values[param] - if isinstance(param_val, str): - try: - param_val = list(map(param_type, param_val.split(","))) - except ValueError: - raise ValueError( - f"Unable to parse value {param_val} for parameter {param}. " - f"Expected type of {param_type}." - ) - values[param] = param_val - - # Get sizes of all lists - list_lens = {} - for p in cls.LIST_PARAMS: - param_val = values[p] - if not isinstance(param_val, list): - # Shouldn't be possible at this point but just in case - param_val = [param_val] - values[p] = param_val - list_lens[p] = len(param_val) - - # Check that there's only one length present - list_lens_set = set(list_lens.values()) - # This could be 0 if lists of length 1 were passed, which is valid - if len(list_lens_set - {1}) > 1: - raise ValueError( - "All passed parameter lists must be the same value. " - f"Instead got list lengths of: {list_lens}" - ) - elif list_lens_set == {1}: - # If all lists have only one value, we defer to the value passed to - # num_layers, as described in the class docstring - num_layers = values["num_layers"] - values["_from_num_layers"] = True - else: - num_layers = max(list_lens_set) - values["_from_num_layers"] = False - - values["num_layers"] = num_layers - # If we just want a model with one layer, can return early since we've already - # converted everything into lists - if num_layers == 1: - return values - - # Adjust any length 1 list to be the right length - for p, list_len in list_lens.items(): - if list_len == 1: - values[p] = values[p] * num_layers - - return values - - def _build(self, mtenn_params={}): - """ - Build an ``mtenn`` GAT ``Model`` from this config. - - :meta public: - - Parameters - ---------- - mtenn_params : dict, optional - Dictionary that stores the ``Readout`` objects for the individual - predictions and for the combined prediction, and the ``Combination`` object - in the case of a multi-pose model. These are all constructed the same for all - ``Model`` types, so we can just handle them in the base class. Keys in the - dict will be: - - * "combination": :py:mod:`Combination ` - - * "pred_readout": :py:mod:`Readout ` for individual - pose predictions - - * "comb_readout": :py:mod:`Readout ` for combined - prediction (in the case of a multi-pose model) - - although the combination-related entries will be ignore because this is a - ligand-only model. - - Returns - ------- - mtenn.model.Model - Model constructed from the config - """ - from mtenn.conversion_utils.gat import GAT - - model = GAT( - in_feats=self.in_feats, - hidden_feats=self.hidden_feats, - num_heads=self.num_heads, - feat_drops=self.feat_drops, - attn_drops=self.attn_drops, - alphas=self.alphas, - residuals=self.residuals, - agg_modes=self.agg_modes, - activations=self.activations, - biases=self.biases, - allow_zero_in_degree=self.allow_zero_in_degree, - ) - - pred_readout = mtenn_params.get("pred_readout", None) - return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) - - def _update(self, config_updates={}) -> GATModelConfig: - """ - GAT-specific implementation of updating logic. Need to handle stuff specially - to make sure that the original method of specifying parameters (either from a - passed value of ``num_layers`` or inferred from each parameter being a list) is - maintained. - - :meta public: - - Parameters - ---------- - config_updates : dict - Dictionary mapping from field names to new values - - Returns - ------- - GATModelConfig - New ``GATModelConfig`` object - """ - orig_config = self.dict() - if self._from_num_layers or ("num_layers" in config_updates): - # If originally generated from num_layers, want to pull out the first entry - # in each list param so it can be re-broadcast with (potentially) new - # num_layers - for param_name in GATModelConfig.LIST_PARAMS.keys(): - orig_config[param_name] = orig_config[param_name][0] - - # Get new config by overwriting old stuff with any new stuff - new_config = orig_config | config_updates - - # A bit hacky, maybe try and change? - if isinstance(new_config["activations"], list) and ( - new_config["activations"][0] is None - ): - new_config["activations"] = None - - return GATModelConfig(**new_config) - - -class PyGGATModelConfig(ModelConfigBase): """ Class for constructing a GAT ML model. Default values here are based on the values in DGL-LifeSci. """ - model_type: ModelType = Field(ModelType.gat_pyg, const=True) + model_type: ModelType = Field(ModelType.gat, const=True) in_channels: int = Field( -1, @@ -734,7 +456,7 @@ def _build(self, mtenn_params={}): mtenn.model.Model Model constructed from the config """ - from mtenn.conversion_utils.gat_pyg import GAT + from mtenn.conversion_utils.gat import GAT model = GAT( in_channels=self.in_channels, diff --git a/mtenn/conversion_utils/gat.py b/mtenn/conversion_utils/gat.py index eca5791..a27f929 100644 --- a/mtenn/conversion_utils/gat.py +++ b/mtenn/conversion_utils/gat.py @@ -1,108 +1,47 @@ """ ``Representation`` and ``Strategy`` implementations for the graph attention model -architecture. The underlying model that we use is the implementation in the -`DGL-LifeSCi `_ -package. +architecture. The underlying model that we use is the implementation in +`PyTorch Geometric `_. """ from copy import deepcopy import torch -from dgllife.model import GAT as GAT_dgl -from dgllife.model import WeightedSumAndMax +from torch_geometric.nn.models import GAT as PygGAT from mtenn.model import LigandOnlyModel class GAT(torch.nn.Module): """ - ``mtenn`` wrapper around the DGL-LifeSci GAT model. This class handles construction - of the model and the formatting into ``Representation`` and ``Strategy`` blocks. + ``mtenn`` wrapper around the PyTorch Geometric GAT model. This class handles + construction of the model and the formatting into ``Representation`` and + ``Strategy`` blocks. """ def __init__(self, *args, model=None, **kwargs): """ - Initialize the underlying ``dgllife.model.GAT`` model, as well as the ``mtenn`` - -specific code on top. If a value is passed for ``model``, builds a new - ``dgllife.model.GAT`` model based on those hyperparameters, and copies over the - weights. Otherwise, all ``*args`` and ``**kwargs`` are passed directly to the - ``dgllife.model.GAT`` constructor. + Initialize the underlying ``torch_geometric.nn.models.GAT`` model. If a value is + passed for ``model``, builds a new ``torch_geometric.nn.models.GAT`` model based + on those hyperparameters, and copies over the weights. Otherwise, all ``*args`` + and ``**kwargs`` are passed directly to the ``torch_geometric.nn.models.GAT`` + constructor. Parameters ---------- - model : ``dgllife.model.GAT``, optional - DGL-LifeSci model to use to construct the underlying model + model : ``torch_geometric.nn.models.GAT``, optional + PyTorch Geometric model to use to construct the underlying model """ super().__init__() - # First check for predictor_hidden_feats so it doesn't get passed to DGL GAT - # constructor - predictor_hidden_feats = kwargs.pop("predictor_hidden_feats", None) - # If no model is passed, construct model based on passed args, otherwise copy # all parameters and weights over if model is None: - self.gnn = GAT_dgl(*args, **kwargs) - else: - # Parameters that are conveniently accessible from the top level - in_feats = model.gnn_layers[0].gat_conv.fc.in_features - hidden_feats = model.hidden_feats - num_heads = model.num_heads - agg_modes = model.agg_modes - # Parameters that can only be adcessed layer-wise - layer_params = [] - for l in model.gnn_layers: - gc = l.gat_conv - new_params = ( - gc.feat_drop.p, - gc.attn_drop.p, - gc.leaky_relu.negative_slope, - gc.activation, - bool(gc.res_fc), - (gc.res_fc.bias is not None) - if gc.has_linear_res - else gc.has_explicit_bias, - ) - layer_params += [new_params] - - ( - feat_drops, - attn_drops, - alphas, - activations, - residuals, - biases, - ) = zip(*layer_params) - self.gnn = GAT_dgl( - in_feats=in_feats, - hidden_feats=hidden_feats, - num_heads=num_heads, - feat_drops=feat_drops, - attn_drops=attn_drops, - alphas=alphas, - residuals=residuals, - agg_modes=agg_modes, - activations=activations, - biases=biases, - ) - self.gnn.load_state_dict(model.state_dict()) - - # Copied from GATPredictor class, figure out how many features the last - # layer of the GNN will have - if self.gnn.agg_modes[-1] == "flatten": - gnn_out_feats = self.gnn.hidden_feats[-1] * self.gnn.num_heads[-1] + self.gnn = PygGAT(*args, **kwargs) else: - gnn_out_feats = self.gnn.hidden_feats[-1] - self.readout = WeightedSumAndMax(gnn_out_feats) - - # Use given hidden feats if supplied, otherwise use 1/2 gnn_out_feats - if predictor_hidden_feats is None: - predictor_hidden_feats = gnn_out_feats // 2 + self.gnn = deepcopy(model) - # 2 layer MLP with ReLU activation (borrowed from GATPredictor) - self.predict = torch.nn.Sequential( - torch.nn.Linear(2 * gnn_out_feats, predictor_hidden_feats), - torch.nn.ReLU(), - torch.nn.Linear(predictor_hidden_feats, 1), - ) + # Predict from mean of node features + self.predict = torch.nn.Linear(self.gnn.out_channels, 1) def forward(self, data): """ @@ -110,20 +49,26 @@ def forward(self, data): Parameters ---------- - data : dict - This dictionary should at minimum contain an entry for ``"g"``, which should - be the molecule graph representation and will be passed to the underlying - ``dgllife.model.GAT`` object + data : dict[str, torch.Tensor] + This dictionary should at minimum contain entries for: + + * ``"x"``: Atom coordinates, shape of (num_atoms, num_features) + + * ``"edge_index"``: All edges in the graph, shape of (2, num_edges) with the + first row giving the source node indices and the second row giving the + destination node indices for each edge Returns ------- torch.Tensor Model prediction """ - g = data["g"] - node_feats = self.gnn(g, g.ndata["h"]) - graph_feats = self.readout(g, node_feats) - return self.predict(graph_feats) + # Run through GNN + graph_gred = self.gnn(x=data["x"], edge_index=data["edge_index"]) + # Take mean of feature values across nodes + graph_gred = graph_gred.mean(dim=0) + # Make final prediction + return self.predict(graph_gred) def _get_representation(self): """ @@ -142,15 +87,15 @@ def _get_representation(self): def _get_energy_func(self): """ - Return last two layer of the model. + Return last layer of the model. Returns ------- - torch.nn.Sequential - Sequential module calling copy of `model`'s last two layers + torch.nn.Linear + Final energy prediction layer of the model """ - return torch.nn.Sequential(deepcopy(self.readout), deepcopy(self.predict)) + return deepcopy(self.readout) @staticmethod def get_model( diff --git a/mtenn/conversion_utils/gat_pyg.py b/mtenn/conversion_utils/gat_pyg.py deleted file mode 100644 index a27f929..0000000 --- a/mtenn/conversion_utils/gat_pyg.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -``Representation`` and ``Strategy`` implementations for the graph attention model -architecture. The underlying model that we use is the implementation in -`PyTorch Geometric `_. -""" -from copy import deepcopy -import torch -from torch_geometric.nn.models import GAT as PygGAT - -from mtenn.model import LigandOnlyModel - - -class GAT(torch.nn.Module): - """ - ``mtenn`` wrapper around the PyTorch Geometric GAT model. This class handles - construction of the model and the formatting into ``Representation`` and - ``Strategy`` blocks. - """ - - def __init__(self, *args, model=None, **kwargs): - """ - Initialize the underlying ``torch_geometric.nn.models.GAT`` model. If a value is - passed for ``model``, builds a new ``torch_geometric.nn.models.GAT`` model based - on those hyperparameters, and copies over the weights. Otherwise, all ``*args`` - and ``**kwargs`` are passed directly to the ``torch_geometric.nn.models.GAT`` - constructor. - - Parameters - ---------- - model : ``torch_geometric.nn.models.GAT``, optional - PyTorch Geometric model to use to construct the underlying model - """ - super().__init__() - - # If no model is passed, construct model based on passed args, otherwise copy - # all parameters and weights over - if model is None: - self.gnn = PygGAT(*args, **kwargs) - else: - self.gnn = deepcopy(model) - - # Predict from mean of node features - self.predict = torch.nn.Linear(self.gnn.out_channels, 1) - - def forward(self, data): - """ - Make a prediction of the target property based on an input molecule graph. - - Parameters - ---------- - data : dict[str, torch.Tensor] - This dictionary should at minimum contain entries for: - - * ``"x"``: Atom coordinates, shape of (num_atoms, num_features) - - * ``"edge_index"``: All edges in the graph, shape of (2, num_edges) with the - first row giving the source node indices and the second row giving the - destination node indices for each edge - - Returns - ------- - torch.Tensor - Model prediction - """ - # Run through GNN - graph_gred = self.gnn(x=data["x"], edge_index=data["edge_index"]) - # Take mean of feature values across nodes - graph_gred = graph_gred.mean(dim=0) - # Make final prediction - return self.predict(graph_gred) - - def _get_representation(self): - """ - Input model, remove last layer. - - Returns - ------- - GAT - Copied GAT model with the last layer replaced by an Identity module - """ - - # Copy model so initial model isn't affected - model_copy = deepcopy(self.gnn) - - return model_copy - - def _get_energy_func(self): - """ - Return last layer of the model. - - Returns - ------- - torch.nn.Linear - Final energy prediction layer of the model - """ - - return deepcopy(self.readout) - - @staticmethod - def get_model( - *args, - model=None, - fix_device=False, - pred_readout=None, - **kwargs, - ): - """ - Exposed function to build a :py:class:`LigandOnlyModel - ` from a :py:class:`GAT - ` (or args/kwargs). If no ``model`` is given, - use the ``*args`` and ``**kwargs``. - - Parameters - ---------- - model: mtenn.conversion_utils.gat.GAT, optional - ``GAT`` model to use to build the ``LigandOnlyModel`` object. If not - provided, a model will be built using the passed ``*args`` and ``**kwargs`` - fix_device: bool, default=False - If True, make sure the input is on the same device as the model, - copying over as necessary - pred_readout : mtenn.readout.Readout, optional - ``Readout`` object for the energy predictions - - Returns - ------- - mtenn.model.LigandOnlyModel - ``LigandOnlyModel`` object containing the model and desired ``Readout`` - """ - if model is None: - model = GAT(*args, **kwargs) - - return LigandOnlyModel(model=model, readout=pred_readout, fix_device=fix_device) From 6ab110a4e0ceb8c5726d8a32fd6640b2c744348e Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 11:45:10 -0400 Subject: [PATCH 05/14] Update GAT tests. --- mtenn/tests/test_gat.py | 63 +++++++++++++++-------------------------- 1 file changed, 23 insertions(+), 40 deletions(-) diff --git a/mtenn/tests/test_gat.py b/mtenn/tests/test_gat.py index 4663da2..83ac7db 100644 --- a/mtenn/tests/test_gat.py +++ b/mtenn/tests/test_gat.py @@ -1,73 +1,56 @@ import pytest import torch -from dgllife.model import GAT as GAT_dgl -from dgllife.utils import CanonicalAtomFeaturizer, SMILESToBigraph +from asapdiscovery.data.backend.openeye import featurize_smiles +from torch_geometric.nn.models import GAT as PygGAT from mtenn.conversion_utils.gat import GAT @pytest.fixture def model_input(): smiles = "CCCC" - g = SMILESToBigraph(add_self_loop=True, node_featurizer=CanonicalAtomFeaturizer())( - smiles - ) + feature_tensor, bond_list_tensor = featurize_smiles(smiles) - return {"g": g, "smiles": smiles} + return {"x": feature_tensor, "edge_index": bond_list_tensor, "smiles": smiles} def test_build_gat_directly_kwargs(): - model = GAT(in_feats=10, hidden_feats=[1, 2, 3]) - assert len(model.gnn.gnn_layers) == 3 + model = GAT(in_channels=-1, hidden_channels=32, num_layers=2) + assert model.gnn.num_layers == 2 - assert model.gnn.gnn_layers[0].gat_conv._in_src_feats == 10 - assert model.gnn.gnn_layers[0].gat_conv._out_feats == 1 + assert model.gnn.convs[0].in_channels == -1 + assert model.gnn.convs[0].out_channels == 32 - # hidden_feats * num_heads = 1 * 4 - assert model.gnn.gnn_layers[1].gat_conv._in_src_feats == 4 - assert model.gnn.gnn_layers[1].gat_conv._out_feats == 2 + assert model.gnn.convs[1].in_channels == 32 + assert model.gnn.convs[1].out_channels == 32 - # hidden_feats * num_heads = 2 * 4 - assert model.gnn.gnn_layers[2].gat_conv._in_src_feats == 8 - assert model.gnn.gnn_layers[2].gat_conv._out_feats == 3 - -def test_build_gat_from_dgl_gat(): - dgl_model = GAT_dgl(in_feats=10, hidden_feats=[1, 2, 3]) - model = GAT(model=dgl_model) +def test_build_gat_from_pyg_gat(): + pyg_model = PygGAT(in_channels=10, hidden_channels=32, num_layers=2) + model = GAT(model=pyg_model) # Check set up as before - assert len(model.gnn.gnn_layers) == 3 - - assert model.gnn.gnn_layers[0].gat_conv._in_src_feats == 10 - assert model.gnn.gnn_layers[0].gat_conv._out_feats == 1 + assert model.gnn.num_layers == 2 - # hidden_feats * num_heads = 1 * 4 - assert model.gnn.gnn_layers[1].gat_conv._in_src_feats == 4 - assert model.gnn.gnn_layers[1].gat_conv._out_feats == 2 + assert model.gnn.convs[0].in_channels == 10 + assert model.gnn.convs[0].out_channels == 32 - # hidden_feats * num_heads = 2 * 4 - assert model.gnn.gnn_layers[2].gat_conv._in_src_feats == 8 - assert model.gnn.gnn_layers[2].gat_conv._out_feats == 3 + assert model.gnn.convs[1].in_channels == 32 + assert model.gnn.convs[1].out_channels == 32 # Check that model weights got copied - ref_params = dict(dgl_model.state_dict()) + ref_params = dict(pyg_model.state_dict()) for n, model_param in model.gnn.named_parameters(): assert (model_param == ref_params[n]).all() -def test_set_predictor_hidden_feats(): - model = GAT(in_feats=10, predictor_hidden_feats=10) - assert model.predict[0].out_features == 10 - - def test_gat_can_predict(model_input): - model = GAT(in_feats=CanonicalAtomFeaturizer().feat_size()) + model = GAT(in_channels=-1, hidden_channels=32, num_layers=2) _ = model(model_input) def test_representation_is_correct(): - model = GAT(in_feats=10) + model = GAT(in_channels=10, hidden_channels=32, num_layers=2) rep = model._get_representation() model_params = dict(model.gnn.named_parameters()) @@ -76,14 +59,14 @@ def test_representation_is_correct(): def test_get_model_no_ref(): - model = GAT.get_model(in_feats=10) + model = GAT.get_model(in_channels=10, hidden_channels=32, num_layers=2) assert isinstance(model.representation, GAT) assert model.readout is None def test_get_model_ref(): - ref_model = GAT(in_feats=10) + ref_model = GAT(in_channels=10, hidden_channels=32, num_layers=2) model = GAT.get_model(model=ref_model) assert isinstance(model.representation, GAT) From 30a16f2b28b8c8613828a4966591ebf37419b063 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 11:49:06 -0400 Subject: [PATCH 06/14] Fix README usage. --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ac2da6f..7f59de4 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,8 @@ The input passed to this model should be a `dict` with the following keys (based * `pos`: Tensor of coordinates for each atom, shape of `(n,3)` * `z`: Tensor of bool labels of whether each atom is a protein atom (`False`) or ligand atom (`True`), shape of `(n,)` * `GAT` - * `g`: DGL graph object + * `x`: Tensor of input atom (node) features, shape of `(n,feats)` + * `edge_index`: Tensor giving source (first row) and dest (second row) atom indices, shape of `(2,n_bonds)` The prediction can then be generated simply with: ```python From a98c0a06b5d7e2963e8d17f41b40d5f324a89585 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 13:26:29 -0400 Subject: [PATCH 07/14] Typo --- mtenn/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index 504d3c6..b5cc904 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -408,7 +408,7 @@ class GATModelConfig(ModelConfigBase): in DGL-LifeSci. """ - model_type: ModelType = Field(ModelType.gat, const=True) + model_type: ModelType = Field(ModelType.GAT, const=True) in_channels: int = Field( -1, From 7c8cddfa5dabff35210ce6ea88d0381cf31f1b9a Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 13:26:47 -0400 Subject: [PATCH 08/14] Update to PyG example. --- docs/docs/basic_usage.rst | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/docs/docs/basic_usage.rst b/docs/docs/basic_usage.rst index 00f769f..e6d6afe 100644 --- a/docs/docs/basic_usage.rst +++ b/docs/docs/basic_usage.rst @@ -7,19 +7,37 @@ Below, we detail a basic example of building a default Graph Attention model and .. code-block:: python - from dgllife.utils import CanonicalAtomFeaturizer, SMILESToBigraph from mtenn.config import GATModelConfig + import rdkit.Chem as Chem + import torch # Build model with GAT defaults model = GATModelConfig().build() - # Build graph from SMILES + # Build mol smiles = "CCCC" - g = SMILESToBigraph( - add_self_loop=True, - node_featurizer=CanonicalAtomFeaturizer(), - )(smiles) + mol = Chem.MolFromSmiles(smiles) + + # Get atomic numbers and bond indices (both directions) + atomic_nums = [a.GetAtomicNum() for a in mol.GetAtoms()] + bond_idxs = [ + atom_pair + for bond in mol.GetBonds() + for atom_pair in ( + (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()), + (bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()), + ) + ] + # Add self bonds + bond_idxs += [(a.GetIdx(), a.GetIdx()) for a in mol.GetAtoms()] + + # Encode atomic numbers as one-hot, assume max num of 100 + node_feats = torch.nn.functional.one_hot( + torch.tensor(atomic_nums), num_classes=100 + ).to(dtype=torch.float) + # Format bonds in correct shape + edge_index = torch.tensor(bond_idxs).t() # Make a prediction - pred, _ = model({"g": g}) + pred, _ = model({"x": node_feats, "edge_index": edge_index}) From 71d1d8ec8c870ab8de82f72576af81b5d8b9f203 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 13:29:16 -0400 Subject: [PATCH 09/14] Remove dgl deps. --- devtools/conda-envs/mtenn.yaml | 2 -- devtools/conda-envs/test_env.yaml | 2 -- docs/requirements.yaml | 2 -- environment-gpu.yml | 4 +--- environment.yml | 4 +--- 5 files changed, 2 insertions(+), 12 deletions(-) diff --git a/devtools/conda-envs/mtenn.yaml b/devtools/conda-envs/mtenn.yaml index 8139a84..208c817 100644 --- a/devtools/conda-envs/mtenn.yaml +++ b/devtools/conda-envs/mtenn.yaml @@ -10,7 +10,5 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - ase diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 9fce2b8..43c95b0 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -10,8 +10,6 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - ase - fsspec diff --git a/docs/requirements.yaml b/docs/requirements.yaml index 5382ba3..b822114 100644 --- a/docs/requirements.yaml +++ b/docs/requirements.yaml @@ -10,8 +10,6 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - ase - pydantic >=1.10.8,<2.0.0a0 diff --git a/environment-gpu.yml b/environment-gpu.yml index de93784..811edf3 100644 --- a/environment-gpu.yml +++ b/environment-gpu.yml @@ -11,7 +11,5 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - - rdkit + - rdkit - ase diff --git a/environment.yml b/environment.yml index 7a679c7..208c817 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,5 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - - ase \ No newline at end of file + - ase From 01d12ca6dd727573ed406eaaaf86ca0040420553 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 13:33:59 -0400 Subject: [PATCH 10/14] Featurize SMILES with rdkit. --- mtenn/tests/test_gat.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/mtenn/tests/test_gat.py b/mtenn/tests/test_gat.py index 83ac7db..2491028 100644 --- a/mtenn/tests/test_gat.py +++ b/mtenn/tests/test_gat.py @@ -1,15 +1,36 @@ import pytest import torch -from asapdiscovery.data.backend.openeye import featurize_smiles -from torch_geometric.nn.models import GAT as PygGAT from mtenn.conversion_utils.gat import GAT +import rdkit.Chem as Chem +from torch_geometric.nn.models import GAT as PygGAT @pytest.fixture def model_input(): + # Build mol smiles = "CCCC" - feature_tensor, bond_list_tensor = featurize_smiles(smiles) + mol = Chem.MolFromSmiles(smiles) + + # Get atomic numbers and bond indices (both directions) + atomic_nums = [a.GetAtomicNum() for a in mol.GetAtoms()] + bond_idxs = [ + atom_pair + for bond in mol.GetBonds() + for atom_pair in ( + (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()), + (bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()), + ) + ] + # Add self bonds + bond_idxs += [(a.GetIdx(), a.GetIdx()) for a in mol.GetAtoms()] + + # Encode atomic numbers as one-hot, assume max num of 100 + feature_tensor = torch.nn.functional.one_hot( + torch.tensor(atomic_nums), num_classes=100 + ).to(dtype=torch.float) + # Format bonds in correct shape + bond_list_tensor = torch.tensor(bond_idxs).t() return {"x": feature_tensor, "edge_index": bond_list_tensor, "smiles": smiles} From 97eb662e52083bf7f83d36b1e8d9da5e683151a2 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 15 Aug 2024 13:41:02 -0400 Subject: [PATCH 11/14] Remove irrelevant tests and fix some config args. --- mtenn/tests/test_model_config.py | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index 3fa0d5b..5e4eb9f 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -12,8 +12,8 @@ def test_random_seed_gat(): - rand_config = GATModelConfig() - set_config = GATModelConfig(rand_seed=10) + rand_config = GATModelConfig(in_channels=10) + set_config = GATModelConfig(in_channels=10, rand_seed=10) rand_model1 = rand_config.build() rand_model2 = rand_config.build() @@ -44,6 +44,7 @@ def test_random_seed_gat(): ) def test_readout_gat(pred_r, pred_r_class, pred_r_args): model = GATModelConfig( + in_channels=10, pred_readout=pred_r, pred_substrate=pred_r_args[0], pred_km=pred_r_args[1], @@ -60,34 +61,14 @@ def test_readout_gat(pred_r, pred_r_class, pred_r_args): def test_model_weights_gat(): - model1 = GATModelConfig().build() - model2 = GATModelConfig(model_weights=model1.state_dict()).build() + model1 = GATModelConfig(in_channels=10).build() + model2 = GATModelConfig(in_channels=10, model_weights=model1.state_dict()).build() test_model_params = dict(model2.named_parameters()) for n, ref_param in model1.named_parameters(): assert (ref_param == test_model_params[n]).all() -def test_no_diff_list_lengths_gat(): - with pytest.raises(ValueError): - # Different length lists should raise error - _ = GATModelConfig(hidden_feats=[1, 2, 3], num_heads=[4, 5]) - - -def test_bad_param_mapping_gat(): - with pytest.raises(ValueError): - # Can't convert string to int - _ = GATModelConfig(hidden_feats="sdf") - - -def test_can_pass_lists_gat(): - model_config = GATModelConfig(hidden_feats=[1, 2, 3]) - model = model_config.build() - - assert len(model.representation.gnn.gnn_layers) == 3 - assert not model_config._from_num_layers - - def test_random_seed_e3nn(): rand_config = E3NNModelConfig() set_config = E3NNModelConfig(rand_seed=10) From a9352e4119cee1a6203df829370902f095595092 Mon Sep 17 00:00:00 2001 From: kaminow Date: Mon, 9 Sep 2024 14:44:42 -0400 Subject: [PATCH 12/14] Add side-by-side PyG and DGL versions. --- mtenn/conversion_utils/dgl_gat.py | 188 ++++++++++++++++++++++++++++++ mtenn/conversion_utils/pyg_gat.py | 1 + 2 files changed, 189 insertions(+) create mode 100644 mtenn/conversion_utils/dgl_gat.py create mode 100644 mtenn/conversion_utils/pyg_gat.py diff --git a/mtenn/conversion_utils/dgl_gat.py b/mtenn/conversion_utils/dgl_gat.py new file mode 100644 index 0000000..0b71e4b --- /dev/null +++ b/mtenn/conversion_utils/dgl_gat.py @@ -0,0 +1,188 @@ +""" +``Representation`` and ``Strategy`` implementations for the graph attention model +architecture. The underlying model that we use is the implementation in the +`DGL-LifeSCi `_ +package. +""" +from copy import deepcopy +import torch +from dgllife.model import GAT as GAT_dgl +from dgllife.model import WeightedSumAndMax + +from mtenn.model import LigandOnlyModel + + +class DGLGAT(torch.nn.Module): + """ + ``mtenn`` wrapper around the DGL-LifeSci GAT model. This class handles construction + of the model and the formatting into ``Representation`` and ``Strategy`` blocks. + """ + + def __init__(self, *args, model=None, **kwargs): + """ + Initialize the underlying ``dgllife.model.GAT`` model, as well as the ``mtenn`` + -specific code on top. If a value is passed for ``model``, builds a new + ``dgllife.model.GAT`` model based on those hyperparameters, and copies over the + weights. Otherwise, all ``*args`` and ``**kwargs`` are passed directly to the + ``dgllife.model.GAT`` constructor. + + Parameters + ---------- + model : ``dgllife.model.GAT``, optional + DGL-LifeSci model to use to construct the underlying model + """ + super().__init__() + + # First check for predictor_hidden_feats so it doesn't get passed to DGL GAT + # constructor + predictor_hidden_feats = kwargs.pop("predictor_hidden_feats", None) + + # If no model is passed, construct model based on passed args, otherwise copy + # all parameters and weights over + if model is None: + self.gnn = GAT_dgl(*args, **kwargs) + else: + # Parameters that are conveniently accessible from the top level + in_feats = model.gnn_layers[0].gat_conv.fc.in_features + hidden_feats = model.hidden_feats + num_heads = model.num_heads + agg_modes = model.agg_modes + # Parameters that can only be adcessed layer-wise + layer_params = [] + for l in model.gnn_layers: + gc = l.gat_conv + new_params = ( + gc.feat_drop.p, + gc.attn_drop.p, + gc.leaky_relu.negative_slope, + gc.activation, + bool(gc.res_fc), + (gc.res_fc.bias is not None) + if gc.has_linear_res + else gc.has_explicit_bias, + ) + layer_params += [new_params] + + ( + feat_drops, + attn_drops, + alphas, + activations, + residuals, + biases, + ) = zip(*layer_params) + self.gnn = GAT_dgl( + in_feats=in_feats, + hidden_feats=hidden_feats, + num_heads=num_heads, + feat_drops=feat_drops, + attn_drops=attn_drops, + alphas=alphas, + residuals=residuals, + agg_modes=agg_modes, + activations=activations, + biases=biases, + ) + self.gnn.load_state_dict(model.state_dict()) + + # Copied from GATPredictor class, figure out how many features the last + # layer of the GNN will have + if self.gnn.agg_modes[-1] == "flatten": + gnn_out_feats = self.gnn.hidden_feats[-1] * self.gnn.num_heads[-1] + else: + gnn_out_feats = self.gnn.hidden_feats[-1] + self.readout = WeightedSumAndMax(gnn_out_feats) + + # Use given hidden feats if supplied, otherwise use 1/2 gnn_out_feats + if predictor_hidden_feats is None: + predictor_hidden_feats = gnn_out_feats // 2 + + # 2 layer MLP with ReLU activation (borrowed from GATPredictor) + self.predict = torch.nn.Sequential( + torch.nn.Linear(2 * gnn_out_feats, predictor_hidden_feats), + torch.nn.ReLU(), + torch.nn.Linear(predictor_hidden_feats, 1), + ) + + def forward(self, data): + """ + Make a prediction of the target property based on an input molecule graph. + + Parameters + ---------- + data : dict + This dictionary should at minimum contain an entry for ``"g"``, which should + be the molecule graph representation and will be passed to the underlying + ``dgllife.model.GAT`` object + + Returns + ------- + torch.Tensor + Model prediction + """ + g = data["g"] + node_feats = self.gnn(g, g.ndata["h"]) + graph_feats = self.readout(g, node_feats) + return self.predict(graph_feats) + + def _get_representation(self): + """ + Input model, remove last layer. + + Returns + ------- + DGLGAT + Copied DGLGAT model with the last layer replaced by an Identity module + """ + + # Copy model so initial model isn't affected + model_copy = deepcopy(self.gnn) + + return model_copy + + def _get_energy_func(self): + """ + Return last two layer of the model. + + Returns + ------- + torch.nn.Sequential + Sequential module calling copy of `model`'s last two layers + """ + + return torch.nn.Sequential(deepcopy(self.readout), deepcopy(self.predict)) + + @staticmethod + def get_model( + *args, + model=None, + fix_device=False, + pred_readout=None, + **kwargs, + ): + """ + Exposed function to build a :py:class:`LigandOnlyModel + ` from a :py:class:`DGLGAT + ` (or args/kwargs). If no ``model`` is given, + use the ``*args`` and ``**kwargs``. + + Parameters + ---------- + model: mtenn.conversion_utils.dgl_gat.DGLGAT, optional + ``DGLGAT`` model to use to build the ``LigandOnlyModel`` object. If not + provided, a model will be built using the passed ``*args`` and ``**kwargs`` + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary + pred_readout : mtenn.readout.Readout, optional + ``Readout`` object for the energy predictions + + Returns + ------- + mtenn.model.LigandOnlyModel + ``LigandOnlyModel`` object containing the model and desired ``Readout`` + """ + if model is None: + model = DGLGAT(*args, **kwargs) + + return LigandOnlyModel(model=model, readout=pred_readout, fix_device=fix_device) diff --git a/mtenn/conversion_utils/pyg_gat.py b/mtenn/conversion_utils/pyg_gat.py new file mode 100644 index 0000000..62bf473 --- /dev/null +++ b/mtenn/conversion_utils/pyg_gat.py @@ -0,0 +1 @@ +from mtenn.conversion_utils.gat import GAT as PyGGAT # noqa: F401 From 9a9f1f67038e7e604148a5673036721b6f7c068b Mon Sep 17 00:00:00 2001 From: kaminow Date: Mon, 9 Sep 2024 14:45:10 -0400 Subject: [PATCH 13/14] Add configs for both GAT versions. --- mtenn/config.py | 328 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) diff --git a/mtenn/config.py b/mtenn/config.py index b5cc904..29eae63 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -78,6 +78,8 @@ class ModelType(StringEnum): """ GAT = "GAT" + pyg_gat = "pyg_gat" + dgl_gat = "dgl_gat" schnet = "schnet" e3nn = "e3nn" visnet = "visnet" @@ -471,6 +473,332 @@ def _build(self, mtenn_params={}): return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) +class PyGGATModelConfig(GATModelConfig): + model_type: ModelType = Field(ModelType.pyg_gat, const=True) + + def _build(self, mtenn_params={}): + """ + Build an ``mtenn`` PyGGAT ``Model`` from this config. + + :meta public: + + Parameters + ---------- + mtenn_params : dict, optional + Dictionary that stores the ``Readout`` objects for the individual + predictions and for the combined prediction, and the ``Combination`` object + in the case of a multi-pose model. These are all constructed the same for all + ``Model`` types, so we can just handle them in the base class. Keys in the + dict will be: + + * "combination": :py:mod:`Combination ` + + * "pred_readout": :py:mod:`Readout ` for individual + pose predictions + + * "comb_readout": :py:mod:`Readout ` for combined + prediction (in the case of a multi-pose model) + + although the combination-related entries will be ignore because this is a + ligand-only model. + + Returns + ------- + mtenn.model.Model + Model constructed from the config + """ + from mtenn.conversion_utils.pyg_gat import PyGGAT + + model = PyGGAT( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + num_layers=self.num_layers, + dropout=self.dropout, + heads=self.heads, + negative_slope=self.negative_slope, + ) + + pred_readout = mtenn_params.get("pred_readout", None) + return PyGGAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) + + +class DGLGATModelConfig(ModelConfigBase): + """ + Class for constructing a graph attention ML model. Note that there are two methods + for defining the size of the model: + + * If single values are passed for all parameters, the value of ``num_layers`` will + be used as the size of the model, and each layer will have the parameters given + + * If a list of values is passed for any parameters, all parameters must be lists of + the same size, or single values. For parameters that are single values, that same + value will be used for each layer. For parameters that are lists, those lists will + be used + + Parameters passed as strings are assumed to be comma-separated lists, and will first + be cast to lists of the appropriate type, and then processed as described above. + + If lists of multiple different (non-1) sizes are found, an error will be raised. + + Default values here are the default values given in DGL-LifeSci. + """ + + # Import as private, mainly so Sphinx doesn't autodoc it + from dgllife.utils import CanonicalAtomFeaturizer as _CanonicalAtomFeaturizer + + # Dict of model params that can be passed as a list, and the type that each will be + # cast to + LIST_PARAMS: ClassVar[dict] = { + "hidden_feats": int, + "num_heads": int, + "feat_drops": float, + "attn_drops": float, + "alphas": float, + "residuals": bool, + "agg_modes": str, + "activations": None, + "biases": bool, + } #: :meta private: + + model_type: ModelType = Field(ModelType.dgl_gat, const=True) + + in_feats: int = Field( + _CanonicalAtomFeaturizer().feat_size(), + description=( + "Input node feature size. Defaults to size of the " + "``CanonicalAtomFeaturizer``." + ), + ) + num_layers: int = Field( + 2, + description=( + "Number of GAT layers. Ignored if a list of values is passed for any " + "other argument." + ), + ) + hidden_feats: str | int | list[int] = Field( + 32, + description=( + "Output size of each GAT layer. If an ``int`` is passed, the value for " + "``num_layers`` will be used to determine the size of the model. If a list " + "of ``int`` s is passed, the size of the model will be inferred from the " + "length of the list." + ), + ) + num_heads: str | int | list[int] = Field( + 4, + description=( + "Number of attention heads for each GAT layer. Passing an ``int`` or list " + "of ``int`` s functions similarly as for ``hidden_feats``." + ), + ) + feat_drops: str | float | list[float] = Field( + 0, + description=( + "Dropout of input features for each GAT layer. Passing a ``float`` or " + "list of ``float`` s functions similarly as for ``hidden_feats``." + ), + ) + attn_drops: str | float | list[float] = Field( + 0, + description=( + "Dropout of attention values for each GAT layer. Passing a ``float`` or " + "list of ``float`` s functions similarly as for ``hidden_feats``." + ), + ) + alphas: str | float | list[float] = Field( + 0.2, + description=( + "Hyperparameter for ``LeakyReLU`` gate for each GAT layer. Passing a " + "``float`` or list of ``float`` s functions similarly as for " + "``hidden_feats``." + ), + ) + residuals: str | bool | list[bool] = Field( + True, + description=( + "Whether to use residual connection for each GAT layer. Passing a ``bool`` " + "or list of ``bool`` s functions similarly as for ``hidden_feats``." + ), + ) + agg_modes: str | list[str] = Field( + "flatten", + description=( + "Which aggregation mode [flatten, mean] to use for each GAT layer. " + "Passing a ``str`` or list of ``str`` s functions similarly as for " + "``hidden_feats``." + ), + ) + activations: Callable | list[Callable] | list[None] | None = Field( + None, + description=( + "Activation function for each GAT layer. Passing a function or " + "list of functions functions similarly as for ``hidden_feats``." + ), + ) + biases: str | bool | list[bool] = Field( + True, + description=( + "Whether to use bias for each GAT layer. Passing a ``bool`` or " + "list of ``bool`` s functions similarly as for ``hidden_feats``." + ), + ) + allow_zero_in_degree: bool = Field( + False, description="Allow zero in degree nodes for all graph layers." + ) + + # Internal tracker for if the parameters were originally built from lists or using + # num_layers + _from_num_layers = False + + @root_validator(pre=False) + def massage_into_lists(cls, values) -> DGLGATModelConfig: + """ + Validator to handle unifying all the values into the proper list forms based on + the rules described in the class docstring. + """ + # First convert string lists to actual lists + for param, param_type in cls.LIST_PARAMS.items(): + param_val = values[param] + if isinstance(param_val, str): + try: + param_val = list(map(param_type, param_val.split(","))) + except ValueError: + raise ValueError( + f"Unable to parse value {param_val} for parameter {param}. " + f"Expected type of {param_type}." + ) + values[param] = param_val + + # Get sizes of all lists + list_lens = {} + for p in cls.LIST_PARAMS: + param_val = values[p] + if not isinstance(param_val, list): + # Shouldn't be possible at this point but just in case + param_val = [param_val] + values[p] = param_val + list_lens[p] = len(param_val) + + # Check that there's only one length present + list_lens_set = set(list_lens.values()) + # This could be 0 if lists of length 1 were passed, which is valid + if len(list_lens_set - {1}) > 1: + raise ValueError( + "All passed parameter lists must be the same value. " + f"Instead got list lengths of: {list_lens}" + ) + elif list_lens_set == {1}: + # If all lists have only one value, we defer to the value passed to + # num_layers, as described in the class docstring + num_layers = values["num_layers"] + values["_from_num_layers"] = True + else: + num_layers = max(list_lens_set) + values["_from_num_layers"] = False + + values["num_layers"] = num_layers + # If we just want a model with one layer, can return early since we've already + # converted everything into lists + if num_layers == 1: + return values + + # Adjust any length 1 list to be the right length + for p, list_len in list_lens.items(): + if list_len == 1: + values[p] = values[p] * num_layers + + return values + + def _build(self, mtenn_params={}): + """ + Build an ``mtenn`` GAT ``Model`` from this config. + + :meta public: + + Parameters + ---------- + mtenn_params : dict, optional + Dictionary that stores the ``Readout`` objects for the individual + predictions and for the combined prediction, and the ``Combination`` object + in the case of a multi-pose model. These are all constructed the same for all + ``Model`` types, so we can just handle them in the base class. Keys in the + dict will be: + + * "combination": :py:mod:`Combination ` + + * "pred_readout": :py:mod:`Readout ` for individual + pose predictions + + * "comb_readout": :py:mod:`Readout ` for combined + prediction (in the case of a multi-pose model) + + although the combination-related entries will be ignore because this is a + ligand-only model. + + Returns + ------- + mtenn.model.Model + Model constructed from the config + """ + from mtenn.conversion_utils.dgl_gat import DGLGAT + + model = DGLGAT( + in_feats=self.in_feats, + hidden_feats=self.hidden_feats, + num_heads=self.num_heads, + feat_drops=self.feat_drops, + attn_drops=self.attn_drops, + alphas=self.alphas, + residuals=self.residuals, + agg_modes=self.agg_modes, + activations=self.activations, + biases=self.biases, + allow_zero_in_degree=self.allow_zero_in_degree, + ) + + pred_readout = mtenn_params.get("pred_readout", None) + return DGLGAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) + + def _update(self, config_updates={}) -> DGLGATModelConfig: + """ + GAT-specific implementation of updating logic. Need to handle stuff specially + to make sure that the original method of specifying parameters (either from a + passed value of ``num_layers`` or inferred from each parameter being a list) is + maintained. + + :meta public: + + Parameters + ---------- + config_updates : dict + Dictionary mapping from field names to new values + + Returns + ------- + DGLGATModelConfig + New ``DGLGATModelConfig`` object + """ + orig_config = self.dict() + if self._from_num_layers or ("num_layers" in config_updates): + # If originally generated from num_layers, want to pull out the first entry + # in each list param so it can be re-broadcast with (potentially) new + # num_layers + for param_name in DGLGATModelConfig.LIST_PARAMS.keys(): + orig_config[param_name] = orig_config[param_name][0] + + # Get new config by overwriting old stuff with any new stuff + new_config = orig_config | config_updates + + # A bit hacky, maybe try and change? + if isinstance(new_config["activations"], list) and ( + new_config["activations"][0] is None + ): + new_config["activations"] = None + + return DGLGATModelConfig(**new_config) + + class SchNetModelConfig(ModelConfigBase): """ Class for constructing a SchNet ML model. Default values here are the default values From cfc66dc475f4d4a40e40b22fe4326056d810a6ef Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 11 Sep 2024 15:26:08 -0400 Subject: [PATCH 14/14] Add option to use GATv2Conv. --- mtenn/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mtenn/config.py b/mtenn/config.py index 29eae63..3b6fafa 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -421,6 +421,7 @@ class GATModelConfig(ModelConfigBase): ) hidden_channels: int = Field(32, description="Hidden embedding size.") num_layers: int = Field(2, description="Number of GAT layers.") + v2: bool = Field(False, description="Use GATv2Conv layer instead of GATConv.") dropout: float = Field(0, description="Dropout probability.") heads: int = Field(4, description="Number of attention heads for each GAT layer.") negative_slope: float = Field( @@ -464,6 +465,7 @@ def _build(self, mtenn_params={}): in_channels=self.in_channels, hidden_channels=self.hidden_channels, num_layers=self.num_layers, + v2=self.v2, dropout=self.dropout, heads=self.heads, negative_slope=self.negative_slope, @@ -513,6 +515,7 @@ def _build(self, mtenn_params={}): in_channels=self.in_channels, hidden_channels=self.hidden_channels, num_layers=self.num_layers, + v2=self.v2, dropout=self.dropout, heads=self.heads, negative_slope=self.negative_slope,