Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
149
pina/model/block/message_passing/interaction_network_block.py
Normal file
149
pina/model/block/message_passing/interaction_network_block.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Module for the Interaction Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from ....utils import check_positive_integer
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
class InteractionNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the Interaction Network block.
|
||||
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Battaglia et al. in
|
||||
2016. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
multi-layer perceptron (MLP) to the concatenation of the sender and
|
||||
recipient node features. Messages are then aggregated using an aggregation
|
||||
scheme (e.g., sum, mean, min, max, or product).
|
||||
|
||||
The update step is performed by applying another MLP to the concatenation of
|
||||
the incoming messages and the node features.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Battaglia, P. W., et al. (2016).
|
||||
*Interaction Networks for Learning about Objects, Relations and
|
||||
Physics*.
|
||||
In Advances in Neural Information Processing Systems (NeurIPS 2016).
|
||||
DOI: `<https://doi.org/10.48550/arXiv.1612.00222>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim=0,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
activation=torch.nn.SiLU,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`InteractionNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
If edge_attr is not provided, it is assumed to be 0.
|
||||
Default is 0.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
network. Default is 2.
|
||||
:param int n_update_layers: The number of layers in the update network.
|
||||
Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.SiLU`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_message_layers` is not a positive integer.
|
||||
:raises AssertionError: If `n_update_layers` is not a positive integer.
|
||||
:raises AssertionError: If `edge_feature_dim` is not a non-negative
|
||||
integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_message_layers, strict=True)
|
||||
check_positive_integer(n_update_layers, strict=True)
|
||||
check_positive_integer(edge_feature_dim, strict=False)
|
||||
|
||||
# Message network
|
||||
self.message_net = FeedForward(
|
||||
input_dimensions=2 * node_feature_dim + edge_feature_dim,
|
||||
output_dimensions=hidden_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_message_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
# Update network
|
||||
self.update_net = FeedForward(
|
||||
input_dimensions=node_feature_dim + hidden_dim,
|
||||
output_dimensions=node_feature_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr=None):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indeces.
|
||||
:param edge_attr: The edge attributes. Default is None.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
|
||||
|
||||
def message(self, x_i, x_j, edge_attr):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_i: The node features of the recipient nodes.
|
||||
:type x_i: torch.Tensor | LabelTensor
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if edge_attr is None:
|
||||
input_ = torch.cat((x_i, x_j), dim=-1)
|
||||
else:
|
||||
input_ = torch.cat((x_i, x_j, edge_attr), dim=-1)
|
||||
return self.message_net(input_)
|
||||
|
||||
def update(self, message, x):
|
||||
"""
|
||||
Update the node features with the received messages.
|
||||
|
||||
:param torch.Tensor message: The message to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.update_net(torch.cat((x, message), dim=-1))
|
||||
Reference in New Issue
Block a user