Skip to content

Commit a3e7f9f

Browse files
committed
fix egnn + equivariance/invariance remaining tests
1 parent 026ab19 commit a3e7f9f

File tree

4 files changed

+95
-87
lines changed

4 files changed

+95
-87
lines changed

pina/model/block/message_passing/en_equivariant_network_block.py

Lines changed: 45 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,6 @@
88

99

1010
class EnEquivariantNetworkBlock(MessagePassing):
11-
"""
12-
Implementation of the E(n) Equivariant Graph Neural Network block.
13-
14-
This block is used to perform message-passing between nodes and edges in a
15-
graph neural network, following the scheme proposed by Satorras et al. in
16-
2021. It serves as an inner block in a larger graph neural network
17-
architecture.
18-
19-
The message between two nodes connected by an edge is computed by applying a
20-
linear transformation to the sender node features and the edge features,
21-
together with the squared euclidean distance between the sender and
22-
recipient node positions, followed by a non-linear activation function.
23-
Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
24-
min, max, or product).
25-
26-
The update step is performed by applying another MLP to the concatenation of
27-
the incoming messages and the node features. Here, also the node
28-
positions are updated by adding the incoming messages divided by the
29-
degree of the recipient node.
30-
31-
.. seealso::
32-
33-
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
34-
(2021). *E(n) Equivariant Graph Neural Networks.*
35-
In International Conference on Machine Learning.
36-
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
37-
"""
38-
3911
def __init__(
4012
self,
4113
node_feature_dim,
@@ -49,50 +21,15 @@ def __init__(
4921
node_dim=-2,
5022
flow="source_to_target",
5123
):
52-
"""
53-
Initialization of the :class:`EnEquivariantNetworkBlock` class.
54-
55-
:param int node_feature_dim: The dimension of the node features.
56-
:param int edge_feature_dim: The dimension of the edge features.
57-
:param int pos_dim: The dimension of the position features.
58-
:param int hidden_dim: The dimension of the hidden features.
59-
Default is 64.
60-
:param int n_message_layers: The number of layers in the message
61-
network. Default is 2.
62-
:param int n_update_layers: The number of layers in the update network.
63-
Default is 2.
64-
:param torch.nn.Module activation: The activation function.
65-
Default is :class:`torch.nn.SiLU`.
66-
:param str aggr: The aggregation scheme to use for message passing.
67-
Available options are "add", "mean", "min", "max", "mul".
68-
See :class:`torch_geometric.nn.MessagePassing` for more details.
69-
Default is "add".
70-
:param int node_dim: The axis along which to propagate. Default is -2.
71-
:param str flow: The direction of message passing. Available options
72-
are "source_to_target" and "target_to_source".
73-
The "source_to_target" flow means that messages are sent from
74-
the source node to the target node, while the "target_to_source"
75-
flow means that messages are sent from the target node to the
76-
source node. See :class:`torch_geometric.nn.MessagePassing` for more
77-
details. Default is "source_to_target".
78-
:raises AssertionError: If `node_feature_dim` is not a positive integer.
79-
:raises AssertionError: If `edge_feature_dim` is a negative integer.
80-
:raises AssertionError: If `pos_dim` is not a positive integer.
81-
:raises AssertionError: If `hidden_dim` is not a positive integer.
82-
:raises AssertionError: If `n_message_layers` is not a positive integer.
83-
:raises AssertionError: If `n_update_layers` is not a positive integer.
84-
"""
8524
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
8625

87-
# Check values
8826
check_positive_integer(node_feature_dim, strict=True)
8927
check_positive_integer(edge_feature_dim, strict=False)
9028
check_positive_integer(pos_dim, strict=True)
9129
check_positive_integer(hidden_dim, strict=True)
9230
check_positive_integer(n_message_layers, strict=True)
9331
check_positive_integer(n_update_layers, strict=True)
9432

95-
# Layer for computing the message
9633
self.message_net = FeedForward(
9734
input_dimensions=2 * node_feature_dim + edge_feature_dim + 1,
9835
output_dimensions=pos_dim,
@@ -101,7 +38,6 @@ def __init__(
10138
func=activation,
10239
)
10340

104-
# Layer for updating the node features
10541
self.update_feat_net = FeedForward(
10642
input_dimensions=node_feature_dim + pos_dim,
10743
output_dimensions=node_feature_dim,
@@ -110,8 +46,6 @@ def __init__(
11046
func=activation,
11147
)
11248

113-
# Layer for updating the node positions
114-
# The output dimension is set to 1 for equivariant updates
11549
self.update_pos_net = FeedForward(
11650
input_dimensions=pos_dim,
11751
output_dimensions=1,
@@ -120,9 +54,6 @@ def __init__(
12054
func=activation,
12155
)
12256

123-
# Placeholder for the messages
124-
self._m_ij = None
125-
12657
def forward(self, x, pos, edge_index, edge_attr=None):
12758
"""
12859
Forward pass of the block, triggering the message-passing routine.
@@ -158,28 +89,57 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
15889
:return: The message to be passed.
15990
:rtype: torch.Tensor
16091
"""
161-
# Compute the euclidean distance between the sender and recipient nodes
16292
diff = pos_i - pos_j
16393
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
16494

165-
# Compute the message input
16695
if edge_attr is None:
16796
input_ = torch.cat((x_i, x_j, dist), dim=-1)
16897
else:
16998
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
17099

171-
# Compute the messages and save them for feature update
172-
self._m_ij = self.message_net(input_)
100+
m_ij = self.message_net(input_) # message features
101+
message = diff * self.update_pos_net(m_ij) # equivariant message
173102

174-
# Rescale the message by the euclidean distance
175-
return diff * self.update_pos_net(self._m_ij)
103+
return message, m_ij
176104

177-
def update(self, message, x, pos, edge_index):
105+
def aggregate(self, inputs, index, ptr=None, dim_size=None):
106+
"""
107+
Aggregate the messages at the nodes during message passing.
108+
109+
This method receives a tuple of tensors corresponding to the messages
110+
to be aggregated. Both messages are aggregated separately according to
111+
the specified aggregation scheme.
112+
113+
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
114+
aggregate.
115+
:param torch.Tensor | LabelTensor index: The indices of target nodes
116+
for each message. This tensor specifies which node each message
117+
is aggregated into.
118+
:param torch.Tensor | LabelTensor ptr: Optional tensor to specify
119+
the slices of messages for each node (used in some aggregation
120+
strategies).
121+
:param int dim_size: Optional size of the output dimension, i.e.,
122+
number of nodes.
123+
:return: Tuple of aggregated tensors corresponding to
124+
(aggregated messages for position updates, aggregated messages for
125+
feature updates).
126+
:rtype: tuple(torch.Tensor, torch.Tensor)
127+
"""
128+
# inputs is tuple (message, m_ij), we want to aggregate separately
129+
message, m_ij = inputs
130+
131+
# Aggregate messages as usual using self.aggr method
132+
agg_message = super().aggregate(message, index, ptr, dim_size)
133+
agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size)
134+
135+
return agg_message, agg_m_ij
136+
137+
def update(self, aggregated_inputs, x, pos, edge_index):
178138
"""
179139
Update the node features and the node coordinates with the received
180140
messages.
181141
182-
:param torch.Tensor message: The message to be passed.
142+
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
183143
:param x: The node features.
184144
:type x: torch.Tensor | LabelTensor
185145
:param pos: The euclidean coordinates of the nodes.
@@ -188,14 +148,14 @@ def update(self, message, x, pos, edge_index):
188148
:return: The updated node features and node positions.
189149
:rtype: tuple(torch.Tensor, torch.Tensor)
190150
"""
191-
# Sum the incoming messages for each node (m_i = sum_j m_ij)
192-
m_sum = torch.zeros(x.size(0), self._m_ij.shape[-1], device=x.device)
193-
m_sum.index_add_(0, edge_index[1], self._m_ij)
151+
# aggregated_inputs is tuple (agg_message, agg_m_ij)
152+
agg_message, agg_m_ij = aggregated_inputs
153+
154+
# Update node features with aggregated m_ij
155+
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
194156

195-
# Update the node features
196-
x = self.update_feat_net(torch.cat((x, m_sum), dim=-1))
157+
# Degree for normalization of position updates
158+
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
159+
pos = pos + agg_message / c
197160

198-
# Update the node positions
199-
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1)
200-
pos = pos + message / c
201161
return x, pos

tests/test_messagepassing/test_equivariant_network_block.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,13 @@ def test_equivariance():
153153
n_update_layers=2,
154154
).eval()
155155

156-
_, pos1 = model(edge_index=edge_index, x=x, pos=pos)
157-
_, pos2 = model(
156+
h1, pos1 = model(edge_index=edge_index, x=x, pos=pos)
157+
h2, pos2 = model(
158158
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation
159159
)
160160

161161
# Transform model output
162162
pos1_transformed = (pos1 @ rotation.T) + translation
163163

164164
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)
165+
assert torch.allclose(h1, h2, atol=1e-5)

tests/test_messagepassing/test_radial_field_network_block.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,28 @@ def test_backward():
6565
loss = torch.mean(output_)
6666
loss.backward()
6767
assert x.grad.shape == x.shape
68+
69+
70+
def test_equivariance():
71+
72+
# Graph to be fully connected and undirected
73+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
74+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
75+
76+
# Random rotation (det(rotation) should be 1)
77+
rotation = torch.linalg.qr(torch.rand(x.shape[-1], x.shape[-1])).Q
78+
if torch.det(rotation) < 0:
79+
rotation[:, 0] *= -1
80+
81+
# Random translation
82+
translation = torch.rand(1, x.shape[-1])
83+
84+
model = RadialFieldNetworkBlock(node_feature_dim=x.shape[1]).eval()
85+
86+
pos1 = model(edge_index=edge_index, x=x)
87+
pos2 = model(edge_index=edge_index, x=x @ rotation.T + translation)
88+
89+
# Transform model output
90+
pos1_transformed = (pos1 @ rotation.T) + translation
91+
92+
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)

tests/test_messagepassing/test_schnet_block.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,25 @@ def test_backward():
7171
loss = torch.mean(output_)
7272
loss.backward()
7373
assert x.grad.shape == x.shape
74+
75+
76+
def test_invariance():
77+
78+
# Graph to be fully connected and undirected
79+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
80+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
81+
82+
# Random rotation (det(rotation) should be 1)
83+
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
84+
if torch.det(rotation) < 0:
85+
rotation[:, 0] *= -1
86+
87+
# Random translation
88+
translation = torch.rand(1, pos.shape[-1])
89+
90+
model = SchnetBlock(node_feature_dim=x.shape[1]).eval()
91+
92+
out1 = model(edge_index=edge_index, x=x, pos=pos)
93+
out2 = model(edge_index=edge_index, x=x, pos=pos @ rotation.T + translation)
94+
95+
assert torch.allclose(out1, out2, atol=1e-5)

0 commit comments

Comments
 (0)