|
3 | 3 | import torch |
4 | 4 | from torch_geometric.nn import MessagePassing |
5 | 5 | from torch_geometric.utils import degree |
| 6 | +from ....utils import check_positive_integer |
| 7 | +from ....model import FeedForward |
6 | 8 |
|
7 | 9 |
|
8 | | -class EnEquivariantGraphBlock(MessagePassing): |
| 10 | +class EnEquivariantNetworkBlock(MessagePassing): |
9 | 11 | """ |
10 | 12 | Implementation of the E(n) Equivariant Graph Neural Network block. |
11 | 13 |
|
12 | 14 | This block is used to perform message-passing between nodes and edges in a |
13 | | - graph neural network, following the scheme proposed by Satorras et al. (2021). |
14 | | - It serves as an inner block in a larger graph neural network architecture. |
| 15 | + graph neural network, following the scheme proposed by Satorras et al. in |
| 16 | + 2021. It serves as an inner block in a larger graph neural network |
| 17 | + architecture. |
15 | 18 |
|
16 | 19 | The message between two nodes connected by an edge is computed by applying a |
17 | 20 | linear transformation to the sender node features and the edge features, |
18 | | - followed by a non-linear activation function. Messages are then aggregated |
19 | | - using an aggregation scheme (e.g., sum, mean, min, max, or product). |
| 21 | + together with the squared euclidean distance between the sender and |
| 22 | + recipient node positions, followed by a non-linear activation function. |
| 23 | + Messages are then aggregated using an aggregation scheme (e.g., sum, mean, |
| 24 | + min, max, or product). |
20 | 25 |
|
21 | | - The update step is performed by a simple addition of the incoming messages |
22 | | - to the node features. |
| 26 | + The update step is performed by applying another MLP to the concatenation of |
| 27 | + the incoming messages and the node features. Here, also the node |
| 28 | + positions are updated by adding the incoming messages divided by the |
| 29 | + degree of the recipient node. |
23 | 30 |
|
24 | 31 | .. seealso:: |
25 | 32 |
|
26 | | - **Original reference** Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July). |
27 | | - E (n) equivariant graph neural networks. |
28 | | - In International conference on machine learning (pp. 9323-9332). PMLR. |
| 33 | + **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M. |
| 34 | + (2021). *E(n) Equivariant Graph Neural Networks.* |
| 35 | + In International Conference on Machine Learning. |
| 36 | + DOI: `<https://doi.org/10.48550/arXiv.2102.09844>_`. |
29 | 37 | """ |
30 | 38 |
|
31 | 39 | def __init__( |
32 | 40 | self, |
33 | | - channels_x, |
34 | | - channels_m, |
35 | | - channels_a, |
36 | | - aggr: str = "add", |
37 | | - hidden_channels: int = 64, |
38 | | - **kwargs, |
| 41 | + node_feature_dim, |
| 42 | + edge_feature_dim, |
| 43 | + pos_dim, |
| 44 | + hidden_dim=64, |
| 45 | + n_message_layers=2, |
| 46 | + n_update_layers=2, |
| 47 | + activation=torch.nn.SiLU, |
| 48 | + aggr="add", |
| 49 | + node_dim=-2, |
| 50 | + flow="source_to_target", |
39 | 51 | ): |
40 | 52 | """ |
41 | | - Initialization of the :class:`EnEquivariantGraphBlock` class. |
42 | | -
|
43 | | - :param int channels_x: The dimension of the node features. |
44 | | - :param int channels_m: The dimension of the Euclidean coordinates (should be =3). |
45 | | - :param int channels_a: The dimension of the edge features. |
| 53 | + Initialization of the :class:`EnEquivariantNetworkBlock` class. |
| 54 | +
|
| 55 | + :param int node_feature_dim: The dimension of the node features. |
| 56 | + :param int edge_feature_dim: The dimension of the edge features. |
| 57 | + :param int pos_dim: The dimension of the position features. |
| 58 | + :param int hidden_dim: The dimension of the hidden features. |
| 59 | + Default is 64. |
| 60 | + :param int n_message_layers: The number of layers in the message |
| 61 | + network. Default is 2. |
| 62 | + :param int n_update_layers: The number of layers in the update network. |
| 63 | + Default is 2. |
| 64 | + :param torch.nn.Module activation: The activation function. |
| 65 | + Default is :class:`torch.nn.SiLU`. |
46 | 66 | :param str aggr: The aggregation scheme to use for message passing. |
47 | 67 | Available options are "add", "mean", "min", "max", "mul". |
48 | 68 | See :class:`torch_geometric.nn.MessagePassing` for more details. |
49 | 69 | Default is "add". |
50 | | - :param int hidden_channels_dim: The hidden dimension in each MLPs initialized in the block. |
| 70 | + :param int node_dim: The axis along which to propagate. Default is -2. |
| 71 | + :param str flow: The direction of message passing. Available options |
| 72 | + are "source_to_target" and "target_to_source". |
| 73 | + The "source_to_target" flow means that messages are sent from |
| 74 | + the source node to the target node, while the "target_to_source" |
| 75 | + flow means that messages are sent from the target node to the |
| 76 | + source node. See :class:`torch_geometric.nn.MessagePassing` for more |
| 77 | + details. Default is "source_to_target". |
| 78 | + :raises AssertionError: If `node_feature_dim` is not a positive integer. |
| 79 | + :raises AssertionError: If `edge_feature_dim` is not a positive integer. |
| 80 | + :raises AssertionError: If `pos_dim` is not a positive integer. |
| 81 | + :raises AssertionError: If `hidden_dim` is not a positive integer. |
| 82 | + :raises AssertionError: If `n_message_layers` is not a positive integer. |
| 83 | + :raises AssertionError: If `n_update_layers` is not a positive integer. |
51 | 84 | """ |
52 | | - super().__init__(aggr=aggr, **kwargs) |
53 | | - |
54 | | - self.phi_e = torch.nn.Sequential( |
55 | | - torch.nn.Linear(2 * channels_x + 1 + channels_a, hidden_channels), |
56 | | - torch.nn.LayerNorm(hidden_channels), |
57 | | - torch.nn.SiLU(), |
58 | | - torch.nn.Linear(hidden_channels, channels_m), |
59 | | - torch.nn.LayerNorm(channels_m), |
60 | | - torch.nn.SiLU(), |
61 | | - ) |
62 | | - self.phi_pos = torch.nn.Sequential( |
63 | | - torch.nn.Linear(channels_m, hidden_channels), |
64 | | - torch.nn.LayerNorm(hidden_channels), |
65 | | - torch.nn.SiLU(), |
66 | | - torch.nn.Linear(hidden_channels, 1), |
| 85 | + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) |
| 86 | + |
| 87 | + # Check values |
| 88 | + check_positive_integer(node_feature_dim, strict=True) |
| 89 | + check_positive_integer(edge_feature_dim, strict=True) |
| 90 | + check_positive_integer(pos_dim, strict=True) |
| 91 | + check_positive_integer(hidden_dim, strict=True) |
| 92 | + check_positive_integer(n_message_layers, strict=True) |
| 93 | + check_positive_integer(n_update_layers, strict=True) |
| 94 | + |
| 95 | + # Layer for computing the message |
| 96 | + self.message_net = FeedForward( |
| 97 | + input_dimensions=2 * node_feature_dim + edge_feature_dim + 1, |
| 98 | + output_dimensions=pos_dim, |
| 99 | + inner_size=hidden_dim, |
| 100 | + n_layers=n_message_layers, |
| 101 | + func=activation, |
67 | 102 | ) |
68 | | - self.phi_x = torch.nn.Sequential( |
69 | | - torch.nn.Linear(channels_x + channels_m, hidden_channels), |
70 | | - torch.nn.LayerNorm(hidden_channels), |
71 | | - torch.nn.SiLU(), |
72 | | - torch.nn.Linear(hidden_channels, channels_x), |
| 103 | + |
| 104 | + # Layer for updating the node features |
| 105 | + self.update_net = FeedForward( |
| 106 | + input_dimensions=node_feature_dim + pos_dim, |
| 107 | + output_dimensions=node_feature_dim, |
| 108 | + inner_size=hidden_dim, |
| 109 | + n_layers=n_update_layers, |
| 110 | + func=activation, |
73 | 111 | ) |
74 | 112 |
|
75 | | - def forward(self, x, pos, edge_attr, edge_index, c=None): |
| 113 | + def forward(self, x, pos, edge_index, edge_attr): |
76 | 114 | """ |
77 | 115 | Forward pass of the block, triggering the message-passing routine. |
78 | 116 |
|
79 | 117 | :param x: The node features. |
80 | 118 | :type x: torch.Tensor | LabelTensor |
81 | | - :param pos_i: 3D Euclidean coordinates. |
82 | | - :type pos_i: torch.Tensor | LabelTensor |
83 | | - :param torch.Tensor edge_index: The edge indices. In the original formulation, |
84 | | - the messages are aggregated from all nodes, not only from the neighbours. |
85 | | - :return: The updated node features. |
86 | | - :rtype: torch.Tensor |
| 119 | + :param pos: The euclidean coordinates of the nodes. |
| 120 | + :type pos: torch.Tensor | LabelTensor |
| 121 | + :param torch.Tensor edge_index: The edge indices. |
| 122 | + :param edge_attr: The edge attributes. Default is None. |
| 123 | + :type edge_attr: torch.Tensor | LabelTensor |
| 124 | + :return: The updated node features and node positions. |
| 125 | + :rtype: tuple(torch.Tensor, torch.Tensor) |
87 | 126 | """ |
88 | | - if c is None: |
89 | | - c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) |
90 | 127 | return self.propagate( |
91 | | - edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, c=c |
| 128 | + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr |
92 | 129 | ) |
93 | 130 |
|
94 | 131 | def message(self, x_i, x_j, pos_i, pos_j, edge_attr): |
95 | 132 | """ |
96 | 133 | Compute the message to be passed between nodes and edges. |
97 | 134 |
|
98 | | - :param x_i: Node features of the sender nodes. |
| 135 | + :param x_i: The node features of the recipient nodes. |
99 | 136 | :type x_i: torch.Tensor | LabelTensor |
100 | | - :param pos_i: 3D Euclidean coordinates of the sender nodes. |
| 137 | + :param x_j: The node features of the sender nodes. |
| 138 | + :type x_j: torch.Tensor | LabelTensor |
| 139 | + :param pos_i: The node coordinates of the recipient nodes. |
101 | 140 | :type pos_i: torch.Tensor | LabelTensor |
| 141 | + :param pos_j: The node coordinates of the sender nodes. |
| 142 | + :type pos_j: torch.Tensor | LabelTensor |
102 | 143 | :param edge_attr: The edge attributes. |
103 | 144 | :type edge_attr: torch.Tensor | LabelTensor |
104 | 145 | :return: The message to be passed. |
105 | 146 | :rtype: torch.Tensor |
106 | 147 | """ |
107 | | - mpos_ij = self.phi_e( |
108 | | - torch.cat( |
109 | | - [ |
110 | | - x_i, |
111 | | - x_j, |
112 | | - torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2, |
113 | | - edge_attr, |
114 | | - ], |
115 | | - dim=-1, |
116 | | - ) |
117 | | - ) |
118 | | - mpos_ij = (pos_i - pos_j) * self.phi_pos(mpos_ij) |
119 | | - return mpos_ij |
| 148 | + dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2 |
| 149 | + input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1) |
| 150 | + return self.message_net(input_) |
120 | 151 |
|
121 | | - def update(self, message, x, pos, c): |
| 152 | + def update(self, message, x, pos, edge_index): |
122 | 153 | """ |
123 | | - Update the node features with the received messages. |
| 154 | + Update the node features and the node coordinates with the received |
| 155 | + messages. |
124 | 156 |
|
125 | 157 | :param torch.Tensor message: The message to be passed. |
126 | 158 | :param x: The node features. |
127 | 159 | :type x: torch.Tensor | LabelTensor |
128 | | - :param pos: The 3D Euclidean coordinates of the nodes. |
| 160 | + :param pos: The euclidean coordinates of the nodes. |
129 | 161 | :type pos: torch.Tensor | LabelTensor |
130 | | - :param c: the constant that divides the aggregated message (it should be (M-1), where M is the number of nodes) |
131 | | - :type pos: torch.Tensor |
132 | | - :return: The concatenation of the update position features and the updated node features. |
133 | | - :rtype: torch.Tensor |
| 162 | + :param torch.Tensor edge_index: The edge indices. |
| 163 | + :return: The updated node features and node positions. |
| 164 | + :rtype: tuple(torch.Tensor, torch.Tensor) |
134 | 165 | """ |
135 | | - x = self.phi_x(torch.cat([x, message], dim=-1)) |
136 | | - pos = pos + (message / c) |
| 166 | + # Update the node features |
| 167 | + x = self.update_net(torch.cat((x, message), dim=-1)) |
| 168 | + |
| 169 | + # Update the node positions |
| 170 | + c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) |
| 171 | + pos = pos + message / c |
137 | 172 | return pos, x |
0 commit comments