Skip to content

Commit 2101d79

Browse files
dario-cosciaGiovanniCanaliAleDinve
authored
Message Passing Module (#516)
* add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <[email protected]> --------- Co-authored-by: giovanni <[email protected]> Co-authored-by: AleDinve <[email protected]>
1 parent facc4a0 commit 2101d79

19 files changed

+1405
-6
lines changed

docs/source/_rst/_code.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,18 @@ Blocks
122122
Continuous Convolution Block <model/block/convolution.rst>
123123
Orthogonal Block <model/block/orthogonal.rst>
124124

125+
Message Passing
126+
-------------------
127+
128+
.. toctree::
129+
:titlesonly:
130+
131+
Deep Tensor Network Block <model/block/message_passing/deep_tensor_network_block.rst>
132+
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
133+
Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
134+
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
135+
Schnet Block <model/block/message_passing/schnet_block.rst>
136+
125137

126138
Reduction and Embeddings
127139
--------------------------
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Deep Tensor Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.deep_tensor_network_block
4+
5+
.. autoclass:: DeepTensorNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
E(n) Equivariant Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.en_equivariant_network_block
4+
5+
.. autoclass:: EnEquivariantNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Interaction Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.interaction_network_block
4+
5+
.. autoclass:: InteractionNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Radial Field Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.radial_field_network_block
4+
5+
.. autoclass:: RadialFieldNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Schnet Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.schnet_block
4+
5+
.. autoclass:: SchnetBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Module for the message passing blocks of the graph neural models."""
2+
3+
__all__ = [
4+
"InteractionNetworkBlock",
5+
"DeepTensorNetworkBlock",
6+
"EnEquivariantNetworkBlock",
7+
"RadialFieldNetworkBlock",
8+
"SchnetBlock",
9+
]
10+
11+
from .interaction_network_block import InteractionNetworkBlock
12+
from .deep_tensor_network_block import DeepTensorNetworkBlock
13+
from .en_equivariant_network_block import EnEquivariantNetworkBlock
14+
from .radial_field_network_block import RadialFieldNetworkBlock
15+
from .schnet_block import SchnetBlock
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Module for the Deep Tensor Network block."""
2+
3+
import torch
4+
from torch_geometric.nn import MessagePassing
5+
from ....utils import check_positive_integer
6+
7+
8+
class DeepTensorNetworkBlock(MessagePassing):
9+
"""
10+
Implementation of the Deep Tensor Network block.
11+
12+
This block is used to perform message-passing between nodes and edges in a
13+
graph neural network, following the scheme proposed by Schutt et al. in
14+
2017. It serves as an inner block in a larger graph neural network
15+
architecture.
16+
17+
The message between two nodes connected by an edge is computed by applying a
18+
linear transformation to the sender node features and the edge features,
19+
followed by a non-linear activation function. Messages are then aggregated
20+
using an aggregation scheme (e.g., sum, mean, min, max, or product).
21+
22+
The update step is performed by a simple addition of the incoming messages
23+
to the node features.
24+
25+
.. seealso::
26+
27+
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
28+
(2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*.
29+
Nature Communications 8, 13890 (2017).
30+
DOI: `<https://doi.org/10.1038/ncomms13890>`_.
31+
"""
32+
33+
def __init__(
34+
self,
35+
node_feature_dim,
36+
edge_feature_dim,
37+
activation=torch.nn.Tanh,
38+
aggr="add",
39+
node_dim=-2,
40+
flow="source_to_target",
41+
):
42+
"""
43+
Initialization of the :class:`DeepTensorNetworkBlock` class.
44+
45+
:param int node_feature_dim: The dimension of the node features.
46+
:param int edge_feature_dim: The dimension of the edge features.
47+
:param torch.nn.Module activation: The activation function.
48+
Default is :class:`torch.nn.Tanh`.
49+
:param str aggr: The aggregation scheme to use for message passing.
50+
Available options are "add", "mean", "min", "max", "mul".
51+
See :class:`torch_geometric.nn.MessagePassing` for more details.
52+
Default is "add".
53+
:param int node_dim: The axis along which to propagate. Default is -2.
54+
:param str flow: The direction of message passing. Available options
55+
are "source_to_target" and "target_to_source".
56+
The "source_to_target" flow means that messages are sent from
57+
the source node to the target node, while the "target_to_source"
58+
flow means that messages are sent from the target node to the
59+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
60+
details. Default is "source_to_target".
61+
:raises AssertionError: If `node_feature_dim` is not a positive integer.
62+
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
63+
"""
64+
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
65+
66+
# Check values
67+
check_positive_integer(node_feature_dim, strict=True)
68+
check_positive_integer(edge_feature_dim, strict=True)
69+
70+
# Activation function
71+
self.activation = activation()
72+
73+
# Layer for processing node features
74+
self.node_layer = torch.nn.Linear(
75+
in_features=node_feature_dim,
76+
out_features=node_feature_dim,
77+
bias=True,
78+
)
79+
80+
# Layer for processing edge features
81+
self.edge_layer = torch.nn.Linear(
82+
in_features=edge_feature_dim,
83+
out_features=node_feature_dim,
84+
bias=True,
85+
)
86+
87+
# Layer for computing the message
88+
self.message_layer = torch.nn.Linear(
89+
in_features=node_feature_dim,
90+
out_features=node_feature_dim,
91+
bias=False,
92+
)
93+
94+
def forward(self, x, edge_index, edge_attr):
95+
"""
96+
Forward pass of the block, triggering the message-passing routine.
97+
98+
:param x: The node features.
99+
:type x: torch.Tensor | LabelTensor
100+
:param torch.Tensor edge_index: The edge indeces.
101+
:param edge_attr: The edge attributes.
102+
:type edge_attr: torch.Tensor | LabelTensor
103+
:return: The updated node features.
104+
:rtype: torch.Tensor
105+
"""
106+
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
107+
108+
def message(self, x_j, edge_attr):
109+
"""
110+
Compute the message to be passed between nodes and edges.
111+
112+
:param x_j: The node features of the sender nodes.
113+
:type x_j: torch.Tensor | LabelTensor
114+
:param edge_attr: The edge attributes.
115+
:type edge_attr: torch.Tensor | LabelTensor
116+
:return: The message to be passed.
117+
:rtype: torch.Tensor
118+
"""
119+
# Process node and edge features
120+
filter_node = self.node_layer(x_j)
121+
filter_edge = self.edge_layer(edge_attr)
122+
123+
# Compute the message to be passed
124+
message = self.message_layer(filter_node * filter_edge)
125+
126+
return self.activation(message)
127+
128+
def update(self, message, x):
129+
"""
130+
Update the node features with the received messages.
131+
132+
:param torch.Tensor message: The message to be passed.
133+
:param x: The node features.
134+
:type x: torch.Tensor | LabelTensor
135+
:return: The updated node features.
136+
:rtype: torch.Tensor
137+
"""
138+
return x + message

0 commit comments

Comments
 (0)