1010class EnEquivariantNetworkBlock (MessagePassing ):
1111 """
1212 Implementation of the E(n) Equivariant Graph Neural Network block.
13-
1413 This block is used to perform message-passing between nodes and edges in a
1514 graph neural network, following the scheme proposed by Satorras et al. in
1615 2021. It serves as an inner block in a larger graph neural network
@@ -102,14 +101,24 @@ def __init__(
102101 )
103102
104103 # Layer for updating the node features
105- self .update_net = FeedForward (
104+ self .update_feat_net = FeedForward (
106105 input_dimensions = node_feature_dim + pos_dim ,
107106 output_dimensions = node_feature_dim ,
108107 inner_size = hidden_dim ,
109108 n_layers = n_update_layers ,
110109 func = activation ,
111110 )
112111
112+ # Layer for updating the node positions
113+ # The output dimension is set to 1 for equivariant updates
114+ self .update_pos_net = FeedForward (
115+ input_dimensions = pos_dim ,
116+ output_dimensions = 1 ,
117+ inner_size = hidden_dim ,
118+ n_layers = n_update_layers ,
119+ func = activation ,
120+ )
121+
113122 def forward (self , x , pos , edge_index , edge_attr = None ):
114123 """
115124 Forward pass of the block, triggering the message-passing routine.
@@ -143,22 +152,62 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
143152 :param edge_attr: The edge attributes.
144153 :type edge_attr: torch.Tensor | LabelTensor
145154 :return: The message to be passed.
146- :rtype: torch.Tensor
155+ :rtype: tuple( torch.Tensor, torch.Tensor)
147156 """
148- dist = torch .norm (pos_i - pos_j , dim = - 1 , keepdim = True ) ** 2
157+ # Compute the euclidean distance between the sender and recipient nodes
158+ diff = pos_i - pos_j
159+ dist = torch .norm (diff , dim = - 1 , keepdim = True ) ** 2
160+
161+ # Compute the message input
149162 if edge_attr is None :
150163 input_ = torch .cat ((x_i , x_j , dist ), dim = - 1 )
151164 else :
152165 input_ = torch .cat ((x_i , x_j , dist , edge_attr ), dim = - 1 )
153166
154- return self .message_net (input_ )
167+ # Compute the messages and their equivariant counterpart
168+ m_ij = self .message_net (input_ )
169+ message = diff * self .update_pos_net (m_ij )
170+
171+ return message , m_ij
155172
156- def update (self , message , x , pos , edge_index ):
173+ def aggregate (self , inputs , index , ptr = None , dim_size = None ):
174+ """
175+ Aggregate the messages at the nodes during message passing.
176+
177+ This method receives a tuple of tensors corresponding to the messages
178+ to be aggregated. Both messages are aggregated separately according to
179+ the specified aggregation scheme.
180+
181+ :param tuple(torch.Tensor) inputs: Tuple containing two messages to
182+ aggregate.
183+ :param index: The indices of target nodes for each message. This tensor
184+ specifies which node each message is aggregated into.
185+ :type index: torch.Tensor | LabelTensor
186+ :param ptr: Optional tensor to specify the slices of messages for each
187+ node (used in some aggregation strategies). Default is None.
188+ :type ptr: torch.Tensor | LabelTensor
189+ :param int dim_size: Optional size of the output dimension, i.e.,
190+ number of nodes. Default is None.
191+ :return: Tuple of aggregated tensors corresponding to (aggregated
192+ messages for position updates, aggregated messages for feature
193+ updates).
194+ :rtype: tuple(torch.Tensor, torch.Tensor)
195+ """
196+ # Unpack the messages from the inputs
197+ message , m_ij = inputs
198+
199+ # Aggregate messages as usual using self.aggr method
200+ agg_message = super ().aggregate (message , index , ptr , dim_size )
201+ agg_m_ij = super ().aggregate (m_ij , index , ptr , dim_size )
202+
203+ return agg_message , agg_m_ij
204+
205+ def update (self , aggregated_inputs , x , pos , edge_index ):
157206 """
158207 Update the node features and the node coordinates with the received
159208 messages.
160209
161- :param torch.Tensor message : The message to be passed.
210+ :param tuple( torch.Tensor) aggregated_inputs : The messages to be passed.
162211 :param x: The node features.
163212 :type x: torch.Tensor | LabelTensor
164213 :param pos: The euclidean coordinates of the nodes.
@@ -167,10 +216,14 @@ def update(self, message, x, pos, edge_index):
167216 :return: The updated node features and node positions.
168217 :rtype: tuple(torch.Tensor, torch.Tensor)
169218 """
170- # Update the node features
171- x = self .update_net (torch .cat ((x , message ), dim = - 1 ))
219+ # aggregated_inputs is tuple (agg_message, agg_m_ij)
220+ agg_message , agg_m_ij = aggregated_inputs
221+
222+ # Update node features with aggregated m_ij
223+ x = self .update_feat_net (torch .cat ((x , agg_m_ij ), dim = - 1 ))
224+
225+ # Degree for normalization of position updates
226+ c = degree (edge_index [1 ], pos .shape [0 ]).unsqueeze (- 1 ).clamp (min = 1 )
227+ pos = pos + agg_message / c
172228
173- # Update the node positions
174- c = degree (edge_index [0 ], pos .shape [0 ]).unsqueeze (- 1 )
175- pos = pos + message / c
176229 return x , pos
0 commit comments