Dev Update (#582)

* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <dariocos99@gmail.com>

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
Dario Coscia
2025-06-13 17:34:37 +02:00
committed by GitHub
parent 6b355b45de
commit 7bf7d34d0f
40 changed files with 1963 additions and 581 deletions

View File

@@ -0,0 +1,13 @@
"""Module for the message passing blocks of the graph neural models."""
__all__ = [
"InteractionNetworkBlock",
"DeepTensorNetworkBlock",
"EnEquivariantNetworkBlock",
"RadialFieldNetworkBlock",
]
from .interaction_network_block import InteractionNetworkBlock
from .deep_tensor_network_block import DeepTensorNetworkBlock
from .en_equivariant_network_block import EnEquivariantNetworkBlock
from .radial_field_network_block import RadialFieldNetworkBlock

View File

@@ -0,0 +1,138 @@
"""Module for the Deep Tensor Network block."""
import torch
from torch_geometric.nn import MessagePassing
from ....utils import check_positive_integer
class DeepTensorNetworkBlock(MessagePassing):
"""
Implementation of the Deep Tensor Network block.
This block is used to perform message-passing between nodes and edges in a
graph neural network, following the scheme proposed by Schutt et al. in
2017. It serves as an inner block in a larger graph neural network
architecture.
The message between two nodes connected by an edge is computed by applying a
linear transformation to the sender node features and the edge features,
followed by a non-linear activation function. Messages are then aggregated
using an aggregation scheme (e.g., sum, mean, min, max, or product).
The update step is performed by a simple addition of the incoming messages
to the node features.
.. seealso::
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
(2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*.
Nature Communications 8, 13890 (2017).
DOI: `<https://doi.org/10.1038/ncomms13890>`_.
"""
def __init__(
self,
node_feature_dim,
edge_feature_dim,
activation=torch.nn.Tanh,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`DeepTensorNetworkBlock` class.
:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.Tanh`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If `node_feature_dim` is not a positive integer.
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
"""
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
# Check values
check_positive_integer(node_feature_dim, strict=True)
check_positive_integer(edge_feature_dim, strict=True)
# Activation function
self.activation = activation()
# Layer for processing node features
self.node_layer = torch.nn.Linear(
in_features=node_feature_dim,
out_features=node_feature_dim,
bias=True,
)
# Layer for processing edge features
self.edge_layer = torch.nn.Linear(
in_features=edge_feature_dim,
out_features=node_feature_dim,
bias=True,
)
# Layer for computing the message
self.message_layer = torch.nn.Linear(
in_features=node_feature_dim,
out_features=node_feature_dim,
bias=False,
)
def forward(self, x, edge_index, edge_attr):
"""
Forward pass of the block, triggering the message-passing routine.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indeces.
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor | LabelTensor
:return: The updated node features.
:rtype: torch.Tensor
"""
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
def message(self, x_j, edge_attr):
"""
Compute the message to be passed between nodes and edges.
:param x_j: The node features of the sender nodes.
:type x_j: torch.Tensor | LabelTensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor | LabelTensor
:return: The message to be passed.
:rtype: torch.Tensor
"""
# Process node and edge features
filter_node = self.node_layer(x_j)
filter_edge = self.edge_layer(edge_attr)
# Compute the message to be passed
message = self.message_layer(filter_node * filter_edge)
return self.activation(message)
def update(self, message, x):
"""
Update the node features with the received messages.
:param torch.Tensor message: The message to be passed.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:return: The updated node features.
:rtype: torch.Tensor
"""
return x + message

View File

@@ -0,0 +1,229 @@
"""Module for the E(n) Equivariant Graph Neural Network block."""
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from ....utils import check_positive_integer
from ....model import FeedForward
class EnEquivariantNetworkBlock(MessagePassing):
"""
Implementation of the E(n) Equivariant Graph Neural Network block.
This block is used to perform message-passing between nodes and edges in a
graph neural network, following the scheme proposed by Satorras et al. in
2021. It serves as an inner block in a larger graph neural network
architecture.
The message between two nodes connected by an edge is computed by applying a
linear transformation to the sender node features and the edge features,
together with the squared euclidean distance between the sender and
recipient node positions, followed by a non-linear activation function.
Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
min, max, or product).
The update step is performed by applying another MLP to the concatenation of
the incoming messages and the node features. Here, also the node
positions are updated by adding the incoming messages divided by the
degree of the recipient node.
.. seealso::
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
(2021). *E(n) Equivariant Graph Neural Networks.*
In International Conference on Machine Learning.
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
"""
def __init__(
self,
node_feature_dim,
edge_feature_dim,
pos_dim,
hidden_dim=64,
n_message_layers=2,
n_update_layers=2,
activation=torch.nn.SiLU,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`EnEquivariantNetworkBlock` class.
:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
:param int pos_dim: The dimension of the position features.
:param int hidden_dim: The dimension of the hidden features.
Default is 64.
:param int n_message_layers: The number of layers in the message
network. Default is 2.
:param int n_update_layers: The number of layers in the update network.
Default is 2.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.SiLU`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If `node_feature_dim` is not a positive integer.
:raises AssertionError: If `edge_feature_dim` is a negative integer.
:raises AssertionError: If `pos_dim` is not a positive integer.
:raises AssertionError: If `hidden_dim` is not a positive integer.
:raises AssertionError: If `n_message_layers` is not a positive integer.
:raises AssertionError: If `n_update_layers` is not a positive integer.
"""
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
# Check values
check_positive_integer(node_feature_dim, strict=True)
check_positive_integer(edge_feature_dim, strict=False)
check_positive_integer(pos_dim, strict=True)
check_positive_integer(hidden_dim, strict=True)
check_positive_integer(n_message_layers, strict=True)
check_positive_integer(n_update_layers, strict=True)
# Layer for computing the message
self.message_net = FeedForward(
input_dimensions=2 * node_feature_dim + edge_feature_dim + 1,
output_dimensions=pos_dim,
inner_size=hidden_dim,
n_layers=n_message_layers,
func=activation,
)
# Layer for updating the node features
self.update_feat_net = FeedForward(
input_dimensions=node_feature_dim + pos_dim,
output_dimensions=node_feature_dim,
inner_size=hidden_dim,
n_layers=n_update_layers,
func=activation,
)
# Layer for updating the node positions
# The output dimension is set to 1 for equivariant updates
self.update_pos_net = FeedForward(
input_dimensions=pos_dim,
output_dimensions=1,
inner_size=hidden_dim,
n_layers=n_update_layers,
func=activation,
)
def forward(self, x, pos, edge_index, edge_attr=None):
"""
Forward pass of the block, triggering the message-passing routine.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param pos: The euclidean coordinates of the nodes.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indices.
:param edge_attr: The edge attributes. Default is None.
:type edge_attr: torch.Tensor | LabelTensor
:return: The updated node features and node positions.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
return self.propagate(
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
)
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
"""
Compute the message to be passed between nodes and edges.
:param x_i: The node features of the recipient nodes.
:type x_i: torch.Tensor | LabelTensor
:param x_j: The node features of the sender nodes.
:type x_j: torch.Tensor | LabelTensor
:param pos_i: The node coordinates of the recipient nodes.
:type pos_i: torch.Tensor | LabelTensor
:param pos_j: The node coordinates of the sender nodes.
:type pos_j: torch.Tensor | LabelTensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor | LabelTensor
:return: The message to be passed.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
# Compute the euclidean distance between the sender and recipient nodes
diff = pos_i - pos_j
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
# Compute the message input
if edge_attr is None:
input_ = torch.cat((x_i, x_j, dist), dim=-1)
else:
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
# Compute the messages and their equivariant counterpart
m_ij = self.message_net(input_)
message = diff * self.update_pos_net(m_ij)
return message, m_ij
def aggregate(self, inputs, index, ptr=None, dim_size=None):
"""
Aggregate the messages at the nodes during message passing.
This method receives a tuple of tensors corresponding to the messages
to be aggregated. Both messages are aggregated separately according to
the specified aggregation scheme.
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
aggregate.
:param index: The indices of target nodes for each message. This tensor
specifies which node each message is aggregated into.
:type index: torch.Tensor | LabelTensor
:param ptr: Optional tensor to specify the slices of messages for each
node (used in some aggregation strategies). Default is None.
:type ptr: torch.Tensor | LabelTensor
:param int dim_size: Optional size of the output dimension, i.e.,
number of nodes. Default is None.
:return: Tuple of aggregated tensors corresponding to (aggregated
messages for position updates, aggregated messages for feature
updates).
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
# Unpack the messages from the inputs
message, m_ij = inputs
# Aggregate messages as usual using self.aggr method
agg_message = super().aggregate(message, index, ptr, dim_size)
agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size)
return agg_message, agg_m_ij
def update(self, aggregated_inputs, x, pos, edge_index):
"""
Update the node features and the node coordinates with the received
messages.
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param pos: The euclidean coordinates of the nodes.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indices.
:return: The updated node features and node positions.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
# aggregated_inputs is tuple (agg_message, agg_m_ij)
agg_message, agg_m_ij = aggregated_inputs
# Update node features with aggregated m_ij
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
# Degree for normalization of position updates
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
pos = pos + agg_message / c
return x, pos

View File

@@ -0,0 +1,149 @@
"""Module for the Interaction Network block."""
import torch
from torch_geometric.nn import MessagePassing
from ....utils import check_positive_integer
from ....model import FeedForward
class InteractionNetworkBlock(MessagePassing):
"""
Implementation of the Interaction Network block.
This block is used to perform message-passing between nodes and edges in a
graph neural network, following the scheme proposed by Battaglia et al. in
2016. It serves as an inner block in a larger graph neural network
architecture.
The message between two nodes connected by an edge is computed by applying a
multi-layer perceptron (MLP) to the concatenation of the sender and
recipient node features. Messages are then aggregated using an aggregation
scheme (e.g., sum, mean, min, max, or product).
The update step is performed by applying another MLP to the concatenation of
the incoming messages and the node features.
.. seealso::
**Original reference**: Battaglia, P. W., et al. (2016).
*Interaction Networks for Learning about Objects, Relations and
Physics*.
In Advances in Neural Information Processing Systems (NeurIPS 2016).
DOI: `<https://doi.org/10.48550/arXiv.1612.00222>`_.
"""
def __init__(
self,
node_feature_dim,
edge_feature_dim=0,
hidden_dim=64,
n_message_layers=2,
n_update_layers=2,
activation=torch.nn.SiLU,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`InteractionNetworkBlock` class.
:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
If edge_attr is not provided, it is assumed to be 0.
Default is 0.
:param int hidden_dim: The dimension of the hidden features.
Default is 64.
:param int n_message_layers: The number of layers in the message
network. Default is 2.
:param int n_update_layers: The number of layers in the update network.
Default is 2.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.SiLU`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If `node_feature_dim` is not a positive integer.
:raises AssertionError: If `hidden_dim` is not a positive integer.
:raises AssertionError: If `n_message_layers` is not a positive integer.
:raises AssertionError: If `n_update_layers` is not a positive integer.
:raises AssertionError: If `edge_feature_dim` is not a non-negative
integer.
"""
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
# Check values
check_positive_integer(node_feature_dim, strict=True)
check_positive_integer(hidden_dim, strict=True)
check_positive_integer(n_message_layers, strict=True)
check_positive_integer(n_update_layers, strict=True)
check_positive_integer(edge_feature_dim, strict=False)
# Message network
self.message_net = FeedForward(
input_dimensions=2 * node_feature_dim + edge_feature_dim,
output_dimensions=hidden_dim,
inner_size=hidden_dim,
n_layers=n_message_layers,
func=activation,
)
# Update network
self.update_net = FeedForward(
input_dimensions=node_feature_dim + hidden_dim,
output_dimensions=node_feature_dim,
inner_size=hidden_dim,
n_layers=n_update_layers,
func=activation,
)
def forward(self, x, edge_index, edge_attr=None):
"""
Forward pass of the block, triggering the message-passing routine.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indeces.
:param edge_attr: The edge attributes. Default is None.
:type edge_attr: torch.Tensor | LabelTensor
:return: The updated node features.
:rtype: torch.Tensor
"""
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
def message(self, x_i, x_j, edge_attr):
"""
Compute the message to be passed between nodes and edges.
:param x_i: The node features of the recipient nodes.
:type x_i: torch.Tensor | LabelTensor
:param x_j: The node features of the sender nodes.
:type x_j: torch.Tensor | LabelTensor
:return: The message to be passed.
:rtype: torch.Tensor
"""
if edge_attr is None:
input_ = torch.cat((x_i, x_j), dim=-1)
else:
input_ = torch.cat((x_i, x_j, edge_attr), dim=-1)
return self.message_net(input_)
def update(self, message, x):
"""
Update the node features with the received messages.
:param torch.Tensor message: The message to be passed.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:return: The updated node features.
:rtype: torch.Tensor
"""
return self.update_net(torch.cat((x, message), dim=-1))

View File

@@ -0,0 +1,126 @@
"""Module for the Radial Field Network block."""
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops
from ....utils import check_positive_integer
from ....model import FeedForward
class RadialFieldNetworkBlock(MessagePassing):
"""
Implementation of the Radial Field Network block.
This block is used to perform message-passing between nodes and edges in a
graph neural network, following the scheme proposed by Köhler et al. in
2020. It serves as an inner block in a larger graph neural network
architecture.
The message between two nodes connected by an edge is computed by applying a
linear transformation to the norm of the difference between the sender and
recipient node features, together with the radial distance between the
sender and recipient node features, followed by a non-linear activation
function. Messages are then aggregated using an aggregation scheme
(e.g., sum, mean, min, max, or product).
The update step is performed by a simple addition of the incoming messages
to the node features.
.. seealso::
**Original reference** Köhler, J., Klein, L., Noé, F. (2020).
*Equivariant Flows: Exact Likelihood Generative Learning for Symmetric
Densities*.
In International Conference on Machine Learning.
DOI: `<https://doi.org/10.48550/arXiv.2006.02425>`_.
"""
def __init__(
self,
node_feature_dim,
hidden_dim=64,
n_layers=2,
activation=torch.nn.Tanh,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`RadialFieldNetworkBlock` class.
:param int node_feature_dim: The dimension of the node features.
:param int hidden_dim: The dimension of the hidden features.
Default is 64.
:param int n_layers: The number of layers in the network. Default is 2.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.Tanh`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If `node_feature_dim` is not a positive integer.
:raises AssertionError: If `hidden_dim` is not a positive integer.
:raises AssertionError: If `n_layers` is not a positive integer.
"""
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
# Check values
check_positive_integer(node_feature_dim, strict=True)
check_positive_integer(hidden_dim, strict=True)
check_positive_integer(n_layers, strict=True)
# Layer for processing node features
self.radial_net = FeedForward(
input_dimensions=1,
output_dimensions=1,
inner_size=hidden_dim,
n_layers=n_layers,
func=activation,
)
def forward(self, x, edge_index):
"""
Forward pass of the block, triggering the message-passing routine.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indices.
:return: The updated node features.
:rtype: torch.Tensor
"""
edge_index, _ = remove_self_loops(edge_index)
return self.propagate(edge_index=edge_index, x=x)
def message(self, x_i, x_j):
"""
Compute the message to be passed between nodes and edges.
:param x_i: The node features of the recipient nodes.
:type x_i: torch.Tensor | LabelTensor
:param x_j: The node features of the sender nodes.
:type x_j: torch.Tensor | LabelTensor
:return: The message to be passed.
:rtype: torch.Tensor
"""
r = x_i - x_j
return self.radial_net(torch.norm(r, dim=1, keepdim=True)) * r
def update(self, message, x):
"""
Update the node features with the received messages.
:param torch.Tensor message: The message to be passed.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:return: The updated node features.
:rtype: torch.Tensor
"""
return x + message