88
99
1010class EnEquivariantNetworkBlock (MessagePassing ):
11- """
12- Implementation of the E(n) Equivariant Graph Neural Network block.
13-
14- This block is used to perform message-passing between nodes and edges in a
15- graph neural network, following the scheme proposed by Satorras et al. in
16- 2021. It serves as an inner block in a larger graph neural network
17- architecture.
18-
19- The message between two nodes connected by an edge is computed by applying a
20- linear transformation to the sender node features and the edge features,
21- together with the squared euclidean distance between the sender and
22- recipient node positions, followed by a non-linear activation function.
23- Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
24- min, max, or product).
25-
26- The update step is performed by applying another MLP to the concatenation of
27- the incoming messages and the node features. Here, also the node
28- positions are updated by adding the incoming messages divided by the
29- degree of the recipient node.
30-
31- .. seealso::
32-
33- **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
34- (2021). *E(n) Equivariant Graph Neural Networks.*
35- In International Conference on Machine Learning.
36- DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
37- """
38-
3911 def __init__ (
4012 self ,
4113 node_feature_dim ,
@@ -49,50 +21,15 @@ def __init__(
4921 node_dim = - 2 ,
5022 flow = "source_to_target" ,
5123 ):
52- """
53- Initialization of the :class:`EnEquivariantNetworkBlock` class.
54-
55- :param int node_feature_dim: The dimension of the node features.
56- :param int edge_feature_dim: The dimension of the edge features.
57- :param int pos_dim: The dimension of the position features.
58- :param int hidden_dim: The dimension of the hidden features.
59- Default is 64.
60- :param int n_message_layers: The number of layers in the message
61- network. Default is 2.
62- :param int n_update_layers: The number of layers in the update network.
63- Default is 2.
64- :param torch.nn.Module activation: The activation function.
65- Default is :class:`torch.nn.SiLU`.
66- :param str aggr: The aggregation scheme to use for message passing.
67- Available options are "add", "mean", "min", "max", "mul".
68- See :class:`torch_geometric.nn.MessagePassing` for more details.
69- Default is "add".
70- :param int node_dim: The axis along which to propagate. Default is -2.
71- :param str flow: The direction of message passing. Available options
72- are "source_to_target" and "target_to_source".
73- The "source_to_target" flow means that messages are sent from
74- the source node to the target node, while the "target_to_source"
75- flow means that messages are sent from the target node to the
76- source node. See :class:`torch_geometric.nn.MessagePassing` for more
77- details. Default is "source_to_target".
78- :raises AssertionError: If `node_feature_dim` is not a positive integer.
79- :raises AssertionError: If `edge_feature_dim` is a negative integer.
80- :raises AssertionError: If `pos_dim` is not a positive integer.
81- :raises AssertionError: If `hidden_dim` is not a positive integer.
82- :raises AssertionError: If `n_message_layers` is not a positive integer.
83- :raises AssertionError: If `n_update_layers` is not a positive integer.
84- """
8524 super ().__init__ (aggr = aggr , node_dim = node_dim , flow = flow )
8625
87- # Check values
8826 check_positive_integer (node_feature_dim , strict = True )
8927 check_positive_integer (edge_feature_dim , strict = False )
9028 check_positive_integer (pos_dim , strict = True )
9129 check_positive_integer (hidden_dim , strict = True )
9230 check_positive_integer (n_message_layers , strict = True )
9331 check_positive_integer (n_update_layers , strict = True )
9432
95- # Layer for computing the message
9633 self .message_net = FeedForward (
9734 input_dimensions = 2 * node_feature_dim + edge_feature_dim + 1 ,
9835 output_dimensions = pos_dim ,
@@ -101,7 +38,6 @@ def __init__(
10138 func = activation ,
10239 )
10340
104- # Layer for updating the node features
10541 self .update_feat_net = FeedForward (
10642 input_dimensions = node_feature_dim + pos_dim ,
10743 output_dimensions = node_feature_dim ,
@@ -110,8 +46,6 @@ def __init__(
11046 func = activation ,
11147 )
11248
113- # Layer for updating the node positions
114- # The output dimension is set to 1 for equivariant updates
11549 self .update_pos_net = FeedForward (
11650 input_dimensions = pos_dim ,
11751 output_dimensions = 1 ,
@@ -120,9 +54,6 @@ def __init__(
12054 func = activation ,
12155 )
12256
123- # Placeholder for the messages
124- self ._m_ij = None
125-
12657 def forward (self , x , pos , edge_index , edge_attr = None ):
12758 """
12859 Forward pass of the block, triggering the message-passing routine.
@@ -158,28 +89,57 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
15889 :return: The message to be passed.
15990 :rtype: torch.Tensor
16091 """
161- # Compute the euclidean distance between the sender and recipient nodes
16292 diff = pos_i - pos_j
16393 dist = torch .norm (diff , dim = - 1 , keepdim = True ) ** 2
16494
165- # Compute the message input
16695 if edge_attr is None :
16796 input_ = torch .cat ((x_i , x_j , dist ), dim = - 1 )
16897 else :
16998 input_ = torch .cat ((x_i , x_j , dist , edge_attr ), dim = - 1 )
17099
171- # Compute the messages and save them for feature update
172- self . _m_ij = self .message_net ( input_ )
100+ m_ij = self . message_net ( input_ ) # message features
101+ message = diff * self .update_pos_net ( m_ij ) # equivariant message
173102
174- # Rescale the message by the euclidean distance
175- return diff * self .update_pos_net (self ._m_ij )
103+ return message , m_ij
176104
177- def update (self , message , x , pos , edge_index ):
105+ def aggregate (self , inputs , index , ptr = None , dim_size = None ):
106+ """
107+ Aggregate the messages at the nodes during message passing.
108+
109+ This method receives a tuple of tensors corresponding to the messages
110+ to be aggregated. Both messages are aggregated separately according to
111+ the specified aggregation scheme.
112+
113+ :param tuple(torch.Tensor) inputs: Tuple containing two messages to
114+ aggregate.
115+ :param torch.Tensor | LabelTensor index: The indices of target nodes
116+ for each message. This tensor specifies which node each message
117+ is aggregated into.
118+ :param torch.Tensor | LabelTensor ptr: Optional tensor to specify
119+ the slices of messages for each node (used in some aggregation
120+ strategies).
121+ :param int dim_size: Optional size of the output dimension, i.e.,
122+ number of nodes.
123+ :return: Tuple of aggregated tensors corresponding to
124+ (aggregated messages for position updates, aggregated messages for
125+ feature updates).
126+ :rtype: tuple(torch.Tensor, torch.Tensor)
127+ """
128+ # inputs is tuple (message, m_ij), we want to aggregate separately
129+ message , m_ij = inputs
130+
131+ # Aggregate messages as usual using self.aggr method
132+ agg_message = super ().aggregate (message , index , ptr , dim_size )
133+ agg_m_ij = super ().aggregate (m_ij , index , ptr , dim_size )
134+
135+ return agg_message , agg_m_ij
136+
137+ def update (self , aggregated_inputs , x , pos , edge_index ):
178138 """
179139 Update the node features and the node coordinates with the received
180140 messages.
181141
182- :param torch.Tensor message : The message to be passed.
142+ :param tuple( torch.Tensor) aggregated_inputs : The messages to be passed.
183143 :param x: The node features.
184144 :type x: torch.Tensor | LabelTensor
185145 :param pos: The euclidean coordinates of the nodes.
@@ -188,14 +148,14 @@ def update(self, message, x, pos, edge_index):
188148 :return: The updated node features and node positions.
189149 :rtype: tuple(torch.Tensor, torch.Tensor)
190150 """
191- # Sum the incoming messages for each node (m_i = sum_j m_ij)
192- m_sum = torch .zeros (x .size (0 ), self ._m_ij .shape [- 1 ], device = x .device )
193- m_sum .index_add_ (0 , edge_index [1 ], self ._m_ij )
151+ # aggregated_inputs is tuple (agg_message, agg_m_ij)
152+ agg_message , agg_m_ij = aggregated_inputs
153+
154+ # Update node features with aggregated m_ij
155+ x = self .update_feat_net (torch .cat ((x , agg_m_ij ), dim = - 1 ))
194156
195- # Update the node features
196- x = self .update_feat_net (torch .cat ((x , m_sum ), dim = - 1 ))
157+ # Degree for normalization of position updates
158+ c = degree (edge_index [1 ], pos .shape [0 ]).unsqueeze (- 1 ).clamp (min = 1 )
159+ pos = pos + agg_message / c
197160
198- # Update the node positions
199- c = degree (edge_index [1 ], pos .shape [0 ]).unsqueeze (- 1 )
200- pos = pos + message / c
201161 return x , pos
0 commit comments