Co-authored-by: GiovanniCanali <giovanni.canali98@yahoo.it>
This commit is contained in:
avisquid
2025-10-03 14:37:56 -04:00
committed by GitHub
parent b5e4d13663
commit 2108c76d14
11 changed files with 885 additions and 39 deletions

View File

@@ -105,6 +105,7 @@ Models
GraphNeuralOperator <model/graph_neural_operator.rst> GraphNeuralOperator <model/graph_neural_operator.rst>
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst> GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
PirateNet <model/pirate_network.rst> PirateNet <model/pirate_network.rst>
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
Blocks Blocks
------------- -------------
@@ -134,6 +135,7 @@ Message Passing
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst> E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
Interaction Network Block <model/block/message_passing/interaction_network_block.rst> Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst> Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
EquivariantGraphNeuralOperatorBlock <model/block/message_passing/equivariant_graph_neural_operator_block.rst>
Reduction and Embeddings Reduction and Embeddings

View File

@@ -0,0 +1,7 @@
EquivariantGraphNeuralOperatorBlock
=====================================
.. currentmodule:: pina.model.block.message_passing.equivariant_graph_neural_operator_block
.. autoclass:: EquivariantGraphNeuralOperatorBlock
:members:
:show-inheritance:

View File

@@ -0,0 +1,7 @@
EquivariantGraphNeuralOperator
=================================
.. currentmodule:: pina.model.equivariant_graph_neural_operator
.. autoclass:: EquivariantGraphNeuralOperator
:members:
:show-inheritance:

View File

@@ -14,6 +14,7 @@ __all__ = [
"Spline", "Spline",
"GraphNeuralOperator", "GraphNeuralOperator",
"PirateNet", "PirateNet",
"EquivariantGraphNeuralOperator",
] ]
from .feed_forward import FeedForward, ResidualFeedForward from .feed_forward import FeedForward, ResidualFeedForward
@@ -26,3 +27,4 @@ from .low_rank_neural_operator import LowRankNeuralOperator
from .spline import Spline from .spline import Spline
from .graph_neural_operator import GraphNeuralOperator from .graph_neural_operator import GraphNeuralOperator
from .pirate_network import PirateNet from .pirate_network import PirateNet
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator

View File

@@ -5,9 +5,13 @@ __all__ = [
"DeepTensorNetworkBlock", "DeepTensorNetworkBlock",
"EnEquivariantNetworkBlock", "EnEquivariantNetworkBlock",
"RadialFieldNetworkBlock", "RadialFieldNetworkBlock",
"EquivariantGraphNeuralOperatorBlock",
] ]
from .interaction_network_block import InteractionNetworkBlock from .interaction_network_block import InteractionNetworkBlock
from .deep_tensor_network_block import DeepTensorNetworkBlock from .deep_tensor_network_block import DeepTensorNetworkBlock
from .en_equivariant_network_block import EnEquivariantNetworkBlock from .en_equivariant_network_block import EnEquivariantNetworkBlock
from .radial_field_network_block import RadialFieldNetworkBlock from .radial_field_network_block import RadialFieldNetworkBlock
from .equivariant_graph_neural_operator_block import (
EquivariantGraphNeuralOperatorBlock,
)

View File

@@ -3,7 +3,7 @@
import torch import torch
from torch_geometric.nn import MessagePassing from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree from torch_geometric.utils import degree
from ....utils import check_positive_integer from ....utils import check_positive_integer, check_consistency
from ....model import FeedForward from ....model import FeedForward
@@ -27,6 +27,12 @@ class EnEquivariantNetworkBlock(MessagePassing):
positions are updated by adding the incoming messages divided by the positions are updated by adding the incoming messages divided by the
degree of the recipient node. degree of the recipient node.
When velocity features are used, node velocities are passed through a small
MLP to compute updates, which are then combined with the aggregated position
messages. The node positions are updated both by the normalized position
messages and by the updated velocities, ensuring equivariance while
incorporating dynamic information.
.. seealso:: .. seealso::
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M. **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
@@ -40,6 +46,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
node_feature_dim, node_feature_dim,
edge_feature_dim, edge_feature_dim,
pos_dim, pos_dim,
use_velocity=False,
hidden_dim=64, hidden_dim=64,
n_message_layers=2, n_message_layers=2,
n_update_layers=2, n_update_layers=2,
@@ -54,6 +61,8 @@ class EnEquivariantNetworkBlock(MessagePassing):
:param int node_feature_dim: The dimension of the node features. :param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features. :param int edge_feature_dim: The dimension of the edge features.
:param int pos_dim: The dimension of the position features. :param int pos_dim: The dimension of the position features.
:param bool use_velocity: Whether to use velocity features in the
message passing. Default is False.
:param int hidden_dim: The dimension of the hidden features. :param int hidden_dim: The dimension of the hidden features.
Default is 64. Default is 64.
:param int n_message_layers: The number of layers in the message :param int n_message_layers: The number of layers in the message
@@ -80,6 +89,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
:raises AssertionError: If `hidden_dim` is not a positive integer. :raises AssertionError: If `hidden_dim` is not a positive integer.
:raises AssertionError: If `n_message_layers` is not a positive integer. :raises AssertionError: If `n_message_layers` is not a positive integer.
:raises AssertionError: If `n_update_layers` is not a positive integer. :raises AssertionError: If `n_update_layers` is not a positive integer.
:raises AssertionError: If `use_velocity` is not a boolean.
""" """
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
@@ -90,6 +100,10 @@ class EnEquivariantNetworkBlock(MessagePassing):
check_positive_integer(hidden_dim, strict=True) check_positive_integer(hidden_dim, strict=True)
check_positive_integer(n_message_layers, strict=True) check_positive_integer(n_message_layers, strict=True)
check_positive_integer(n_update_layers, strict=True) check_positive_integer(n_update_layers, strict=True)
check_consistency(use_velocity, bool)
# Initialization
self.use_velocity = use_velocity
# Layer for computing the message # Layer for computing the message
self.message_net = FeedForward( self.message_net = FeedForward(
@@ -119,7 +133,17 @@ class EnEquivariantNetworkBlock(MessagePassing):
func=activation, func=activation,
) )
def forward(self, x, pos, edge_index, edge_attr=None): # If velocity is used, instantiate layer for velocity updates
if self.use_velocity:
self.update_vel_net = FeedForward(
input_dimensions=node_feature_dim,
output_dimensions=1,
inner_size=hidden_dim,
n_layers=n_update_layers,
func=activation,
)
def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
""" """
Forward pass of the block, triggering the message-passing routine. Forward pass of the block, triggering the message-passing routine.
@@ -130,11 +154,19 @@ class EnEquivariantNetworkBlock(MessagePassing):
:param torch.Tensor edge_index: The edge indices. :param torch.Tensor edge_index: The edge indices.
:param edge_attr: The edge attributes. Default is None. :param edge_attr: The edge attributes. Default is None.
:type edge_attr: torch.Tensor | LabelTensor :type edge_attr: torch.Tensor | LabelTensor
:param vel: The velocity of the nodes. Default is None.
:type vel: torch.Tensor | LabelTensor
:return: The updated node features and node positions. :return: The updated node features and node positions.
:rtype: tuple(torch.Tensor, torch.Tensor) :rtype: tuple(torch.Tensor, torch.Tensor)
:raises: ValueError: If ``use_velocity`` is True and ``vel`` is None.
""" """
if self.use_velocity and vel is None:
raise ValueError(
"Velocity features are enabled, but no velocity is passed."
)
return self.propagate( return self.propagate(
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
) )
def message(self, x_i, x_j, pos_i, pos_j, edge_attr): def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
@@ -202,10 +234,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
return agg_message, agg_m_ij return agg_message, agg_m_ij
def update(self, aggregated_inputs, x, pos, edge_index): def update(self, aggregated_inputs, x, pos, edge_index, vel):
""" """
Update the node features and the node coordinates with the received Update node features, positions, and optionally velocities.
messages.
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed. :param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
:param x: The node features. :param x: The node features.
@@ -213,17 +244,26 @@ class EnEquivariantNetworkBlock(MessagePassing):
:param pos: The euclidean coordinates of the nodes. :param pos: The euclidean coordinates of the nodes.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indices. :param torch.Tensor edge_index: The edge indices.
:param vel: The velocity of the nodes.
:type vel: torch.Tensor | LabelTensor
:return: The updated node features and node positions. :return: The updated node features and node positions.
:rtype: tuple(torch.Tensor, torch.Tensor) :rtype: tuple(torch.Tensor, torch.Tensor) |
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
""" """
# aggregated_inputs is tuple (agg_message, agg_m_ij) # aggregated_inputs is tuple (agg_message, agg_m_ij)
agg_message, agg_m_ij = aggregated_inputs agg_message, agg_m_ij = aggregated_inputs
# Degree for normalization of position updates
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
# If velocity is used, update it and use it to update positions
if self.use_velocity:
vel = self.update_vel_net(x) * vel
# Update node features with aggregated m_ij # Update node features with aggregated m_ij
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1)) x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
# Degree for normalization of position updates # Update positions with aggregated messages m_ij and velocities
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1) pos = pos + agg_message / c + (vel if self.use_velocity else 0)
pos = pos + agg_message / c
return x, pos return (x, pos, vel) if self.use_velocity else (x, pos)

View File

@@ -0,0 +1,188 @@
"""Module for the Equivariant Graph Neural Operator block."""
import torch
from ....utils import check_positive_integer
from .en_equivariant_network_block import EnEquivariantNetworkBlock
class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
"""
A single block of the Equivariant Graph Neural Operator (EGNO).
This block combines a temporal convolution with an equivariant graph neural
network (EGNN) layer. It preserves equivariance while modeling complex
interactions between nodes in a graph over time.
.. seealso::
**Original reference**
Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli,
K., Leskovec, J., Ermon, S., Anandkumar, A. (2024).
*Equivariant Graph Neural Operator for Modeling 3D Dynamics*
DOI: `arXiv preprint arXiv:2401.11037.
<https://arxiv.org/abs/2401.11037>`_
"""
def __init__(
self,
node_feature_dim,
edge_feature_dim,
pos_dim,
modes,
hidden_dim=64,
n_message_layers=2,
n_update_layers=2,
activation=torch.nn.SiLU,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`EquivariantGraphNeuralOperatorBlock`
class.
:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
:param int pos_dim: The dimension of the position features.
:param int modes: The number of Fourier modes to use in the temporal
convolution.
:param int hidden_dim: The dimension of the hidden features.
Default is 64.
:param int n_message_layers: The number of layers in the message
network. Default is 2.
:param int n_update_layers: The number of layers in the update network.
Default is 2.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.SiLU`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If ``modes`` is not a positive integer.
"""
super().__init__()
# Check consistency
check_positive_integer(modes, strict=True)
# Initialization
self.modes = modes
# Temporal convolution weights - real and imaginary parts
self.weight_scalar_r = torch.nn.Parameter(
torch.rand(node_feature_dim, node_feature_dim, modes)
)
self.weight_scalar_i = torch.nn.Parameter(
torch.rand(node_feature_dim, node_feature_dim, modes)
)
self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
# EGNN block
self.egnn = EnEquivariantNetworkBlock(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim,
use_velocity=True,
hidden_dim=hidden_dim,
n_message_layers=n_message_layers,
n_update_layers=n_update_layers,
activation=activation,
aggr=aggr,
node_dim=node_dim,
flow=flow,
)
def forward(self, x, pos, vel, edge_index, edge_attr=None):
"""
Forward pass of the Equivariant Graph Neural Operator block.
:param x: The node feature tensor of shape
``[time_steps, num_nodes, node_feature_dim]``.
:type x: torch.Tensor | LabelTensor
:param pos: The node position tensor (Euclidean coordinates) of shape
``[time_steps, num_nodes, pos_dim]``.
:type pos: torch.Tensor | LabelTensor
:param vel: The node velocity tensor of shape
``[time_steps, num_nodes, pos_dim]``.
:type vel: torch.Tensor | LabelTensor
:param edge_index: The edge connectivity of shape ``[2, num_edges]``.
:type edge_index: torch.Tensor
:param edge_attr: The edge feature tensor of shape
``[time_steps, num_edges, edge_feature_dim]``. Default is None.
:type edge_attr: torch.Tensor | LabelTensor, optional
:return: The updated node features, positions, and velocities, each with
the same shape as the inputs.
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""
# Prepare features
center = pos.mean(dim=1, keepdim=True)
vector = torch.stack((pos - center, vel), dim=-1)
# Compute temporal convolution
x = x + self._convolution(
x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i
)
vector = vector + self._convolution(
vector,
"mndi, iom -> mndo",
self.weight_vector_r,
self.weight_vector_i,
)
# Split position and velocity
pos, vel = vector.unbind(dim=-1)
pos = pos + center
# Reshape to (time * nodes, feature) for egnn
x = x.reshape(-1, x.shape[-1])
pos = pos.reshape(-1, pos.shape[-1])
vel = vel.reshape(-1, vel.shape[-1])
if edge_attr is not None:
edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1])
x, pos, vel = self.egnn(
x=x,
pos=pos,
edge_index=edge_index,
edge_attr=edge_attr,
vel=vel,
)
# Reshape back to (time, nodes, feature)
x = x.reshape(center.shape[0], -1, x.shape[-1])
pos = pos.reshape(center.shape[0], -1, pos.shape[-1])
vel = vel.reshape(center.shape[0], -1, vel.shape[-1])
return x, pos, vel
def _convolution(self, x, einsum_idx, real, img):
"""
Compute the temporal convolution.
:param torch.Tensor x: The input features.
:param str einsum_idx: The indices for the einsum operation.
:param torch.Tensor real: The real part of the convolution weights.
:param torch.Tensor img: The imaginary part of the convolution weights.
:return: The convolved features.
:rtype: torch.Tensor
"""
# Number of modes to use
modes = min(self.modes, (x.shape[0] // 2) + 1)
# Build complex weights
weights = torch.complex(real[..., :modes], img[..., :modes])
# Convolution in Fourier space
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
out = torch.einsum(einsum_idx, fourier, weights)
return torch.fft.irfftn(out, s=x.shape[0], dim=0)

View File

@@ -0,0 +1,219 @@
"""Module for the Equivariant Graph Neural Operator model."""
import torch
from ..utils import check_positive_integer
from .block.message_passing import EquivariantGraphNeuralOperatorBlock
class EquivariantGraphNeuralOperator(torch.nn.Module):
"""
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
EGNO is a graph-based neural operator that preserves equivariance with
respect to 3D transformations while modeling temporal and spatial
interactions between nodes. It combines:
1. Temporal convolution in the Fourier domain to capture long-range
temporal dependencies efficiently.
2. Equivariant Graph Neural Network (EGNN) layers to model interactions
between nodes while respecting geometric symmetries.
This design allows EGNO to learn complex spatiotemporal dynamics of
physical systems, molecules, or particles while enforcing physically
meaningful constraints.
.. seealso::
**Original reference**
Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli,
K., Leskovec, J., Ermon, S., Anandkumar, A. (2024).
*Equivariant Graph Neural Operator for Modeling 3D Dynamics*
DOI: `arXiv preprint arXiv:2401.11037.
<https://arxiv.org/abs/2401.11037>`_
"""
def __init__(
self,
n_egno_layers,
node_feature_dim,
edge_feature_dim,
pos_dim,
modes,
time_steps=2,
hidden_dim=64,
time_emb_dim=16,
max_time_idx=10000,
n_message_layers=2,
n_update_layers=2,
activation=torch.nn.SiLU,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`EquivariantGraphNeuralOperator` class.
:param int n_egno_layers: The number of EGNO layers.
:param int node_feature_dim: The dimension of the node features in each
EGNO layer.
:param int edge_feature_dim: The dimension of the edge features in each
EGNO layer.
:param int pos_dim: The dimension of the position features in each
EGNO layer.
:param int modes: The number of Fourier modes to use in the temporal
convolution.
:param int time_steps: The number of time steps to consider in the
temporal convolution. Default is 2.
:param int hidden_dim: The dimension of the hidden features in each EGNO
layer. Default is 64.
:param int time_emb_dim: The dimension of the sinusoidal time
embeddings. Default is 16.
:param int max_time_idx: The maximum time index for the sinusoidal
embeddings. Default is 10000.
:param int n_message_layers: The number of layers in the message
network of each EGNO layer. Default is 2.
:param int n_update_layers: The number of layers in the update network
of each EGNO layer. Default is 2.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.SiLU`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If ``n_egno_layers`` is not a positive integer.
:raises AssertionError: If ``time_emb_dim`` is not a positive integer.
:raises AssertionError: If ``max_time_idx`` is not a positive integer.
:raises AssertionError: If ``time_steps`` is not a positive integer.
"""
super().__init__()
# Check consistency
check_positive_integer(n_egno_layers, strict=True)
check_positive_integer(time_emb_dim, strict=True)
check_positive_integer(max_time_idx, strict=True)
check_positive_integer(time_steps, strict=True)
# Initialize parameters
self.time_steps = time_steps
self.time_emb_dim = time_emb_dim
self.max_time_idx = max_time_idx
# Initialize EGNO layers
self.egno_layers = torch.nn.ModuleList()
for _ in range(n_egno_layers):
self.egno_layers.append(
EquivariantGraphNeuralOperatorBlock(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim,
modes=modes,
hidden_dim=hidden_dim,
n_message_layers=n_message_layers,
n_update_layers=n_update_layers,
activation=activation,
aggr=aggr,
node_dim=node_dim,
flow=flow,
)
)
# Linear layer to adjust the scalar feature dimension
self.linear = torch.nn.Linear(
node_feature_dim + time_emb_dim, node_feature_dim
)
def forward(self, graph):
"""
Forward pass of the :class:`EquivariantGraphNeuralOperator` class.
:param graph: The input graph object with the following attributes:
- 'x': Node features, shape ``[num_nodes, node_feature_dim]``.
- 'pos': Node positions, shape ``[num_nodes, pos_dim]``.
- 'vel': Node velocities, shape ``[num_nodes, pos_dim]``.
- 'edge_index': Graph connectivity, shape ``[2, num_edges]``.
- 'edge_attr': Edge attrs, shape ``[num_edges, edge_feature_dim]``.
:type graph: Data | Graph
:return: The output graph object with updated node features,
positions, and velocities. The output graph adds to 'x', 'pos',
'vel', and 'edge_attr' the time dimension, resulting in shapes:
- 'x': ``[time_steps, num_nodes, node_feature_dim]``
- 'pos': ``[time_steps, num_nodes, pos_dim]``
- 'vel': ``[time_steps, num_nodes, pos_dim]``
- 'edge_attr': ``[time_steps, num_edges, edge_feature_dim]``
:rtype: Data | Graph
:raises ValueError: If the input graph does not have a 'vel' attribute.
"""
# Check that the graph has the required attributes
if "vel" not in graph:
raise ValueError("The input graph must have a 'vel' attribute.")
# Compute the temporal embedding
emb = self._embedding(torch.arange(self.time_steps)).to(graph.x.device)
emb = emb.unsqueeze(1).repeat(1, graph.x.shape[0], 1)
# Expand dimensions
x = graph.x.unsqueeze(0).repeat(self.time_steps, 1, 1)
x = self.linear(torch.cat((x, emb), dim=-1))
pos = graph.pos.unsqueeze(0).repeat(self.time_steps, 1, 1)
vel = graph.vel.unsqueeze(0).repeat(self.time_steps, 1, 1)
# Manage edge index
offset = torch.arange(self.time_steps).reshape(-1, 1)
offset = offset.to(graph.x.device) * graph.x.shape[0]
src = graph.edge_index[0].unsqueeze(0) + offset
dst = graph.edge_index[1].unsqueeze(0) + offset
edge_index = torch.stack([src, dst], dim=0).reshape(2, -1)
# Manage edge attributes
if graph.edge_attr is not None:
edge_attr = graph.edge_attr.unsqueeze(0)
edge_attr = edge_attr.repeat(self.time_steps, 1, 1)
else:
edge_attr = None
# Iteratively apply EGNO layers
for layer in self.egno_layers:
x, pos, vel = layer(
x=x,
pos=pos,
vel=vel,
edge_index=edge_index,
edge_attr=edge_attr,
)
# Build new graph
new_graph = graph.clone()
new_graph.x, new_graph.pos, new_graph.vel = x, pos, vel
if edge_attr is not None:
new_graph.edge_attr = edge_attr
return new_graph
def _embedding(self, time):
"""
Generate sinusoidal temporal embeddings.
:param torch.Tensor time: The time instances.
:return: The sinusoidal embedding tensor.
:rtype: torch.Tensor
"""
# Compute the sinusoidal embeddings
half_dim = self.time_emb_dim // 2
logs = torch.log(torch.as_tensor(self.max_time_idx)) / (half_dim - 1)
freqs = torch.exp(-torch.arange(half_dim) * logs)
args = torch.as_tensor(time)[:, None] * freqs[None, :]
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
# Apply padding if the embedding dimension is odd
if self.time_emb_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1), mode="constant")
return emb

View File

@@ -5,19 +5,22 @@ from pina.model.block.message_passing import EnEquivariantNetworkBlock
# Data for testing # Data for testing
x = torch.rand(10, 4) x = torch.rand(10, 4)
pos = torch.rand(10, 3) pos = torch.rand(10, 3)
edge_index = torch.randint(0, 10, (2, 20)) velocity = torch.rand(10, 3)
edge_attr = torch.randn(20, 2) edge_idx = torch.randint(0, 10, (2, 20))
edge_attributes = torch.randn(20, 2)
@pytest.mark.parametrize("node_feature_dim", [1, 3]) @pytest.mark.parametrize("node_feature_dim", [1, 3])
@pytest.mark.parametrize("edge_feature_dim", [0, 2]) @pytest.mark.parametrize("edge_feature_dim", [0, 2])
@pytest.mark.parametrize("pos_dim", [2, 3]) @pytest.mark.parametrize("pos_dim", [2, 3])
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): @pytest.mark.parametrize("use_velocity", [True, False])
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, use_velocity):
EnEquivariantNetworkBlock( EnEquivariantNetworkBlock(
node_feature_dim=node_feature_dim, node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim, edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim, pos_dim=pos_dim,
use_velocity=use_velocity,
hidden_dim=64, hidden_dim=64,
n_message_layers=2, n_message_layers=2,
n_update_layers=2, n_update_layers=2,
@@ -29,6 +32,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
node_feature_dim=-1, node_feature_dim=-1,
edge_feature_dim=edge_feature_dim, edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim, pos_dim=pos_dim,
use_velocity=use_velocity,
) )
# Should fail if edge_feature_dim is negative # Should fail if edge_feature_dim is negative
@@ -37,6 +41,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
node_feature_dim=node_feature_dim, node_feature_dim=node_feature_dim,
edge_feature_dim=-1, edge_feature_dim=-1,
pos_dim=pos_dim, pos_dim=pos_dim,
use_velocity=use_velocity,
) )
# Should fail if pos_dim is negative # Should fail if pos_dim is negative
@@ -45,6 +50,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
node_feature_dim=node_feature_dim, node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim, edge_feature_dim=edge_feature_dim,
pos_dim=-1, pos_dim=-1,
use_velocity=use_velocity,
) )
# Should fail if hidden_dim is negative # Should fail if hidden_dim is negative
@@ -54,6 +60,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
edge_feature_dim=edge_feature_dim, edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim, pos_dim=pos_dim,
hidden_dim=-1, hidden_dim=-1,
use_velocity=use_velocity,
) )
# Should fail if n_message_layers is negative # Should fail if n_message_layers is negative
@@ -63,6 +70,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
edge_feature_dim=edge_feature_dim, edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim, pos_dim=pos_dim,
n_message_layers=-1, n_message_layers=-1,
use_velocity=use_velocity,
) )
# Should fail if n_update_layers is negative # Should fail if n_update_layers is negative
@@ -72,11 +80,22 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
edge_feature_dim=edge_feature_dim, edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim, pos_dim=pos_dim,
n_update_layers=-1, n_update_layers=-1,
use_velocity=use_velocity,
)
# Should fail if use_velocity is not boolean
with pytest.raises(ValueError):
EnEquivariantNetworkBlock(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim,
use_velocity="False",
) )
@pytest.mark.parametrize("edge_feature_dim", [0, 2]) @pytest.mark.parametrize("edge_feature_dim", [0, 2])
def test_forward(edge_feature_dim): @pytest.mark.parametrize("use_velocity", [True, False])
def test_forward(edge_feature_dim, use_velocity):
model = EnEquivariantNetworkBlock( model = EnEquivariantNetworkBlock(
node_feature_dim=x.shape[1], node_feature_dim=x.shape[1],
@@ -85,21 +104,26 @@ def test_forward(edge_feature_dim):
hidden_dim=64, hidden_dim=64,
n_message_layers=2, n_message_layers=2,
n_update_layers=2, n_update_layers=2,
use_velocity=use_velocity,
) )
if edge_feature_dim == 0: # Manage inputs
output_ = model(edge_index=edge_index, x=x, pos=pos) vel = velocity if use_velocity else None
else: edge_attr = edge_attributes if edge_feature_dim > 0 else None
output_ = model(
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
)
# Checks on output shapes
output_ = model(
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
)
assert output_[0].shape == x.shape assert output_[0].shape == x.shape
assert output_[1].shape == pos.shape assert output_[1].shape == pos.shape
if vel is not None:
assert output_[2].shape == vel.shape
@pytest.mark.parametrize("edge_feature_dim", [0, 2]) @pytest.mark.parametrize("edge_feature_dim", [0, 2])
def test_backward(edge_feature_dim): @pytest.mark.parametrize("use_velocity", [True, False])
def test_backward(edge_feature_dim, use_velocity):
model = EnEquivariantNetworkBlock( model = EnEquivariantNetworkBlock(
node_feature_dim=x.shape[1], node_feature_dim=x.shape[1],
@@ -108,35 +132,45 @@ def test_backward(edge_feature_dim):
hidden_dim=64, hidden_dim=64,
n_message_layers=2, n_message_layers=2,
n_update_layers=2, n_update_layers=2,
use_velocity=use_velocity,
)
# Manage inputs
vel = velocity.requires_grad_() if use_velocity else None
edge_attr = (
edge_attributes.requires_grad_() if edge_feature_dim > 0 else None
) )
if edge_feature_dim == 0: if edge_feature_dim == 0:
output_ = model( output_ = model(
edge_index=edge_index, edge_index=edge_idx,
x=x.requires_grad_(), x=x.requires_grad_(),
pos=pos.requires_grad_(), pos=pos.requires_grad_(),
vel=vel,
) )
else: else:
output_ = model( output_ = model(
edge_index=edge_index, edge_index=edge_idx,
x=x.requires_grad_(), x=x.requires_grad_(),
pos=pos.requires_grad_(), pos=pos.requires_grad_(),
edge_attr=edge_attr.requires_grad_(), edge_attr=edge_attr,
vel=vel,
) )
loss = torch.mean(output_[0]) # Checks on gradients
loss = sum(torch.mean(output_[i]) for i in range(len(output_)))
loss.backward() loss.backward()
assert x.grad.shape == x.shape assert x.grad.shape == x.shape
assert pos.grad.shape == pos.shape assert pos.grad.shape == pos.shape
if use_velocity:
assert vel.grad.shape == vel.shape
def test_equivariance(): @pytest.mark.parametrize("edge_feature_dim", [0, 2])
@pytest.mark.parametrize("use_velocity", [True, False])
def test_equivariance(edge_feature_dim, use_velocity):
# Graph to be fully connected and undirected # Random rotation
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
# Random rotation (det(rotation) should be 1)
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
if torch.det(rotation) < 0: if torch.det(rotation) < 0:
rotation[:, 0] *= -1 rotation[:, 0] *= -1
@@ -146,20 +180,37 @@ def test_equivariance():
model = EnEquivariantNetworkBlock( model = EnEquivariantNetworkBlock(
node_feature_dim=x.shape[1], node_feature_dim=x.shape[1],
edge_feature_dim=0, edge_feature_dim=edge_feature_dim,
pos_dim=pos.shape[1], pos_dim=pos.shape[1],
hidden_dim=64, hidden_dim=64,
n_message_layers=2, n_message_layers=2,
n_update_layers=2, n_update_layers=2,
use_velocity=use_velocity,
).eval() ).eval()
h1, pos1 = model(edge_index=edge_index, x=x, pos=pos) # Manage inputs
h2, pos2 = model( vel = velocity if use_velocity else None
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation edge_attr = edge_attributes if edge_feature_dim > 0 else None
# Transform inputs (no translation for velocity)
pos_rot = pos @ rotation.T + translation
vel_rot = vel @ rotation.T if use_velocity else vel
# Get model outputs
out1 = model(
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
)
out2 = model(
x=x, pos=pos_rot, edge_index=edge_idx, edge_attr=edge_attr, vel=vel_rot
) )
# Transform model output # Unpack outputs
pos1_transformed = (pos1 @ rotation.T) + translation h1, pos1, *other1 = out1
h2, pos2, *other2 = out2
if use_velocity:
vel1, vel2 = other1[0], other2[0]
assert torch.allclose(pos2, pos1_transformed, atol=1e-5) assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
assert torch.allclose(h1, h2, atol=1e-5) assert torch.allclose(h1, h2, atol=1e-5)
if vel is not None:
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)

View File

@@ -0,0 +1,132 @@
import pytest
import torch
from pina.model.block.message_passing import EquivariantGraphNeuralOperatorBlock
# Data for testing. Shapes: (time, nodes, features)
x = torch.rand(5, 10, 4)
pos = torch.rand(5, 10, 3)
vel = torch.rand(5, 10, 3)
# Edge index and attributes
edge_idx = torch.randint(0, 10, (2, 20))
edge_attributes = torch.randn(20, 2)
@pytest.mark.parametrize("node_feature_dim", [1, 3])
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
@pytest.mark.parametrize("pos_dim", [2, 3])
@pytest.mark.parametrize("modes", [1, 5])
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, modes):
EquivariantGraphNeuralOperatorBlock(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim,
modes=modes,
)
# Should fail if modes is negative
with pytest.raises(AssertionError):
EquivariantGraphNeuralOperatorBlock(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim,
modes=-1,
)
@pytest.mark.parametrize("modes", [1, 5])
def test_forward(modes):
model = EquivariantGraphNeuralOperatorBlock(
node_feature_dim=x.shape[2],
edge_feature_dim=edge_attributes.shape[1],
pos_dim=pos.shape[2],
modes=modes,
)
output_ = model(
x=x,
pos=pos,
vel=vel,
edge_index=edge_idx,
edge_attr=edge_attributes,
)
# Checks on output shapes
assert output_[0].shape == x.shape
assert output_[1].shape == pos.shape
assert output_[2].shape == vel.shape
@pytest.mark.parametrize("modes", [1, 5])
def test_backward(modes):
model = EquivariantGraphNeuralOperatorBlock(
node_feature_dim=x.shape[2],
edge_feature_dim=edge_attributes.shape[1],
pos_dim=pos.shape[2],
modes=modes,
)
output_ = model(
x=x.requires_grad_(),
pos=pos.requires_grad_(),
vel=vel.requires_grad_(),
edge_index=edge_idx,
edge_attr=edge_attributes.requires_grad_(),
)
# Checks on gradients
loss = sum(torch.mean(output_[i]) for i in range(len(output_)))
loss.backward()
assert x.grad.shape == x.shape
assert pos.grad.shape == pos.shape
assert vel.grad.shape == vel.shape
@pytest.mark.parametrize("modes", [1, 5])
def test_equivariance(modes):
# Random rotation
rotation = torch.linalg.qr(torch.rand(pos.shape[2], pos.shape[2])).Q
if torch.det(rotation) < 0:
rotation[:, 0] *= -1
# Random translation
translation = torch.rand(1, pos.shape[2])
model = EquivariantGraphNeuralOperatorBlock(
node_feature_dim=x.shape[2],
edge_feature_dim=edge_attributes.shape[1],
pos_dim=pos.shape[2],
modes=modes,
).eval()
# Transform inputs (no translation for velocity)
pos_rot = pos @ rotation.T + translation
vel_rot = vel @ rotation.T
# Get model outputs
out1 = model(
x=x,
pos=pos,
vel=vel,
edge_index=edge_idx,
edge_attr=edge_attributes,
)
out2 = model(
x=x,
pos=pos_rot,
vel=vel_rot,
edge_index=edge_idx,
edge_attr=edge_attributes,
)
# Unpack outputs
h1, pos1, vel1 = out1
h2, pos2, vel2 = out2
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
assert torch.allclose(h1, h2, atol=1e-5)

View File

@@ -0,0 +1,194 @@
import pytest
import torch
import copy
from pina.model import EquivariantGraphNeuralOperator
from pina.graph import Graph
# Utility to create graphs
def make_graph(include_vel=True, use_edge_attr=True):
data = dict(
x=torch.rand(10, 4),
pos=torch.rand(10, 3),
edge_index=torch.randint(0, 10, (2, 20)),
edge_attr=torch.randn(20, 2) if use_edge_attr else None,
)
if include_vel:
data["vel"] = torch.rand(10, 3)
return Graph(**data)
@pytest.mark.parametrize("n_egno_layers", [1, 3])
@pytest.mark.parametrize("time_steps", [1, 3])
@pytest.mark.parametrize("time_emb_dim", [4, 8])
@pytest.mark.parametrize("max_time_idx", [10, 20])
def test_constructor(n_egno_layers, time_steps, time_emb_dim, max_time_idx):
# Create graph and model
graph = make_graph()
EquivariantGraphNeuralOperator(
n_egno_layers=n_egno_layers,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1],
pos_dim=graph.pos.shape[1],
modes=5,
time_steps=time_steps,
time_emb_dim=time_emb_dim,
max_time_idx=max_time_idx,
)
# Should fail if n_egno_layers is negative
with pytest.raises(AssertionError):
EquivariantGraphNeuralOperator(
n_egno_layers=-1,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1],
pos_dim=graph.pos.shape[1],
modes=5,
time_steps=time_steps,
time_emb_dim=time_emb_dim,
max_time_idx=max_time_idx,
)
# Should fail if time_steps is negative
with pytest.raises(AssertionError):
EquivariantGraphNeuralOperator(
n_egno_layers=n_egno_layers,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1],
pos_dim=graph.pos.shape[1],
modes=5,
time_steps=-1,
time_emb_dim=time_emb_dim,
max_time_idx=max_time_idx,
)
# Should fail if max_time_idx is negative
with pytest.raises(AssertionError):
EquivariantGraphNeuralOperator(
n_egno_layers=n_egno_layers,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1],
pos_dim=graph.pos.shape[1],
modes=5,
time_steps=time_steps,
time_emb_dim=time_emb_dim,
max_time_idx=-1,
)
# Should fail if time_emb_dim is negative
with pytest.raises(AssertionError):
EquivariantGraphNeuralOperator(
n_egno_layers=n_egno_layers,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1],
pos_dim=graph.pos.shape[1],
modes=5,
time_steps=time_steps,
time_emb_dim=-1,
max_time_idx=max_time_idx,
)
@pytest.mark.parametrize("n_egno_layers", [1, 3])
@pytest.mark.parametrize("time_steps", [1, 5])
@pytest.mark.parametrize("modes", [1, 3, 10])
@pytest.mark.parametrize("use_edge_attr", [True, False])
def test_forward(n_egno_layers, time_steps, modes, use_edge_attr):
# Create graph and model
graph = make_graph(use_edge_attr=use_edge_attr)
model = EquivariantGraphNeuralOperator(
n_egno_layers=n_egno_layers,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
pos_dim=graph.pos.shape[1],
modes=modes,
time_steps=time_steps,
)
# Checks on output shapes
output_ = model(graph)
assert output_.x.shape == (time_steps, *graph.x.shape)
assert output_.pos.shape == (time_steps, *graph.pos.shape)
assert output_.vel.shape == (time_steps, *graph.vel.shape)
# Should fail graph has no vel attribute
with pytest.raises(ValueError):
graph_no_vel = make_graph(include_vel=False)
model(graph_no_vel)
@pytest.mark.parametrize("n_egno_layers", [1, 3])
@pytest.mark.parametrize("time_steps", [1, 5])
@pytest.mark.parametrize("modes", [1, 3, 10])
@pytest.mark.parametrize("use_edge_attr", [True, False])
def test_backward(n_egno_layers, time_steps, modes, use_edge_attr):
# Create graph and model
graph = make_graph(use_edge_attr=use_edge_attr)
model = EquivariantGraphNeuralOperator(
n_egno_layers=n_egno_layers,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
pos_dim=graph.pos.shape[1],
modes=modes,
time_steps=time_steps,
)
# Set requires_grad and perform forward pass
graph.x.requires_grad_()
graph.pos.requires_grad_()
graph.vel.requires_grad_()
out = model(graph)
# Checks on gradients
loss = torch.mean(out.x) + torch.mean(out.pos) + torch.mean(out.vel)
loss.backward()
assert graph.x.grad.shape == graph.x.shape
assert graph.pos.grad.shape == graph.pos.shape
assert graph.vel.grad.shape == graph.vel.shape
@pytest.mark.parametrize("n_egno_layers", [1, 3])
@pytest.mark.parametrize("time_steps", [1, 5])
@pytest.mark.parametrize("modes", [1, 3, 10])
@pytest.mark.parametrize("use_edge_attr", [True, False])
def test_equivariance(n_egno_layers, time_steps, modes, use_edge_attr):
graph = make_graph(use_edge_attr=use_edge_attr)
model = EquivariantGraphNeuralOperator(
n_egno_layers=n_egno_layers,
node_feature_dim=graph.x.shape[1],
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
pos_dim=graph.pos.shape[1],
modes=modes,
time_steps=time_steps,
).eval()
# Random rotation
rotation = torch.linalg.qr(
torch.rand(graph.pos.shape[1], graph.pos.shape[1])
).Q
if torch.det(rotation) < 0:
rotation[:, 0] *= -1
# Random translation
translation = torch.rand(1, graph.pos.shape[1])
# Transform graph (no translation for velocity)
graph_rot = copy.deepcopy(graph)
graph_rot.pos = graph.pos @ rotation.T + translation
graph_rot.vel = graph.vel @ rotation.T
# Get model outputs
out1 = model(graph)
out2 = model(graph_rot)
# Unpack outputs
h1, pos1, vel1 = out1.x, out1.pos, out1.vel
h2, pos2, vel2 = out2.x, out2.pos, out2.vel
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
assert torch.allclose(h1, h2, atol=1e-5)