Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
13
pina/model/block/message_passing/__init__.py
Normal file
13
pina/model/block/message_passing/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Module for the message passing blocks of the graph neural models."""
|
||||
|
||||
__all__ = [
|
||||
"InteractionNetworkBlock",
|
||||
"DeepTensorNetworkBlock",
|
||||
"EnEquivariantNetworkBlock",
|
||||
"RadialFieldNetworkBlock",
|
||||
]
|
||||
|
||||
from .interaction_network_block import InteractionNetworkBlock
|
||||
from .deep_tensor_network_block import DeepTensorNetworkBlock
|
||||
from .en_equivariant_network_block import EnEquivariantNetworkBlock
|
||||
from .radial_field_network_block import RadialFieldNetworkBlock
|
||||
138
pina/model/block/message_passing/deep_tensor_network_block.py
Normal file
138
pina/model/block/message_passing/deep_tensor_network_block.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Module for the Deep Tensor Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from ....utils import check_positive_integer
|
||||
|
||||
|
||||
class DeepTensorNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the Deep Tensor Network block.
|
||||
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Schutt et al. in
|
||||
2017. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
linear transformation to the sender node features and the edge features,
|
||||
followed by a non-linear activation function. Messages are then aggregated
|
||||
using an aggregation scheme (e.g., sum, mean, min, max, or product).
|
||||
|
||||
The update step is performed by a simple addition of the incoming messages
|
||||
to the node features.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
|
||||
(2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*.
|
||||
Nature Communications 8, 13890 (2017).
|
||||
DOI: `<https://doi.org/10.1038/ncomms13890>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
activation=torch.nn.Tanh,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`DeepTensorNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.Tanh`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(edge_feature_dim, strict=True)
|
||||
|
||||
# Activation function
|
||||
self.activation = activation()
|
||||
|
||||
# Layer for processing node features
|
||||
self.node_layer = torch.nn.Linear(
|
||||
in_features=node_feature_dim,
|
||||
out_features=node_feature_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Layer for processing edge features
|
||||
self.edge_layer = torch.nn.Linear(
|
||||
in_features=edge_feature_dim,
|
||||
out_features=node_feature_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Layer for computing the message
|
||||
self.message_layer = torch.nn.Linear(
|
||||
in_features=node_feature_dim,
|
||||
out_features=node_feature_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indeces.
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
|
||||
|
||||
def message(self, x_j, edge_attr):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Process node and edge features
|
||||
filter_node = self.node_layer(x_j)
|
||||
filter_edge = self.edge_layer(edge_attr)
|
||||
|
||||
# Compute the message to be passed
|
||||
message = self.message_layer(filter_node * filter_edge)
|
||||
|
||||
return self.activation(message)
|
||||
|
||||
def update(self, message, x):
|
||||
"""
|
||||
Update the node features with the received messages.
|
||||
|
||||
:param torch.Tensor message: The message to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return x + message
|
||||
229
pina/model/block/message_passing/en_equivariant_network_block.py
Normal file
229
pina/model/block/message_passing/en_equivariant_network_block.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Module for the E(n) Equivariant Graph Neural Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from torch_geometric.utils import degree
|
||||
from ....utils import check_positive_integer
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
class EnEquivariantNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the E(n) Equivariant Graph Neural Network block.
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Satorras et al. in
|
||||
2021. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
linear transformation to the sender node features and the edge features,
|
||||
together with the squared euclidean distance between the sender and
|
||||
recipient node positions, followed by a non-linear activation function.
|
||||
Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
|
||||
min, max, or product).
|
||||
|
||||
The update step is performed by applying another MLP to the concatenation of
|
||||
the incoming messages and the node features. Here, also the node
|
||||
positions are updated by adding the incoming messages divided by the
|
||||
degree of the recipient node.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
|
||||
(2021). *E(n) Equivariant Graph Neural Networks.*
|
||||
In International Conference on Machine Learning.
|
||||
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
pos_dim,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
activation=torch.nn.SiLU,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`EnEquivariantNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
:param int pos_dim: The dimension of the position features.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
network. Default is 2.
|
||||
:param int n_update_layers: The number of layers in the update network.
|
||||
Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.SiLU`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `edge_feature_dim` is a negative integer.
|
||||
:raises AssertionError: If `pos_dim` is not a positive integer.
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_message_layers` is not a positive integer.
|
||||
:raises AssertionError: If `n_update_layers` is not a positive integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(edge_feature_dim, strict=False)
|
||||
check_positive_integer(pos_dim, strict=True)
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_message_layers, strict=True)
|
||||
check_positive_integer(n_update_layers, strict=True)
|
||||
|
||||
# Layer for computing the message
|
||||
self.message_net = FeedForward(
|
||||
input_dimensions=2 * node_feature_dim + edge_feature_dim + 1,
|
||||
output_dimensions=pos_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_message_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
# Layer for updating the node features
|
||||
self.update_feat_net = FeedForward(
|
||||
input_dimensions=node_feature_dim + pos_dim,
|
||||
output_dimensions=node_feature_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
# Layer for updating the node positions
|
||||
# The output dimension is set to 1 for equivariant updates
|
||||
self.update_pos_net = FeedForward(
|
||||
input_dimensions=pos_dim,
|
||||
output_dimensions=1,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, edge_index, edge_attr=None):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param pos: The euclidean coordinates of the nodes.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:param edge_attr: The edge attributes. Default is None.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The updated node features and node positions.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
return self.propagate(
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
|
||||
)
|
||||
|
||||
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_i: The node features of the recipient nodes.
|
||||
:type x_i: torch.Tensor | LabelTensor
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:param pos_i: The node coordinates of the recipient nodes.
|
||||
:type pos_i: torch.Tensor | LabelTensor
|
||||
:param pos_j: The node coordinates of the sender nodes.
|
||||
:type pos_j: torch.Tensor | LabelTensor
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
# Compute the euclidean distance between the sender and recipient nodes
|
||||
diff = pos_i - pos_j
|
||||
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
|
||||
|
||||
# Compute the message input
|
||||
if edge_attr is None:
|
||||
input_ = torch.cat((x_i, x_j, dist), dim=-1)
|
||||
else:
|
||||
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
|
||||
|
||||
# Compute the messages and their equivariant counterpart
|
||||
m_ij = self.message_net(input_)
|
||||
message = diff * self.update_pos_net(m_ij)
|
||||
|
||||
return message, m_ij
|
||||
|
||||
def aggregate(self, inputs, index, ptr=None, dim_size=None):
|
||||
"""
|
||||
Aggregate the messages at the nodes during message passing.
|
||||
|
||||
This method receives a tuple of tensors corresponding to the messages
|
||||
to be aggregated. Both messages are aggregated separately according to
|
||||
the specified aggregation scheme.
|
||||
|
||||
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
|
||||
aggregate.
|
||||
:param index: The indices of target nodes for each message. This tensor
|
||||
specifies which node each message is aggregated into.
|
||||
:type index: torch.Tensor | LabelTensor
|
||||
:param ptr: Optional tensor to specify the slices of messages for each
|
||||
node (used in some aggregation strategies). Default is None.
|
||||
:type ptr: torch.Tensor | LabelTensor
|
||||
:param int dim_size: Optional size of the output dimension, i.e.,
|
||||
number of nodes. Default is None.
|
||||
:return: Tuple of aggregated tensors corresponding to (aggregated
|
||||
messages for position updates, aggregated messages for feature
|
||||
updates).
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
# Unpack the messages from the inputs
|
||||
message, m_ij = inputs
|
||||
|
||||
# Aggregate messages as usual using self.aggr method
|
||||
agg_message = super().aggregate(message, index, ptr, dim_size)
|
||||
agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size)
|
||||
|
||||
return agg_message, agg_m_ij
|
||||
|
||||
def update(self, aggregated_inputs, x, pos, edge_index):
|
||||
"""
|
||||
Update the node features and the node coordinates with the received
|
||||
messages.
|
||||
|
||||
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param pos: The euclidean coordinates of the nodes.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:return: The updated node features and node positions.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
# aggregated_inputs is tuple (agg_message, agg_m_ij)
|
||||
agg_message, agg_m_ij = aggregated_inputs
|
||||
|
||||
# Update node features with aggregated m_ij
|
||||
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
|
||||
|
||||
# Degree for normalization of position updates
|
||||
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
|
||||
pos = pos + agg_message / c
|
||||
|
||||
return x, pos
|
||||
149
pina/model/block/message_passing/interaction_network_block.py
Normal file
149
pina/model/block/message_passing/interaction_network_block.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Module for the Interaction Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from ....utils import check_positive_integer
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
class InteractionNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the Interaction Network block.
|
||||
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Battaglia et al. in
|
||||
2016. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
multi-layer perceptron (MLP) to the concatenation of the sender and
|
||||
recipient node features. Messages are then aggregated using an aggregation
|
||||
scheme (e.g., sum, mean, min, max, or product).
|
||||
|
||||
The update step is performed by applying another MLP to the concatenation of
|
||||
the incoming messages and the node features.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Battaglia, P. W., et al. (2016).
|
||||
*Interaction Networks for Learning about Objects, Relations and
|
||||
Physics*.
|
||||
In Advances in Neural Information Processing Systems (NeurIPS 2016).
|
||||
DOI: `<https://doi.org/10.48550/arXiv.1612.00222>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim=0,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
activation=torch.nn.SiLU,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`InteractionNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
If edge_attr is not provided, it is assumed to be 0.
|
||||
Default is 0.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
network. Default is 2.
|
||||
:param int n_update_layers: The number of layers in the update network.
|
||||
Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.SiLU`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_message_layers` is not a positive integer.
|
||||
:raises AssertionError: If `n_update_layers` is not a positive integer.
|
||||
:raises AssertionError: If `edge_feature_dim` is not a non-negative
|
||||
integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_message_layers, strict=True)
|
||||
check_positive_integer(n_update_layers, strict=True)
|
||||
check_positive_integer(edge_feature_dim, strict=False)
|
||||
|
||||
# Message network
|
||||
self.message_net = FeedForward(
|
||||
input_dimensions=2 * node_feature_dim + edge_feature_dim,
|
||||
output_dimensions=hidden_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_message_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
# Update network
|
||||
self.update_net = FeedForward(
|
||||
input_dimensions=node_feature_dim + hidden_dim,
|
||||
output_dimensions=node_feature_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr=None):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indeces.
|
||||
:param edge_attr: The edge attributes. Default is None.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
|
||||
|
||||
def message(self, x_i, x_j, edge_attr):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_i: The node features of the recipient nodes.
|
||||
:type x_i: torch.Tensor | LabelTensor
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if edge_attr is None:
|
||||
input_ = torch.cat((x_i, x_j), dim=-1)
|
||||
else:
|
||||
input_ = torch.cat((x_i, x_j, edge_attr), dim=-1)
|
||||
return self.message_net(input_)
|
||||
|
||||
def update(self, message, x):
|
||||
"""
|
||||
Update the node features with the received messages.
|
||||
|
||||
:param torch.Tensor message: The message to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.update_net(torch.cat((x, message), dim=-1))
|
||||
126
pina/model/block/message_passing/radial_field_network_block.py
Normal file
126
pina/model/block/message_passing/radial_field_network_block.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Module for the Radial Field Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from torch_geometric.utils import remove_self_loops
|
||||
from ....utils import check_positive_integer
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
class RadialFieldNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the Radial Field Network block.
|
||||
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Köhler et al. in
|
||||
2020. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
linear transformation to the norm of the difference between the sender and
|
||||
recipient node features, together with the radial distance between the
|
||||
sender and recipient node features, followed by a non-linear activation
|
||||
function. Messages are then aggregated using an aggregation scheme
|
||||
(e.g., sum, mean, min, max, or product).
|
||||
|
||||
The update step is performed by a simple addition of the incoming messages
|
||||
to the node features.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference** Köhler, J., Klein, L., Noé, F. (2020).
|
||||
*Equivariant Flows: Exact Likelihood Generative Learning for Symmetric
|
||||
Densities*.
|
||||
In International Conference on Machine Learning.
|
||||
DOI: `<https://doi.org/10.48550/arXiv.2006.02425>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
hidden_dim=64,
|
||||
n_layers=2,
|
||||
activation=torch.nn.Tanh,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`RadialFieldNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_layers: The number of layers in the network. Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.Tanh`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_layers` is not a positive integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_layers, strict=True)
|
||||
|
||||
# Layer for processing node features
|
||||
self.radial_net = FeedForward(
|
||||
input_dimensions=1,
|
||||
output_dimensions=1,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
edge_index, _ = remove_self_loops(edge_index)
|
||||
return self.propagate(edge_index=edge_index, x=x)
|
||||
|
||||
def message(self, x_i, x_j):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_i: The node features of the recipient nodes.
|
||||
:type x_i: torch.Tensor | LabelTensor
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
r = x_i - x_j
|
||||
return self.radial_net(torch.norm(r, dim=1, keepdim=True)) * r
|
||||
|
||||
def update(self, message, x):
|
||||
"""
|
||||
Update the node features with the received messages.
|
||||
|
||||
:param torch.Tensor message: The message to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return x + message
|
||||
Reference in New Issue
Block a user