88
99
1010class EnEquivariantNetworkBlock (MessagePassing ):
11+ """
12+ Implementation of the E(n) Equivariant Graph Neural Network block.
13+ This block is used to perform message-passing between nodes and edges in a
14+ graph neural network, following the scheme proposed by Satorras et al. in
15+ 2021. It serves as an inner block in a larger graph neural network
16+ architecture.
17+ The message between two nodes connected by an edge is computed by applying a
18+ linear transformation to the sender node features and the edge features,
19+ together with the squared euclidean distance between the sender and
20+ recipient node positions, followed by a non-linear activation function.
21+ Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
22+ min, max, or product).
23+ The update step is performed by applying another MLP to the concatenation of
24+ the incoming messages and the node features. Here, also the node
25+ positions are updated by adding the incoming messages divided by the
26+ degree of the recipient node.
27+ .. seealso::
28+ **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
29+ (2021). *E(n) Equivariant Graph Neural Networks.*
30+ In International Conference on Machine Learning.
31+ DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
32+ """
33+
1134 def __init__ (
1235 self ,
1336 node_feature_dim ,
@@ -21,15 +44,49 @@ def __init__(
2144 node_dim = - 2 ,
2245 flow = "source_to_target" ,
2346 ):
47+ """
48+ Initialization of the :class:`EnEquivariantNetworkBlock` class.
49+ :param int node_feature_dim: The dimension of the node features.
50+ :param int edge_feature_dim: The dimension of the edge features.
51+ :param int pos_dim: The dimension of the position features.
52+ :param int hidden_dim: The dimension of the hidden features.
53+ Default is 64.
54+ :param int n_message_layers: The number of layers in the message
55+ network. Default is 2.
56+ :param int n_update_layers: The number of layers in the update network.
57+ Default is 2.
58+ :param torch.nn.Module activation: The activation function.
59+ Default is :class:`torch.nn.SiLU`.
60+ :param str aggr: The aggregation scheme to use for message passing.
61+ Available options are "add", "mean", "min", "max", "mul".
62+ See :class:`torch_geometric.nn.MessagePassing` for more details.
63+ Default is "add".
64+ :param int node_dim: The axis along which to propagate. Default is -2.
65+ :param str flow: The direction of message passing. Available options
66+ are "source_to_target" and "target_to_source".
67+ The "source_to_target" flow means that messages are sent from
68+ the source node to the target node, while the "target_to_source"
69+ flow means that messages are sent from the target node to the
70+ source node. See :class:`torch_geometric.nn.MessagePassing` for more
71+ details. Default is "source_to_target".
72+ :raises AssertionError: If `node_feature_dim` is not a positive integer.
73+ :raises AssertionError: If `edge_feature_dim` is a negative integer.
74+ :raises AssertionError: If `pos_dim` is not a positive integer.
75+ :raises AssertionError: If `hidden_dim` is not a positive integer.
76+ :raises AssertionError: If `n_message_layers` is not a positive integer.
77+ :raises AssertionError: If `n_update_layers` is not a positive integer.
78+ """
2479 super ().__init__ (aggr = aggr , node_dim = node_dim , flow = flow )
2580
81+ # Check values
2682 check_positive_integer (node_feature_dim , strict = True )
2783 check_positive_integer (edge_feature_dim , strict = False )
2884 check_positive_integer (pos_dim , strict = True )
2985 check_positive_integer (hidden_dim , strict = True )
3086 check_positive_integer (n_message_layers , strict = True )
3187 check_positive_integer (n_update_layers , strict = True )
3288
89+ # Layer for computing the message
3390 self .message_net = FeedForward (
3491 input_dimensions = 2 * node_feature_dim + edge_feature_dim + 1 ,
3592 output_dimensions = pos_dim ,
@@ -38,6 +95,7 @@ def __init__(
3895 func = activation ,
3996 )
4097
98+ # Layer for updating the node features
4199 self .update_feat_net = FeedForward (
42100 input_dimensions = node_feature_dim + pos_dim ,
43101 output_dimensions = node_feature_dim ,
@@ -46,6 +104,8 @@ def __init__(
46104 func = activation ,
47105 )
48106
107+ # Layer for updating the node positions
108+ # The output dimension is set to 1 for equivariant updates
49109 self .update_pos_net = FeedForward (
50110 input_dimensions = pos_dim ,
51111 output_dimensions = 1 ,
@@ -87,18 +147,21 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
87147 :param edge_attr: The edge attributes.
88148 :type edge_attr: torch.Tensor | LabelTensor
89149 :return: The message to be passed.
90- :rtype: torch.Tensor
150+ :rtype: tuple( torch.Tensor, torch.Tensor)
91151 """
152+ # Compute the euclidean distance between the sender and recipient nodes
92153 diff = pos_i - pos_j
93154 dist = torch .norm (diff , dim = - 1 , keepdim = True ) ** 2
94155
156+ # Compute the message input
95157 if edge_attr is None :
96158 input_ = torch .cat ((x_i , x_j , dist ), dim = - 1 )
97159 else :
98160 input_ = torch .cat ((x_i , x_j , dist , edge_attr ), dim = - 1 )
99161
100- m_ij = self .message_net (input_ ) # message features
101- message = diff * self .update_pos_net (m_ij ) # equivariant message
162+ # Compute the messages and their equivariant counterpart
163+ m_ij = self .message_net (input_ )
164+ message = diff * self .update_pos_net (m_ij )
102165
103166 return message , m_ij
104167
@@ -112,20 +175,20 @@ def aggregate(self, inputs, index, ptr=None, dim_size=None):
112175
113176 :param tuple(torch.Tensor) inputs: Tuple containing two messages to
114177 aggregate.
115- :param torch.Tensor | LabelTensor index: The indices of target nodes
116- for each message. This tensor specifies which node each message
117- is aggregated into.
118- :param torch.Tensor | LabelTensor ptr: Optional tensor to specify
119- the slices of messages for each node (used in some aggregation
120- strategies).
178+ :param index: The indices of target nodes for each message. This tensor
179+ specifies which node each message is aggregated into.
180+ :type index: torch.Tensor | LabelTensor
181+ :param ptr: Optional tensor to specify the slices of messages for each
182+ node (used in some aggregation strategies). Default is None.
183+ :type ptr: torch.Tensor | LabelTensor
121184 :param int dim_size: Optional size of the output dimension, i.e.,
122- number of nodes.
123- :return: Tuple of aggregated tensors corresponding to
124- (aggregated messages for position updates, aggregated messages for
125- feature updates).
185+ number of nodes. Default is None.
186+ :return: Tuple of aggregated tensors corresponding to (aggregated
187+ messages for position updates, aggregated messages for feature
188+ updates).
126189 :rtype: tuple(torch.Tensor, torch.Tensor)
127190 """
128- # inputs is tuple (message, m_ij), we want to aggregate separately
191+ # Unpack the messages from the inputs
129192 message , m_ij = inputs
130193
131194 # Aggregate messages as usual using self.aggr method
0 commit comments