Skip to content

Commit 1525549

Browse files
global fixes
1 parent f13a5f3 commit 1525549

File tree

7 files changed

+300
-278
lines changed

7 files changed

+300
-278
lines changed

pina/model/block/message_passing/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
__all__ = [
44
"InteractionNetworkBlock",
55
"DeepTensorNetworkBlock",
6+
"EnEquivariantNetworkBlock",
7+
"RadialFieldNetworkBlock",
8+
"SchnetBlock",
69
]
710

811
from .interaction_network_block import InteractionNetworkBlock
912
from .deep_tensor_network_block import DeepTensorNetworkBlock
13+
from .egnn_block import EnEquivariantNetworkBlock
14+
from .radial_field_network_block import RadialFieldNetworkBlock
15+
from .schnet_block import SchnetBlock

pina/model/block/message_passing/deep_tensor_network_block.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22

33
import torch
44
from torch_geometric.nn import MessagePassing
5-
from ....utils import check_consistency
5+
from ....utils import check_positive_integer
66

77

88
class DeepTensorNetworkBlock(MessagePassing):
99
"""
1010
Implementation of the Deep Tensor Network block.
1111
1212
This block is used to perform message-passing between nodes and edges in a
13-
graph neural network, following the scheme proposed by Schutt et al. (2017).
14-
It serves as an inner block in a larger graph neural network architecture.
13+
graph neural network, following the scheme proposed by Schutt et al. in
14+
2017. It serves as an inner block in a larger graph neural network
15+
architecture.
1516
1617
The message between two nodes connected by an edge is computed by applying a
1718
linear transformation to the sender node features and the edge features,
@@ -24,7 +25,7 @@ class DeepTensorNetworkBlock(MessagePassing):
2425
.. seealso::
2526
2627
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
27-
*Quantum-Chemical Insights from Deep Tensor Neural Networks*.
28+
(2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*.
2829
Nature Communications 8, 13890 (2017).
2930
DOI: `<https://doi.org/10.1038/ncomms13890>_`.
3031
"""
@@ -57,51 +58,36 @@ def __init__(
5758
flow means that messages are sent from the target node to the
5859
source node. See :class:`torch_geometric.nn.MessagePassing` for more
5960
details. Default is "source_to_target".
60-
:raises ValueError: If `node_feature_dim` is not a positive integer.
61-
:raises ValueError: If `edge_feature_dim` is not a positive integer.
61+
:raises AssertionError: If `node_feature_dim` is not a positive integer.
62+
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
6263
"""
6364
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
6465

65-
# Check consistency
66-
check_consistency(node_feature_dim, int)
67-
check_consistency(edge_feature_dim, int)
68-
6966
# Check values
70-
if node_feature_dim <= 0:
71-
raise ValueError(
72-
"`node_feature_dim` must be a positive integer,"
73-
f" got {node_feature_dim}."
74-
)
75-
76-
if edge_feature_dim <= 0:
77-
raise ValueError(
78-
"`edge_feature_dim` must be a positive integer,"
79-
f" got {edge_feature_dim}."
80-
)
81-
82-
# Initialize parameters
83-
self.node_feature_dim = node_feature_dim
84-
self.edge_feature_dim = edge_feature_dim
67+
check_positive_integer(node_feature_dim, strict=True)
68+
check_positive_integer(edge_feature_dim, strict=True)
69+
70+
# Activation function
8571
self.activation = activation
8672

8773
# Layer for processing node features
8874
self.node_layer = torch.nn.Linear(
89-
in_features=self.node_feature_dim,
90-
out_features=self.node_feature_dim,
75+
in_features=node_feature_dim,
76+
out_features=node_feature_dim,
9177
bias=True,
9278
)
9379

9480
# Layer for processing edge features
9581
self.edge_layer = torch.nn.Linear(
96-
in_features=self.edge_feature_dim,
97-
out_features=self.node_feature_dim,
82+
in_features=edge_feature_dim,
83+
out_features=node_feature_dim,
9884
bias=True,
9985
)
10086

10187
# Layer for computing the message
10288
self.message_layer = torch.nn.Linear(
103-
in_features=self.node_feature_dim,
104-
out_features=self.node_feature_dim,
89+
in_features=node_feature_dim,
90+
out_features=node_feature_dim,
10591
bias=False,
10692
)
10793

pina/model/block/message_passing/egnn_block.py

Lines changed: 111 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,135 +3,170 @@
33
import torch
44
from torch_geometric.nn import MessagePassing
55
from torch_geometric.utils import degree
6+
from ....utils import check_positive_integer
7+
from ....model import FeedForward
68

79

8-
class EnEquivariantGraphBlock(MessagePassing):
10+
class EnEquivariantNetworkBlock(MessagePassing):
911
"""
1012
Implementation of the E(n) Equivariant Graph Neural Network block.
1113
1214
This block is used to perform message-passing between nodes and edges in a
13-
graph neural network, following the scheme proposed by Satorras et al. (2021).
14-
It serves as an inner block in a larger graph neural network architecture.
15+
graph neural network, following the scheme proposed by Satorras et al. in
16+
2021. It serves as an inner block in a larger graph neural network
17+
architecture.
1518
1619
The message between two nodes connected by an edge is computed by applying a
1720
linear transformation to the sender node features and the edge features,
18-
followed by a non-linear activation function. Messages are then aggregated
19-
using an aggregation scheme (e.g., sum, mean, min, max, or product).
21+
together with the squared euclidean distance between the sender and
22+
recipient node positions, followed by a non-linear activation function.
23+
Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
24+
min, max, or product).
2025
21-
The update step is performed by a simple addition of the incoming messages
22-
to the node features.
26+
The update step is performed by applying another MLP to the concatenation of
27+
the incoming messages and the node features. Here, also the node
28+
positions are updated by adding the incoming messages divided by the
29+
degree of the recipient node.
2330
2431
.. seealso::
2532
26-
**Original reference** Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July).
27-
E (n) equivariant graph neural networks.
28-
In International conference on machine learning (pp. 9323-9332). PMLR.
33+
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
34+
(2021). *E(n) Equivariant Graph Neural Networks.*
35+
In International Conference on Machine Learning.
36+
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>_`.
2937
"""
3038

3139
def __init__(
3240
self,
33-
channels_x,
34-
channels_m,
35-
channels_a,
36-
aggr: str = "add",
37-
hidden_channels: int = 64,
38-
**kwargs,
41+
node_feature_dim,
42+
edge_feature_dim,
43+
pos_dim,
44+
hidden_dim=64,
45+
n_message_layers=2,
46+
n_update_layers=2,
47+
activation=torch.nn.SiLU,
48+
aggr="add",
49+
node_dim=-2,
50+
flow="source_to_target",
3951
):
4052
"""
41-
Initialization of the :class:`EnEquivariantGraphBlock` class.
42-
43-
:param int channels_x: The dimension of the node features.
44-
:param int channels_m: The dimension of the Euclidean coordinates (should be =3).
45-
:param int channels_a: The dimension of the edge features.
53+
Initialization of the :class:`EnEquivariantNetworkBlock` class.
54+
55+
:param int node_feature_dim: The dimension of the node features.
56+
:param int edge_feature_dim: The dimension of the edge features.
57+
:param int pos_dim: The dimension of the position features.
58+
:param int hidden_dim: The dimension of the hidden features.
59+
Default is 64.
60+
:param int n_message_layers: The number of layers in the message
61+
network. Default is 2.
62+
:param int n_update_layers: The number of layers in the update network.
63+
Default is 2.
64+
:param torch.nn.Module activation: The activation function.
65+
Default is :class:`torch.nn.SiLU`.
4666
:param str aggr: The aggregation scheme to use for message passing.
4767
Available options are "add", "mean", "min", "max", "mul".
4868
See :class:`torch_geometric.nn.MessagePassing` for more details.
4969
Default is "add".
50-
:param int hidden_channels_dim: The hidden dimension in each MLPs initialized in the block.
70+
:param int node_dim: The axis along which to propagate. Default is -2.
71+
:param str flow: The direction of message passing. Available options
72+
are "source_to_target" and "target_to_source".
73+
The "source_to_target" flow means that messages are sent from
74+
the source node to the target node, while the "target_to_source"
75+
flow means that messages are sent from the target node to the
76+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
77+
details. Default is "source_to_target".
78+
:raises AssertionError: If `node_feature_dim` is not a positive integer.
79+
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
80+
:raises AssertionError: If `pos_dim` is not a positive integer.
81+
:raises AssertionError: If `hidden_dim` is not a positive integer.
82+
:raises AssertionError: If `n_message_layers` is not a positive integer.
83+
:raises AssertionError: If `n_update_layers` is not a positive integer.
5184
"""
52-
super().__init__(aggr=aggr, **kwargs)
53-
54-
self.phi_e = torch.nn.Sequential(
55-
torch.nn.Linear(2 * channels_x + 1 + channels_a, hidden_channels),
56-
torch.nn.LayerNorm(hidden_channels),
57-
torch.nn.SiLU(),
58-
torch.nn.Linear(hidden_channels, channels_m),
59-
torch.nn.LayerNorm(channels_m),
60-
torch.nn.SiLU(),
61-
)
62-
self.phi_pos = torch.nn.Sequential(
63-
torch.nn.Linear(channels_m, hidden_channels),
64-
torch.nn.LayerNorm(hidden_channels),
65-
torch.nn.SiLU(),
66-
torch.nn.Linear(hidden_channels, 1),
85+
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
86+
87+
# Check values
88+
check_positive_integer(node_feature_dim, strict=True)
89+
check_positive_integer(edge_feature_dim, strict=True)
90+
check_positive_integer(pos_dim, strict=True)
91+
check_positive_integer(hidden_dim, strict=True)
92+
check_positive_integer(n_message_layers, strict=True)
93+
check_positive_integer(n_update_layers, strict=True)
94+
95+
# Layer for computing the message
96+
self.message_net = FeedForward(
97+
input_dimensions=2 * node_feature_dim + edge_feature_dim + 1,
98+
output_dimensions=pos_dim,
99+
inner_size=hidden_dim,
100+
n_layers=n_message_layers,
101+
func=activation,
67102
)
68-
self.phi_x = torch.nn.Sequential(
69-
torch.nn.Linear(channels_x + channels_m, hidden_channels),
70-
torch.nn.LayerNorm(hidden_channels),
71-
torch.nn.SiLU(),
72-
torch.nn.Linear(hidden_channels, channels_x),
103+
104+
# Layer for updating the node features
105+
self.update_net = FeedForward(
106+
input_dimensions=node_feature_dim + pos_dim,
107+
output_dimensions=node_feature_dim,
108+
inner_size=hidden_dim,
109+
n_layers=n_update_layers,
110+
func=activation,
73111
)
74112

75-
def forward(self, x, pos, edge_attr, edge_index, c=None):
113+
def forward(self, x, pos, edge_index, edge_attr):
76114
"""
77115
Forward pass of the block, triggering the message-passing routine.
78116
79117
:param x: The node features.
80118
:type x: torch.Tensor | LabelTensor
81-
:param pos_i: 3D Euclidean coordinates.
82-
:type pos_i: torch.Tensor | LabelTensor
83-
:param torch.Tensor edge_index: The edge indices. In the original formulation,
84-
the messages are aggregated from all nodes, not only from the neighbours.
85-
:return: The updated node features.
86-
:rtype: torch.Tensor
119+
:param pos: The euclidean coordinates of the nodes.
120+
:type pos: torch.Tensor | LabelTensor
121+
:param torch.Tensor edge_index: The edge indices.
122+
:param edge_attr: The edge attributes. Default is None.
123+
:type edge_attr: torch.Tensor | LabelTensor
124+
:return: The updated node features and node positions.
125+
:rtype: tuple(torch.Tensor, torch.Tensor)
87126
"""
88-
if c is None:
89-
c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1)
90127
return self.propagate(
91-
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, c=c
128+
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
92129
)
93130

94131
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
95132
"""
96133
Compute the message to be passed between nodes and edges.
97134
98-
:param x_i: Node features of the sender nodes.
135+
:param x_i: The node features of the recipient nodes.
99136
:type x_i: torch.Tensor | LabelTensor
100-
:param pos_i: 3D Euclidean coordinates of the sender nodes.
137+
:param x_j: The node features of the sender nodes.
138+
:type x_j: torch.Tensor | LabelTensor
139+
:param pos_i: The node coordinates of the recipient nodes.
101140
:type pos_i: torch.Tensor | LabelTensor
141+
:param pos_j: The node coordinates of the sender nodes.
142+
:type pos_j: torch.Tensor | LabelTensor
102143
:param edge_attr: The edge attributes.
103144
:type edge_attr: torch.Tensor | LabelTensor
104145
:return: The message to be passed.
105146
:rtype: torch.Tensor
106147
"""
107-
mpos_ij = self.phi_e(
108-
torch.cat(
109-
[
110-
x_i,
111-
x_j,
112-
torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2,
113-
edge_attr,
114-
],
115-
dim=-1,
116-
)
117-
)
118-
mpos_ij = (pos_i - pos_j) * self.phi_pos(mpos_ij)
119-
return mpos_ij
148+
dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2
149+
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
150+
return self.message_net(input_)
120151

121-
def update(self, message, x, pos, c):
152+
def update(self, message, x, pos, edge_index):
122153
"""
123-
Update the node features with the received messages.
154+
Update the node features and the node coordinates with the received
155+
messages.
124156
125157
:param torch.Tensor message: The message to be passed.
126158
:param x: The node features.
127159
:type x: torch.Tensor | LabelTensor
128-
:param pos: The 3D Euclidean coordinates of the nodes.
160+
:param pos: The euclidean coordinates of the nodes.
129161
:type pos: torch.Tensor | LabelTensor
130-
:param c: the constant that divides the aggregated message (it should be (M-1), where M is the number of nodes)
131-
:type pos: torch.Tensor
132-
:return: The concatenation of the update position features and the updated node features.
133-
:rtype: torch.Tensor
162+
:param torch.Tensor edge_index: The edge indices.
163+
:return: The updated node features and node positions.
164+
:rtype: tuple(torch.Tensor, torch.Tensor)
134165
"""
135-
x = self.phi_x(torch.cat([x, message], dim=-1))
136-
pos = pos + (message / c)
166+
# Update the node features
167+
x = self.update_net(torch.cat((x, message), dim=-1))
168+
169+
# Update the node positions
170+
c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1)
171+
pos = pos + message / c
137172
return pos, x

0 commit comments

Comments
 (0)