diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index ba059ddbc..957eb6e17 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -122,6 +122,18 @@ Blocks Continuous Convolution Block Orthogonal Block +Message Passing +------------------- + +.. toctree:: + :titlesonly: + + Deep Tensor Network Block + E(n) Equivariant Network Block + Interaction Network Block + Radial Field Network Block + Schnet Block + Reduction and Embeddings -------------------------- diff --git a/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst b/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst new file mode 100644 index 000000000..30121e5a6 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst @@ -0,0 +1,8 @@ +Deep Tensor Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.deep_tensor_network_block + +.. autoclass:: DeepTensorNetworkBlock + :members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst b/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst new file mode 100644 index 000000000..e2755c665 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst @@ -0,0 +1,8 @@ +E(n) Equivariant Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.en_equivariant_network_block + +.. autoclass:: EnEquivariantNetworkBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/docs/source/_rst/model/block/message_passing/interaction_network_block.rst b/docs/source/_rst/model/block/message_passing/interaction_network_block.rst new file mode 100644 index 000000000..ffac307e2 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/interaction_network_block.rst @@ -0,0 +1,8 @@ +Interaction Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.interaction_network_block + +.. autoclass:: InteractionNetworkBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst b/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst new file mode 100644 index 000000000..e05203f33 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst @@ -0,0 +1,8 @@ +Radial Field Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.radial_field_network_block + +.. autoclass:: RadialFieldNetworkBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/docs/source/_rst/model/block/message_passing/schnet_block.rst b/docs/source/_rst/model/block/message_passing/schnet_block.rst new file mode 100644 index 000000000..c5baa2730 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/schnet_block.rst @@ -0,0 +1,8 @@ +Schnet Block +================================== +.. currentmodule:: pina.model.block.message_passing.schnet_block + +.. autoclass:: SchnetBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/pina/model/block/message_passing/__init__.py b/pina/model/block/message_passing/__init__.py new file mode 100644 index 000000000..4eed0a611 --- /dev/null +++ b/pina/model/block/message_passing/__init__.py @@ -0,0 +1,15 @@ +"""Module for the message passing blocks of the graph neural models.""" + +__all__ = [ + "InteractionNetworkBlock", + "DeepTensorNetworkBlock", + "EnEquivariantNetworkBlock", + "RadialFieldNetworkBlock", + "SchnetBlock", +] + +from .interaction_network_block import InteractionNetworkBlock +from .deep_tensor_network_block import DeepTensorNetworkBlock +from .en_equivariant_network_block import EnEquivariantNetworkBlock +from .radial_field_network_block import RadialFieldNetworkBlock +from .schnet_block import SchnetBlock diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/model/block/message_passing/deep_tensor_network_block.py new file mode 100644 index 000000000..a2de3097a --- /dev/null +++ b/pina/model/block/message_passing/deep_tensor_network_block.py @@ -0,0 +1,138 @@ +"""Module for the Deep Tensor Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from ....utils import check_positive_integer + + +class DeepTensorNetworkBlock(MessagePassing): + """ + Implementation of the Deep Tensor Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schutt et al. in + 2017. It serves as an inner block in a larger graph neural network + architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. + (2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*. + Nature Communications 8, 13890 (2017). + DOI: ``_. + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim, + activation=torch.nn.Tanh, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`DeepTensorNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `edge_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check values + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(edge_feature_dim, strict=True) + + # Activation function + self.activation = activation() + + # Layer for processing node features + self.node_layer = torch.nn.Linear( + in_features=node_feature_dim, + out_features=node_feature_dim, + bias=True, + ) + + # Layer for processing edge features + self.edge_layer = torch.nn.Linear( + in_features=edge_feature_dim, + out_features=node_feature_dim, + bias=True, + ) + + # Layer for computing the message + self.message_layer = torch.nn.Linear( + in_features=node_feature_dim, + out_features=node_feature_dim, + bias=False, + ) + + def forward(self, x, edge_index, edge_attr): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + # Process node and edge features + filter_node = self.node_layer(x_j) + filter_edge = self.edge_layer(edge_attr) + + # Compute the message to be passed + message = self.message_layer(filter_node * filter_edge) + + return self.activation(message) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message diff --git a/pina/model/block/message_passing/en_equivariant_network_block.py b/pina/model/block/message_passing/en_equivariant_network_block.py new file mode 100644 index 000000000..904c1c6c9 --- /dev/null +++ b/pina/model/block/message_passing/en_equivariant_network_block.py @@ -0,0 +1,229 @@ +"""Module for the E(n) Equivariant Graph Neural Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import degree +from ....utils import check_positive_integer +from ....model import FeedForward + + +class EnEquivariantNetworkBlock(MessagePassing): + """ + Implementation of the E(n) Equivariant Graph Neural Network block. + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Satorras et al. in + 2021. It serves as an inner block in a larger graph neural network + architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + together with the squared euclidean distance between the sender and + recipient node positions, followed by a non-linear activation function. + Messages are then aggregated using an aggregation scheme (e.g., sum, mean, + min, max, or product). + + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. Here, also the node + positions are updated by adding the incoming messages divided by the + degree of the recipient node. + + .. seealso:: + + **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M. + (2021). *E(n) Equivariant Graph Neural Networks.* + In International Conference on Machine Learning. + DOI: ``_. + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim, + pos_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`EnEquivariantNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param int pos_dim: The dimension of the position features. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `edge_feature_dim` is a negative integer. + :raises AssertionError: If `pos_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_message_layers` is not a positive integer. + :raises AssertionError: If `n_update_layers` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check values + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(edge_feature_dim, strict=False) + check_positive_integer(pos_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_message_layers, strict=True) + check_positive_integer(n_update_layers, strict=True) + + # Layer for computing the message + self.message_net = FeedForward( + input_dimensions=2 * node_feature_dim + edge_feature_dim + 1, + output_dimensions=pos_dim, + inner_size=hidden_dim, + n_layers=n_message_layers, + func=activation, + ) + + # Layer for updating the node features + self.update_feat_net = FeedForward( + input_dimensions=node_feature_dim + pos_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) + + # Layer for updating the node positions + # The output dimension is set to 1 for equivariant updates + self.update_pos_net = FeedForward( + input_dimensions=pos_dim, + output_dimensions=1, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) + + def forward(self, x, pos, edge_index, edge_attr=None): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos: The euclidean coordinates of the nodes. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. + :param edge_attr: The edge attributes. Default is None. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features and node positions. + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + return self.propagate( + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr + ) + + def message(self, x_i, x_j, pos_i, pos_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param pos_i: The node coordinates of the recipient nodes. + :type pos_i: torch.Tensor | LabelTensor + :param pos_j: The node coordinates of the sender nodes. + :type pos_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + # Compute the euclidean distance between the sender and recipient nodes + diff = pos_i - pos_j + dist = torch.norm(diff, dim=-1, keepdim=True) ** 2 + + # Compute the message input + if edge_attr is None: + input_ = torch.cat((x_i, x_j, dist), dim=-1) + else: + input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1) + + # Compute the messages and their equivariant counterpart + m_ij = self.message_net(input_) + message = diff * self.update_pos_net(m_ij) + + return message, m_ij + + def aggregate(self, inputs, index, ptr=None, dim_size=None): + """ + Aggregate the messages at the nodes during message passing. + + This method receives a tuple of tensors corresponding to the messages + to be aggregated. Both messages are aggregated separately according to + the specified aggregation scheme. + + :param tuple(torch.Tensor) inputs: Tuple containing two messages to + aggregate. + :param index: The indices of target nodes for each message. This tensor + specifies which node each message is aggregated into. + :type index: torch.Tensor | LabelTensor + :param ptr: Optional tensor to specify the slices of messages for each + node (used in some aggregation strategies). Default is None. + :type ptr: torch.Tensor | LabelTensor + :param int dim_size: Optional size of the output dimension, i.e., + number of nodes. Default is None. + :return: Tuple of aggregated tensors corresponding to (aggregated + messages for position updates, aggregated messages for feature + updates). + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + # Unpack the messages from the inputs + message, m_ij = inputs + + # Aggregate messages as usual using self.aggr method + agg_message = super().aggregate(message, index, ptr, dim_size) + agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size) + + return agg_message, agg_m_ij + + def update(self, aggregated_inputs, x, pos, edge_index): + """ + Update the node features and the node coordinates with the received + messages. + + :param tuple(torch.Tensor) aggregated_inputs: The messages to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos: The euclidean coordinates of the nodes. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. + :return: The updated node features and node positions. + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + # aggregated_inputs is tuple (agg_message, agg_m_ij) + agg_message, agg_m_ij = aggregated_inputs + + # Update node features with aggregated m_ij + x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1)) + + # Degree for normalization of position updates + c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1) + pos = pos + agg_message / c + + return x, pos diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/model/block/message_passing/interaction_network_block.py new file mode 100644 index 000000000..7c6eb03f6 --- /dev/null +++ b/pina/model/block/message_passing/interaction_network_block.py @@ -0,0 +1,149 @@ +"""Module for the Interaction Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from ....utils import check_positive_integer +from ....model import FeedForward + + +class InteractionNetworkBlock(MessagePassing): + """ + Implementation of the Interaction Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Battaglia et al. in + 2016. It serves as an inner block in a larger graph neural network + architecture. + + The message between two nodes connected by an edge is computed by applying a + multi-layer perceptron (MLP) to the concatenation of the sender and + recipient node features. Messages are then aggregated using an aggregation + scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. + + .. seealso:: + + **Original reference**: Battaglia, P. W., et al. (2016). + *Interaction Networks for Learning about Objects, Relations and + Physics*. + In Advances in Neural Information Processing Systems (NeurIPS 2016). + DOI: ``_. + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim=0, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`InteractionNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + If edge_attr is not provided, it is assumed to be 0. + Default is 0. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_message_layers` is not a positive integer. + :raises AssertionError: If `n_update_layers` is not a positive integer. + :raises AssertionError: If `edge_feature_dim` is not a non-negative + integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check values + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_message_layers, strict=True) + check_positive_integer(n_update_layers, strict=True) + check_positive_integer(edge_feature_dim, strict=False) + + # Message network + self.message_net = FeedForward( + input_dimensions=2 * node_feature_dim + edge_feature_dim, + output_dimensions=hidden_dim, + inner_size=hidden_dim, + n_layers=n_message_layers, + func=activation, + ) + + # Update network + self.update_net = FeedForward( + input_dimensions=node_feature_dim + hidden_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) + + def forward(self, x, edge_index, edge_attr=None): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. Default is None. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_i, x_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + if edge_attr is None: + input_ = torch.cat((x_i, x_j), dim=-1) + else: + input_ = torch.cat((x_i, x_j, edge_attr), dim=-1) + return self.message_net(input_) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((x, message), dim=-1)) diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py new file mode 100644 index 000000000..ef621b10e --- /dev/null +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -0,0 +1,126 @@ +"""Module for the Radial Field Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import remove_self_loops +from ....utils import check_positive_integer +from ....model import FeedForward + + +class RadialFieldNetworkBlock(MessagePassing): + """ + Implementation of the Radial Field Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Köhler et al. in + 2020. It serves as an inner block in a larger graph neural network + architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the norm of the difference between the sender and + recipient node features, together with the radial distance between the + sender and recipient node features, followed by a non-linear activation + function. Messages are then aggregated using an aggregation scheme + (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Köhler, J., Klein, L., Noé, F. (2020). + *Equivariant Flows: Exact Likelihood Generative Learning for Symmetric + Densities*. + In International Conference on Machine Learning. + DOI: ``_. + """ + + def __init__( + self, + node_feature_dim, + hidden_dim=64, + n_layers=2, + activation=torch.nn.Tanh, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`RadialFieldNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_layers: The number of layers in the network. Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_layers` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check values + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_layers, strict=True) + + # Layer for processing node features + self.radial_net = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=hidden_dim, + n_layers=n_layers, + func=activation, + ) + + def forward(self, x, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. + :return: The updated node features. + :rtype: torch.Tensor + """ + edge_index, _ = remove_self_loops(edge_index) + return self.propagate(edge_index=edge_index, x=x) + + def message(self, x_i, x_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + r = x_i - x_j + return self.radial_net(torch.norm(r, dim=1, keepdim=True)) * r + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message diff --git a/pina/model/block/message_passing/schnet_block.py b/pina/model/block/message_passing/schnet_block.py new file mode 100644 index 000000000..94fe06364 --- /dev/null +++ b/pina/model/block/message_passing/schnet_block.py @@ -0,0 +1,158 @@ +"""Module for the Schnet block.""" + +import torch +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import remove_self_loops +from ....utils import check_positive_integer +from ....model import FeedForward + + +class SchnetBlock(MessagePassing): + """ + Implementation of the Schnet block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schütt et al. in + 2017. It serves as an inner block in a larger graph neural network + architecture. + + The message between two nodes connected by an edge is computed as the + product of the output of a MLP applied to the norm of the distance of the + node positions, and of another MLP applied to the node features. Messages + are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, + or product). + + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. + + .. seealso:: + + **Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, + H. E., Chmiela, S., Tkatchenko, A., Müller, K. R. (2017). + *Schnet: A continuous-filter convolutional neural network for modeling + quantum interactions.* + Advances in Neural Information Processing Systems, 30. + DOI: ``_. + """ + + def __init__( + self, + node_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`SchnetBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param int n_radial_layers: The number of layers in the radial field + network. Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_message_layers` is not a positive integer. + :raises AssertionError: If `n_update_layers` is not a positive integer. + :raises AssertionError: If `n_radial_layers` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check values + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_message_layers, strict=True) + check_positive_integer(n_update_layers, strict=True) + check_positive_integer(n_radial_layers, strict=True) + + # Layer for processing node distances + self.radial_net = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=hidden_dim, + n_layers=n_radial_layers, + func=activation, + ) + + # Layer for computing the message + self.message_net = FeedForward( + input_dimensions=node_feature_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, + n_layers=n_message_layers, + func=activation, + ) + + # Layer for updating the node features + self.update_net = FeedForward( + input_dimensions=2 * node_feature_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) + + def forward(self, x, pos, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. + :return: The updated node features. + :rtype: torch.Tensor + """ + edge_index, _ = remove_self_loops(edge_index) + return self.propagate(edge_index=edge_index, x=x, pos=pos) + + def message(self, x_i, pos_i, pos_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: Node features of the sender nodes. + :type x_i: torch.Tensor | LabelTensor + :param pos_i: The node coordinates of the recipient nodes. + :type pos_i: torch.Tensor | LabelTensor + :param pos_j: The node coordinates of the sender nodes. + :type pos_j: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + rad = self.radial_net(torch.norm(pos_i - pos_j, dim=-1, keepdim=True)) + msg = self.message_net(x_i) + return rad * msg + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((x, message), dim=-1)) diff --git a/pina/utils.py b/pina/utils.py index e3126de45..569ba632c 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -193,3 +193,22 @@ def chebyshev_roots(n): k = torch.arange(n) nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0] return nodes + + +def check_positive_integer(value, strict=True): + """ + Check if the value is a positive integer. + + :param int value: The value to check. + :param bool strict: If True, the value must be strictly positive. + Default is True. + :raises AssertionError: If the value is not a positive integer. + """ + if strict: + assert ( + isinstance(value, int) and value > 0 + ), f"Expected a strictly positive integer, got {value}." + else: + assert ( + isinstance(value, int) and value >= 0 + ), f"Expected a non-negative integer, got {value}." diff --git a/tests/test_messagepassing/test_deep_tensor_network_block.py b/tests/test_messagepassing/test_deep_tensor_network_block.py new file mode 100644 index 000000000..aa295d2db --- /dev/null +++ b/tests/test_messagepassing/test_deep_tensor_network_block.py @@ -0,0 +1,59 @@ +import pytest +import torch +from pina.model.block.message_passing import DeepTensorNetworkBlock + +# Data for testing +x = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) +edge_attr = torch.randn(20, 2) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +@pytest.mark.parametrize("edge_feature_dim", [3, 5]) +def test_constructor(node_feature_dim, edge_feature_dim): + + DeepTensorNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + DeepTensorNetworkBlock( + node_feature_dim=-1, edge_feature_dim=edge_feature_dim + ) + + # Should fail if edge_feature_dim is negative + with pytest.raises(AssertionError): + DeepTensorNetworkBlock( + node_feature_dim=node_feature_dim, edge_feature_dim=-1 + ) + + +def test_forward(): + + model = DeepTensorNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_attr.shape[1], + ) + + output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr) + assert output_.shape == x.shape + + +def test_backward(): + + model = DeepTensorNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_attr.shape[1], + ) + + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + edge_attr=edge_attr.requires_grad_(), + ) + + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape diff --git a/tests/test_messagepassing/test_equivariant_network_block.py b/tests/test_messagepassing/test_equivariant_network_block.py new file mode 100644 index 000000000..eea000a0e --- /dev/null +++ b/tests/test_messagepassing/test_equivariant_network_block.py @@ -0,0 +1,165 @@ +import pytest +import torch +from pina.model.block.message_passing import EnEquivariantNetworkBlock + +# Data for testing +x = torch.rand(10, 4) +pos = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) +edge_attr = torch.randn(20, 2) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +@pytest.mark.parametrize("pos_dim", [2, 3]) +def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): + + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=-1, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + ) + + # Should fail if edge_feature_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=-1, + pos_dim=pos_dim, + ) + + # Should fail if pos_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=-1, + ) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + hidden_dim=-1, + ) + + # Should fail if n_message_layers is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + n_message_layers=-1, + ) + + # Should fail if n_update_layers is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + n_update_layers=-1, + ) + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_forward(edge_feature_dim): + + model = EnEquivariantNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + pos_dim=pos.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model(edge_index=edge_index, x=x, pos=pos) + else: + output_ = model( + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr + ) + + assert output_[0].shape == x.shape + assert output_[1].shape == pos.shape + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_backward(edge_feature_dim): + + model = EnEquivariantNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + pos_dim=pos.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + pos=pos.requires_grad_(), + ) + else: + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + pos=pos.requires_grad_(), + edge_attr=edge_attr.requires_grad_(), + ) + + loss = torch.mean(output_[0]) + loss.backward() + assert x.grad.shape == x.shape + assert pos.grad.shape == pos.shape + + +def test_equivariance(): + + # Graph to be fully connected and undirected + edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) + + # Random rotation (det(rotation) should be 1) + rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, pos.shape[-1]) + + model = EnEquivariantNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=0, + pos_dim=pos.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ).eval() + + h1, pos1 = model(edge_index=edge_index, x=x, pos=pos) + h2, pos2 = model( + edge_index=edge_index, x=x, pos=pos @ rotation.T + translation + ) + + # Transform model output + pos1_transformed = (pos1 @ rotation.T) + translation + + assert torch.allclose(pos2, pos1_transformed, atol=1e-5) + assert torch.allclose(h1, h2, atol=1e-5) diff --git a/tests/test_messagepassing/test_interaction_network_block.py b/tests/test_messagepassing/test_interaction_network_block.py new file mode 100644 index 000000000..d121fb173 --- /dev/null +++ b/tests/test_messagepassing/test_interaction_network_block.py @@ -0,0 +1,84 @@ +import pytest +import torch +from pina.model.block.message_passing import InteractionNetworkBlock + +# Data for testing +x = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) +edge_attr = torch.randn(20, 2) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_constructor(node_feature_dim, edge_feature_dim): + + InteractionNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=-1) + + # Should fail if edge_feature_dim is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, edge_feature_dim=-1) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, hidden_dim=-1) + + # Should fail if n_message_layers is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, n_message_layers=-1) + + # Should fail if n_update_layers is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, n_update_layers=-1) + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_forward(edge_feature_dim): + + model = InteractionNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model(edge_index=edge_index, x=x) + else: + output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr) + assert output_.shape == x.shape + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_backward(edge_feature_dim): + + model = InteractionNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model(edge_index=edge_index, x=x.requires_grad_()) + else: + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + edge_attr=edge_attr.requires_grad_(), + ) + + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape diff --git a/tests/test_messagepassing/test_radial_field_network_block.py b/tests/test_messagepassing/test_radial_field_network_block.py new file mode 100644 index 000000000..4632ebfc9 --- /dev/null +++ b/tests/test_messagepassing/test_radial_field_network_block.py @@ -0,0 +1,92 @@ +import pytest +import torch +from pina.model.block.message_passing import RadialFieldNetworkBlock + +# Data for testing +x = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +def test_constructor(node_feature_dim): + + RadialFieldNetworkBlock( + node_feature_dim=node_feature_dim, + hidden_dim=64, + n_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + RadialFieldNetworkBlock( + node_feature_dim=-1, + hidden_dim=64, + n_layers=2, + ) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + RadialFieldNetworkBlock( + node_feature_dim=node_feature_dim, + hidden_dim=-1, + n_layers=2, + ) + + # Should fail if n_layers is negative + with pytest.raises(AssertionError): + RadialFieldNetworkBlock( + node_feature_dim=node_feature_dim, + hidden_dim=64, + n_layers=-1, + ) + + +def test_forward(): + + model = RadialFieldNetworkBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_layers=2, + ) + + output_ = model(edge_index=edge_index, x=x) + assert output_.shape == x.shape + + +def test_backward(): + + model = RadialFieldNetworkBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_layers=2, + ) + + output_ = model(edge_index=edge_index, x=x.requires_grad_()) + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape + + +def test_equivariance(): + + # Graph to be fully connected and undirected + edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) + + # Random rotation (det(rotation) should be 1) + rotation = torch.linalg.qr(torch.rand(x.shape[-1], x.shape[-1])).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, x.shape[-1]) + + model = RadialFieldNetworkBlock(node_feature_dim=x.shape[1]).eval() + + pos1 = model(edge_index=edge_index, x=x) + pos2 = model(edge_index=edge_index, x=x @ rotation.T + translation) + + # Transform model output + pos1_transformed = (pos1 @ rotation.T) + translation + + assert torch.allclose(pos2, pos1_transformed, atol=1e-5) diff --git a/tests/test_messagepassing/test_schnet_block.py b/tests/test_messagepassing/test_schnet_block.py new file mode 100644 index 000000000..51073b0f3 --- /dev/null +++ b/tests/test_messagepassing/test_schnet_block.py @@ -0,0 +1,95 @@ +import pytest +import torch +from pina.model.block.message_passing import SchnetBlock + +# Data for testing +x = torch.rand(10, 4) +pos = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +def test_constructor(node_feature_dim): + + SchnetBlock( + node_feature_dim=node_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=-1) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, hidden_dim=-1) + + # Should fail if n_message_layers is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, n_message_layers=-1) + + # Should fail if n_update_layers is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, n_update_layers=-1) + + # Should fail if n_radial_layers is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, n_radial_layers=-1) + + +def test_forward(): + + model = SchnetBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + ) + + output_ = model(edge_index=edge_index, x=x, pos=pos) + assert output_.shape == x.shape + + +def test_backward(): + + model = SchnetBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + ) + + output_ = model( + edge_index=edge_index, x=x.requires_grad_(), pos=pos.requires_grad_() + ) + + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape + + +def test_invariance(): + + # Graph to be fully connected and undirected + edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) + + # Random rotation (det(rotation) should be 1) + rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, pos.shape[-1]) + + model = SchnetBlock(node_feature_dim=x.shape[1]).eval() + + out1 = model(edge_index=edge_index, x=x, pos=pos) + out2 = model(edge_index=edge_index, x=x, pos=pos @ rotation.T + translation) + + assert torch.allclose(out1, out2, atol=1e-5) diff --git a/tests/test_utils.py b/tests/test_utils.py index a641c3838..7e8518995 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,12 +1,9 @@ import torch +import pytest -from pina.utils import merge_tensors -from pina.label_tensor import LabelTensor from pina import LabelTensor -from pina.domain import EllipsoidDomain, CartesianDomain -from pina.utils import check_consistency -import pytest -from pina.domain import DomainInterface +from pina.utils import merge_tensors, check_consistency, check_positive_integer +from pina.domain import EllipsoidDomain, CartesianDomain, DomainInterface def test_merge_tensors(): @@ -50,3 +47,24 @@ def test_check_consistency_incorrect(): check_consistency(torch.Tensor, DomainInterface, subclass=True) with pytest.raises(ValueError): check_consistency(ellipsoid1, torch.Tensor) + + +@pytest.mark.parametrize("value", [0, 1, 2, 3, 10]) +@pytest.mark.parametrize("strict", [True, False]) +def test_check_positive_integer(value, strict): + if value != 0: + check_positive_integer(value, strict=strict) + else: + check_positive_integer(value, strict=False) + + # Should fail if value is negative + with pytest.raises(AssertionError): + check_positive_integer(-1, strict=strict) + + # Should fail if value is not an integer + with pytest.raises(AssertionError): + check_positive_integer(1.5, strict=strict) + + # Should fail if value is not a number + with pytest.raises(AssertionError): + check_positive_integer("string", strict=strict)