From 2108c76d140c11f06e2fe880633d379fa10fdddd Mon Sep 17 00:00:00 2001 From: avisquid <115588530+avisquid@users.noreply.github.com> Date: Fri, 3 Oct 2025 14:37:56 -0400 Subject: [PATCH] add egno (#602) Co-authored-by: GiovanniCanali --- docs/source/_rst/_code.rst | 2 + ...quivariant_graph_neural_operator_block.rst | 7 + .../equivariant_graph_neural_operator.rst | 7 + pina/model/__init__.py | 2 + pina/model/block/message_passing/__init__.py | 4 + .../en_equivariant_network_block.py | 62 ++++- ...equivariant_graph_neural_operator_block.py | 188 +++++++++++++++ .../equivariant_graph_neural_operator.py | 219 ++++++++++++++++++ .../test_equivariant_network_block.py | 107 ++++++--- .../test_equivariant_operator_block.py | 132 +++++++++++ .../test_equivariant_graph_neural_operator.py | 194 ++++++++++++++++ 11 files changed, 885 insertions(+), 39 deletions(-) create mode 100644 docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst create mode 100644 docs/source/_rst/model/equivariant_graph_neural_operator.rst create mode 100644 pina/model/block/message_passing/equivariant_graph_neural_operator_block.py create mode 100644 pina/model/equivariant_graph_neural_operator.py create mode 100644 tests/test_messagepassing/test_equivariant_operator_block.py create mode 100644 tests/test_model/test_equivariant_graph_neural_operator.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 160eb35..25f0e30 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -105,6 +105,7 @@ Models GraphNeuralOperator GraphNeuralKernel PirateNet + EquivariantGraphNeuralOperator Blocks ------------- @@ -134,6 +135,7 @@ Message Passing E(n) Equivariant Network Block Interaction Network Block Radial Field Network Block + EquivariantGraphNeuralOperatorBlock Reduction and Embeddings diff --git a/docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst b/docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst new file mode 100644 index 0000000..8d047f8 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst @@ -0,0 +1,7 @@ +EquivariantGraphNeuralOperatorBlock +===================================== +.. currentmodule:: pina.model.block.message_passing.equivariant_graph_neural_operator_block + +.. autoclass:: EquivariantGraphNeuralOperatorBlock + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/model/equivariant_graph_neural_operator.rst b/docs/source/_rst/model/equivariant_graph_neural_operator.rst new file mode 100644 index 0000000..a11edcc --- /dev/null +++ b/docs/source/_rst/model/equivariant_graph_neural_operator.rst @@ -0,0 +1,7 @@ +EquivariantGraphNeuralOperator +================================= +.. currentmodule:: pina.model.equivariant_graph_neural_operator + +.. autoclass:: EquivariantGraphNeuralOperator + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 5e34048..ee343e5 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -14,6 +14,7 @@ __all__ = [ "Spline", "GraphNeuralOperator", "PirateNet", + "EquivariantGraphNeuralOperator", ] from .feed_forward import FeedForward, ResidualFeedForward @@ -26,3 +27,4 @@ from .low_rank_neural_operator import LowRankNeuralOperator from .spline import Spline from .graph_neural_operator import GraphNeuralOperator from .pirate_network import PirateNet +from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator diff --git a/pina/model/block/message_passing/__init__.py b/pina/model/block/message_passing/__init__.py index 0d43288..202e1fd 100644 --- a/pina/model/block/message_passing/__init__.py +++ b/pina/model/block/message_passing/__init__.py @@ -5,9 +5,13 @@ __all__ = [ "DeepTensorNetworkBlock", "EnEquivariantNetworkBlock", "RadialFieldNetworkBlock", + "EquivariantGraphNeuralOperatorBlock", ] from .interaction_network_block import InteractionNetworkBlock from .deep_tensor_network_block import DeepTensorNetworkBlock from .en_equivariant_network_block import EnEquivariantNetworkBlock from .radial_field_network_block import RadialFieldNetworkBlock +from .equivariant_graph_neural_operator_block import ( + EquivariantGraphNeuralOperatorBlock, +) diff --git a/pina/model/block/message_passing/en_equivariant_network_block.py b/pina/model/block/message_passing/en_equivariant_network_block.py index 904c1c6..b8057b0 100644 --- a/pina/model/block/message_passing/en_equivariant_network_block.py +++ b/pina/model/block/message_passing/en_equivariant_network_block.py @@ -3,7 +3,7 @@ import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import degree -from ....utils import check_positive_integer +from ....utils import check_positive_integer, check_consistency from ....model import FeedForward @@ -27,6 +27,12 @@ class EnEquivariantNetworkBlock(MessagePassing): positions are updated by adding the incoming messages divided by the degree of the recipient node. + When velocity features are used, node velocities are passed through a small + MLP to compute updates, which are then combined with the aggregated position + messages. The node positions are updated both by the normalized position + messages and by the updated velocities, ensuring equivariance while + incorporating dynamic information. + .. seealso:: **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M. @@ -40,6 +46,7 @@ class EnEquivariantNetworkBlock(MessagePassing): node_feature_dim, edge_feature_dim, pos_dim, + use_velocity=False, hidden_dim=64, n_message_layers=2, n_update_layers=2, @@ -54,6 +61,8 @@ class EnEquivariantNetworkBlock(MessagePassing): :param int node_feature_dim: The dimension of the node features. :param int edge_feature_dim: The dimension of the edge features. :param int pos_dim: The dimension of the position features. + :param bool use_velocity: Whether to use velocity features in the + message passing. Default is False. :param int hidden_dim: The dimension of the hidden features. Default is 64. :param int n_message_layers: The number of layers in the message @@ -80,6 +89,7 @@ class EnEquivariantNetworkBlock(MessagePassing): :raises AssertionError: If `hidden_dim` is not a positive integer. :raises AssertionError: If `n_message_layers` is not a positive integer. :raises AssertionError: If `n_update_layers` is not a positive integer. + :raises AssertionError: If `use_velocity` is not a boolean. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) @@ -90,6 +100,10 @@ class EnEquivariantNetworkBlock(MessagePassing): check_positive_integer(hidden_dim, strict=True) check_positive_integer(n_message_layers, strict=True) check_positive_integer(n_update_layers, strict=True) + check_consistency(use_velocity, bool) + + # Initialization + self.use_velocity = use_velocity # Layer for computing the message self.message_net = FeedForward( @@ -119,7 +133,17 @@ class EnEquivariantNetworkBlock(MessagePassing): func=activation, ) - def forward(self, x, pos, edge_index, edge_attr=None): + # If velocity is used, instantiate layer for velocity updates + if self.use_velocity: + self.update_vel_net = FeedForward( + input_dimensions=node_feature_dim, + output_dimensions=1, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) + + def forward(self, x, pos, edge_index, edge_attr=None, vel=None): """ Forward pass of the block, triggering the message-passing routine. @@ -130,11 +154,19 @@ class EnEquivariantNetworkBlock(MessagePassing): :param torch.Tensor edge_index: The edge indices. :param edge_attr: The edge attributes. Default is None. :type edge_attr: torch.Tensor | LabelTensor + :param vel: The velocity of the nodes. Default is None. + :type vel: torch.Tensor | LabelTensor :return: The updated node features and node positions. :rtype: tuple(torch.Tensor, torch.Tensor) + :raises: ValueError: If ``use_velocity`` is True and ``vel`` is None. """ + if self.use_velocity and vel is None: + raise ValueError( + "Velocity features are enabled, but no velocity is passed." + ) + return self.propagate( - edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel ) def message(self, x_i, x_j, pos_i, pos_j, edge_attr): @@ -202,10 +234,9 @@ class EnEquivariantNetworkBlock(MessagePassing): return agg_message, agg_m_ij - def update(self, aggregated_inputs, x, pos, edge_index): + def update(self, aggregated_inputs, x, pos, edge_index, vel): """ - Update the node features and the node coordinates with the received - messages. + Update node features, positions, and optionally velocities. :param tuple(torch.Tensor) aggregated_inputs: The messages to be passed. :param x: The node features. @@ -213,17 +244,26 @@ class EnEquivariantNetworkBlock(MessagePassing): :param pos: The euclidean coordinates of the nodes. :type pos: torch.Tensor | LabelTensor :param torch.Tensor edge_index: The edge indices. + :param vel: The velocity of the nodes. + :type vel: torch.Tensor | LabelTensor :return: The updated node features and node positions. - :rtype: tuple(torch.Tensor, torch.Tensor) + :rtype: tuple(torch.Tensor, torch.Tensor) | + tuple(torch.Tensor, torch.Tensor, torch.Tensor) """ # aggregated_inputs is tuple (agg_message, agg_m_ij) agg_message, agg_m_ij = aggregated_inputs + # Degree for normalization of position updates + c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1) + + # If velocity is used, update it and use it to update positions + if self.use_velocity: + vel = self.update_vel_net(x) * vel + # Update node features with aggregated m_ij x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1)) - # Degree for normalization of position updates - c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1) - pos = pos + agg_message / c + # Update positions with aggregated messages m_ij and velocities + pos = pos + agg_message / c + (vel if self.use_velocity else 0) - return x, pos + return (x, pos, vel) if self.use_velocity else (x, pos) diff --git a/pina/model/block/message_passing/equivariant_graph_neural_operator_block.py b/pina/model/block/message_passing/equivariant_graph_neural_operator_block.py new file mode 100644 index 0000000..f6c7392 --- /dev/null +++ b/pina/model/block/message_passing/equivariant_graph_neural_operator_block.py @@ -0,0 +1,188 @@ +"""Module for the Equivariant Graph Neural Operator block.""" + +import torch +from ....utils import check_positive_integer +from .en_equivariant_network_block import EnEquivariantNetworkBlock + + +class EquivariantGraphNeuralOperatorBlock(torch.nn.Module): + """ + A single block of the Equivariant Graph Neural Operator (EGNO). + + This block combines a temporal convolution with an equivariant graph neural + network (EGNN) layer. It preserves equivariance while modeling complex + interactions between nodes in a graph over time. + + .. seealso:: + + **Original reference** + Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, + K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). + *Equivariant Graph Neural Operator for Modeling 3D Dynamics* + DOI: `arXiv preprint arXiv:2401.11037. + `_ + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim, + pos_dim, + modes, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`EquivariantGraphNeuralOperatorBlock` + class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param int pos_dim: The dimension of the position features. + :param int modes: The number of Fourier modes to use in the temporal + convolution. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If ``modes`` is not a positive integer. + """ + super().__init__() + + # Check consistency + check_positive_integer(modes, strict=True) + + # Initialization + self.modes = modes + + # Temporal convolution weights - real and imaginary parts + self.weight_scalar_r = torch.nn.Parameter( + torch.rand(node_feature_dim, node_feature_dim, modes) + ) + self.weight_scalar_i = torch.nn.Parameter( + torch.rand(node_feature_dim, node_feature_dim, modes) + ) + self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1) + self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1) + + # EGNN block + self.egnn = EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + use_velocity=True, + hidden_dim=hidden_dim, + n_message_layers=n_message_layers, + n_update_layers=n_update_layers, + activation=activation, + aggr=aggr, + node_dim=node_dim, + flow=flow, + ) + + def forward(self, x, pos, vel, edge_index, edge_attr=None): + """ + Forward pass of the Equivariant Graph Neural Operator block. + + :param x: The node feature tensor of shape + ``[time_steps, num_nodes, node_feature_dim]``. + :type x: torch.Tensor | LabelTensor + :param pos: The node position tensor (Euclidean coordinates) of shape + ``[time_steps, num_nodes, pos_dim]``. + :type pos: torch.Tensor | LabelTensor + :param vel: The node velocity tensor of shape + ``[time_steps, num_nodes, pos_dim]``. + :type vel: torch.Tensor | LabelTensor + :param edge_index: The edge connectivity of shape ``[2, num_edges]``. + :type edge_index: torch.Tensor + :param edge_attr: The edge feature tensor of shape + ``[time_steps, num_edges, edge_feature_dim]``. Default is None. + :type edge_attr: torch.Tensor | LabelTensor, optional + :return: The updated node features, positions, and velocities, each with + the same shape as the inputs. + :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + """ + # Prepare features + center = pos.mean(dim=1, keepdim=True) + vector = torch.stack((pos - center, vel), dim=-1) + + # Compute temporal convolution + x = x + self._convolution( + x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i + ) + vector = vector + self._convolution( + vector, + "mndi, iom -> mndo", + self.weight_vector_r, + self.weight_vector_i, + ) + + # Split position and velocity + pos, vel = vector.unbind(dim=-1) + pos = pos + center + + # Reshape to (time * nodes, feature) for egnn + x = x.reshape(-1, x.shape[-1]) + pos = pos.reshape(-1, pos.shape[-1]) + vel = vel.reshape(-1, vel.shape[-1]) + if edge_attr is not None: + edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1]) + + x, pos, vel = self.egnn( + x=x, + pos=pos, + edge_index=edge_index, + edge_attr=edge_attr, + vel=vel, + ) + + # Reshape back to (time, nodes, feature) + x = x.reshape(center.shape[0], -1, x.shape[-1]) + pos = pos.reshape(center.shape[0], -1, pos.shape[-1]) + vel = vel.reshape(center.shape[0], -1, vel.shape[-1]) + + return x, pos, vel + + def _convolution(self, x, einsum_idx, real, img): + """ + Compute the temporal convolution. + + :param torch.Tensor x: The input features. + :param str einsum_idx: The indices for the einsum operation. + :param torch.Tensor real: The real part of the convolution weights. + :param torch.Tensor img: The imaginary part of the convolution weights. + :return: The convolved features. + :rtype: torch.Tensor + """ + # Number of modes to use + modes = min(self.modes, (x.shape[0] // 2) + 1) + + # Build complex weights + weights = torch.complex(real[..., :modes], img[..., :modes]) + + # Convolution in Fourier space + fourier = torch.fft.rfftn(x, dim=[0])[:modes] + out = torch.einsum(einsum_idx, fourier, weights) + + return torch.fft.irfftn(out, s=x.shape[0], dim=0) diff --git a/pina/model/equivariant_graph_neural_operator.py b/pina/model/equivariant_graph_neural_operator.py new file mode 100644 index 0000000..6b33df6 --- /dev/null +++ b/pina/model/equivariant_graph_neural_operator.py @@ -0,0 +1,219 @@ +"""Module for the Equivariant Graph Neural Operator model.""" + +import torch +from ..utils import check_positive_integer +from .block.message_passing import EquivariantGraphNeuralOperatorBlock + + +class EquivariantGraphNeuralOperator(torch.nn.Module): + """ + Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics. + + EGNO is a graph-based neural operator that preserves equivariance with + respect to 3D transformations while modeling temporal and spatial + interactions between nodes. It combines: + + 1. Temporal convolution in the Fourier domain to capture long-range + temporal dependencies efficiently. + 2. Equivariant Graph Neural Network (EGNN) layers to model interactions + between nodes while respecting geometric symmetries. + + This design allows EGNO to learn complex spatiotemporal dynamics of + physical systems, molecules, or particles while enforcing physically + meaningful constraints. + + .. seealso:: + + **Original reference** + Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, + K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). + *Equivariant Graph Neural Operator for Modeling 3D Dynamics* + DOI: `arXiv preprint arXiv:2401.11037. + `_ + """ + + def __init__( + self, + n_egno_layers, + node_feature_dim, + edge_feature_dim, + pos_dim, + modes, + time_steps=2, + hidden_dim=64, + time_emb_dim=16, + max_time_idx=10000, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`EquivariantGraphNeuralOperator` class. + + :param int n_egno_layers: The number of EGNO layers. + :param int node_feature_dim: The dimension of the node features in each + EGNO layer. + :param int edge_feature_dim: The dimension of the edge features in each + EGNO layer. + :param int pos_dim: The dimension of the position features in each + EGNO layer. + :param int modes: The number of Fourier modes to use in the temporal + convolution. + :param int time_steps: The number of time steps to consider in the + temporal convolution. Default is 2. + :param int hidden_dim: The dimension of the hidden features in each EGNO + layer. Default is 64. + :param int time_emb_dim: The dimension of the sinusoidal time + embeddings. Default is 16. + :param int max_time_idx: The maximum time index for the sinusoidal + embeddings. Default is 10000. + :param int n_message_layers: The number of layers in the message + network of each EGNO layer. Default is 2. + :param int n_update_layers: The number of layers in the update network + of each EGNO layer. Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If ``n_egno_layers`` is not a positive integer. + :raises AssertionError: If ``time_emb_dim`` is not a positive integer. + :raises AssertionError: If ``max_time_idx`` is not a positive integer. + :raises AssertionError: If ``time_steps`` is not a positive integer. + """ + super().__init__() + + # Check consistency + check_positive_integer(n_egno_layers, strict=True) + check_positive_integer(time_emb_dim, strict=True) + check_positive_integer(max_time_idx, strict=True) + check_positive_integer(time_steps, strict=True) + + # Initialize parameters + self.time_steps = time_steps + self.time_emb_dim = time_emb_dim + self.max_time_idx = max_time_idx + + # Initialize EGNO layers + self.egno_layers = torch.nn.ModuleList() + for _ in range(n_egno_layers): + self.egno_layers.append( + EquivariantGraphNeuralOperatorBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + modes=modes, + hidden_dim=hidden_dim, + n_message_layers=n_message_layers, + n_update_layers=n_update_layers, + activation=activation, + aggr=aggr, + node_dim=node_dim, + flow=flow, + ) + ) + + # Linear layer to adjust the scalar feature dimension + self.linear = torch.nn.Linear( + node_feature_dim + time_emb_dim, node_feature_dim + ) + + def forward(self, graph): + """ + Forward pass of the :class:`EquivariantGraphNeuralOperator` class. + + :param graph: The input graph object with the following attributes: + - 'x': Node features, shape ``[num_nodes, node_feature_dim]``. + - 'pos': Node positions, shape ``[num_nodes, pos_dim]``. + - 'vel': Node velocities, shape ``[num_nodes, pos_dim]``. + - 'edge_index': Graph connectivity, shape ``[2, num_edges]``. + - 'edge_attr': Edge attrs, shape ``[num_edges, edge_feature_dim]``. + :type graph: Data | Graph + :return: The output graph object with updated node features, + positions, and velocities. The output graph adds to 'x', 'pos', + 'vel', and 'edge_attr' the time dimension, resulting in shapes: + - 'x': ``[time_steps, num_nodes, node_feature_dim]`` + - 'pos': ``[time_steps, num_nodes, pos_dim]`` + - 'vel': ``[time_steps, num_nodes, pos_dim]`` + - 'edge_attr': ``[time_steps, num_edges, edge_feature_dim]`` + :rtype: Data | Graph + :raises ValueError: If the input graph does not have a 'vel' attribute. + """ + # Check that the graph has the required attributes + if "vel" not in graph: + raise ValueError("The input graph must have a 'vel' attribute.") + + # Compute the temporal embedding + emb = self._embedding(torch.arange(self.time_steps)).to(graph.x.device) + emb = emb.unsqueeze(1).repeat(1, graph.x.shape[0], 1) + + # Expand dimensions + x = graph.x.unsqueeze(0).repeat(self.time_steps, 1, 1) + x = self.linear(torch.cat((x, emb), dim=-1)) + pos = graph.pos.unsqueeze(0).repeat(self.time_steps, 1, 1) + vel = graph.vel.unsqueeze(0).repeat(self.time_steps, 1, 1) + + # Manage edge index + offset = torch.arange(self.time_steps).reshape(-1, 1) + offset = offset.to(graph.x.device) * graph.x.shape[0] + src = graph.edge_index[0].unsqueeze(0) + offset + dst = graph.edge_index[1].unsqueeze(0) + offset + edge_index = torch.stack([src, dst], dim=0).reshape(2, -1) + + # Manage edge attributes + if graph.edge_attr is not None: + edge_attr = graph.edge_attr.unsqueeze(0) + edge_attr = edge_attr.repeat(self.time_steps, 1, 1) + else: + edge_attr = None + + # Iteratively apply EGNO layers + for layer in self.egno_layers: + x, pos, vel = layer( + x=x, + pos=pos, + vel=vel, + edge_index=edge_index, + edge_attr=edge_attr, + ) + + # Build new graph + new_graph = graph.clone() + new_graph.x, new_graph.pos, new_graph.vel = x, pos, vel + if edge_attr is not None: + new_graph.edge_attr = edge_attr + + return new_graph + + def _embedding(self, time): + """ + Generate sinusoidal temporal embeddings. + + :param torch.Tensor time: The time instances. + :return: The sinusoidal embedding tensor. + :rtype: torch.Tensor + """ + # Compute the sinusoidal embeddings + half_dim = self.time_emb_dim // 2 + logs = torch.log(torch.as_tensor(self.max_time_idx)) / (half_dim - 1) + freqs = torch.exp(-torch.arange(half_dim) * logs) + args = torch.as_tensor(time)[:, None] * freqs[None, :] + emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + + # Apply padding if the embedding dimension is odd + if self.time_emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1), mode="constant") + + return emb diff --git a/tests/test_messagepassing/test_equivariant_network_block.py b/tests/test_messagepassing/test_equivariant_network_block.py index eea000a..0143440 100644 --- a/tests/test_messagepassing/test_equivariant_network_block.py +++ b/tests/test_messagepassing/test_equivariant_network_block.py @@ -5,19 +5,22 @@ from pina.model.block.message_passing import EnEquivariantNetworkBlock # Data for testing x = torch.rand(10, 4) pos = torch.rand(10, 3) -edge_index = torch.randint(0, 10, (2, 20)) -edge_attr = torch.randn(20, 2) +velocity = torch.rand(10, 3) +edge_idx = torch.randint(0, 10, (2, 20)) +edge_attributes = torch.randn(20, 2) @pytest.mark.parametrize("node_feature_dim", [1, 3]) @pytest.mark.parametrize("edge_feature_dim", [0, 2]) @pytest.mark.parametrize("pos_dim", [2, 3]) -def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): +@pytest.mark.parametrize("use_velocity", [True, False]) +def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, use_velocity): EnEquivariantNetworkBlock( node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, pos_dim=pos_dim, + use_velocity=use_velocity, hidden_dim=64, n_message_layers=2, n_update_layers=2, @@ -29,6 +32,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): node_feature_dim=-1, edge_feature_dim=edge_feature_dim, pos_dim=pos_dim, + use_velocity=use_velocity, ) # Should fail if edge_feature_dim is negative @@ -37,6 +41,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): node_feature_dim=node_feature_dim, edge_feature_dim=-1, pos_dim=pos_dim, + use_velocity=use_velocity, ) # Should fail if pos_dim is negative @@ -45,6 +50,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, pos_dim=-1, + use_velocity=use_velocity, ) # Should fail if hidden_dim is negative @@ -54,6 +60,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): edge_feature_dim=edge_feature_dim, pos_dim=pos_dim, hidden_dim=-1, + use_velocity=use_velocity, ) # Should fail if n_message_layers is negative @@ -63,6 +70,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): edge_feature_dim=edge_feature_dim, pos_dim=pos_dim, n_message_layers=-1, + use_velocity=use_velocity, ) # Should fail if n_update_layers is negative @@ -72,11 +80,22 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): edge_feature_dim=edge_feature_dim, pos_dim=pos_dim, n_update_layers=-1, + use_velocity=use_velocity, + ) + + # Should fail if use_velocity is not boolean + with pytest.raises(ValueError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + use_velocity="False", ) @pytest.mark.parametrize("edge_feature_dim", [0, 2]) -def test_forward(edge_feature_dim): +@pytest.mark.parametrize("use_velocity", [True, False]) +def test_forward(edge_feature_dim, use_velocity): model = EnEquivariantNetworkBlock( node_feature_dim=x.shape[1], @@ -85,21 +104,26 @@ def test_forward(edge_feature_dim): hidden_dim=64, n_message_layers=2, n_update_layers=2, + use_velocity=use_velocity, ) - if edge_feature_dim == 0: - output_ = model(edge_index=edge_index, x=x, pos=pos) - else: - output_ = model( - edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr - ) + # Manage inputs + vel = velocity if use_velocity else None + edge_attr = edge_attributes if edge_feature_dim > 0 else None + # Checks on output shapes + output_ = model( + x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel + ) assert output_[0].shape == x.shape assert output_[1].shape == pos.shape + if vel is not None: + assert output_[2].shape == vel.shape @pytest.mark.parametrize("edge_feature_dim", [0, 2]) -def test_backward(edge_feature_dim): +@pytest.mark.parametrize("use_velocity", [True, False]) +def test_backward(edge_feature_dim, use_velocity): model = EnEquivariantNetworkBlock( node_feature_dim=x.shape[1], @@ -108,35 +132,45 @@ def test_backward(edge_feature_dim): hidden_dim=64, n_message_layers=2, n_update_layers=2, + use_velocity=use_velocity, + ) + + # Manage inputs + vel = velocity.requires_grad_() if use_velocity else None + edge_attr = ( + edge_attributes.requires_grad_() if edge_feature_dim > 0 else None ) if edge_feature_dim == 0: output_ = model( - edge_index=edge_index, + edge_index=edge_idx, x=x.requires_grad_(), pos=pos.requires_grad_(), + vel=vel, ) else: output_ = model( - edge_index=edge_index, + edge_index=edge_idx, x=x.requires_grad_(), pos=pos.requires_grad_(), - edge_attr=edge_attr.requires_grad_(), + edge_attr=edge_attr, + vel=vel, ) - loss = torch.mean(output_[0]) + # Checks on gradients + loss = sum(torch.mean(output_[i]) for i in range(len(output_))) loss.backward() assert x.grad.shape == x.shape assert pos.grad.shape == pos.shape + if use_velocity: + assert vel.grad.shape == vel.shape -def test_equivariance(): +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +@pytest.mark.parametrize("use_velocity", [True, False]) +def test_equivariance(edge_feature_dim, use_velocity): - # Graph to be fully connected and undirected - edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T - edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) - - # Random rotation (det(rotation) should be 1) + # Random rotation rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q if torch.det(rotation) < 0: rotation[:, 0] *= -1 @@ -146,20 +180,37 @@ def test_equivariance(): model = EnEquivariantNetworkBlock( node_feature_dim=x.shape[1], - edge_feature_dim=0, + edge_feature_dim=edge_feature_dim, pos_dim=pos.shape[1], hidden_dim=64, n_message_layers=2, n_update_layers=2, + use_velocity=use_velocity, ).eval() - h1, pos1 = model(edge_index=edge_index, x=x, pos=pos) - h2, pos2 = model( - edge_index=edge_index, x=x, pos=pos @ rotation.T + translation + # Manage inputs + vel = velocity if use_velocity else None + edge_attr = edge_attributes if edge_feature_dim > 0 else None + + # Transform inputs (no translation for velocity) + pos_rot = pos @ rotation.T + translation + vel_rot = vel @ rotation.T if use_velocity else vel + + # Get model outputs + out1 = model( + x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel + ) + out2 = model( + x=x, pos=pos_rot, edge_index=edge_idx, edge_attr=edge_attr, vel=vel_rot ) - # Transform model output - pos1_transformed = (pos1 @ rotation.T) + translation + # Unpack outputs + h1, pos1, *other1 = out1 + h2, pos2, *other2 = out2 + if use_velocity: + vel1, vel2 = other1[0], other2[0] - assert torch.allclose(pos2, pos1_transformed, atol=1e-5) + assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5) assert torch.allclose(h1, h2, atol=1e-5) + if vel is not None: + assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5) diff --git a/tests/test_messagepassing/test_equivariant_operator_block.py b/tests/test_messagepassing/test_equivariant_operator_block.py new file mode 100644 index 0000000..ad4f050 --- /dev/null +++ b/tests/test_messagepassing/test_equivariant_operator_block.py @@ -0,0 +1,132 @@ +import pytest +import torch +from pina.model.block.message_passing import EquivariantGraphNeuralOperatorBlock + +# Data for testing. Shapes: (time, nodes, features) +x = torch.rand(5, 10, 4) +pos = torch.rand(5, 10, 3) +vel = torch.rand(5, 10, 3) + +# Edge index and attributes +edge_idx = torch.randint(0, 10, (2, 20)) +edge_attributes = torch.randn(20, 2) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +@pytest.mark.parametrize("pos_dim", [2, 3]) +@pytest.mark.parametrize("modes", [1, 5]) +def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, modes): + + EquivariantGraphNeuralOperatorBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + modes=modes, + ) + + # Should fail if modes is negative + with pytest.raises(AssertionError): + EquivariantGraphNeuralOperatorBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + modes=-1, + ) + + +@pytest.mark.parametrize("modes", [1, 5]) +def test_forward(modes): + + model = EquivariantGraphNeuralOperatorBlock( + node_feature_dim=x.shape[2], + edge_feature_dim=edge_attributes.shape[1], + pos_dim=pos.shape[2], + modes=modes, + ) + + output_ = model( + x=x, + pos=pos, + vel=vel, + edge_index=edge_idx, + edge_attr=edge_attributes, + ) + + # Checks on output shapes + assert output_[0].shape == x.shape + assert output_[1].shape == pos.shape + assert output_[2].shape == vel.shape + + +@pytest.mark.parametrize("modes", [1, 5]) +def test_backward(modes): + + model = EquivariantGraphNeuralOperatorBlock( + node_feature_dim=x.shape[2], + edge_feature_dim=edge_attributes.shape[1], + pos_dim=pos.shape[2], + modes=modes, + ) + + output_ = model( + x=x.requires_grad_(), + pos=pos.requires_grad_(), + vel=vel.requires_grad_(), + edge_index=edge_idx, + edge_attr=edge_attributes.requires_grad_(), + ) + + # Checks on gradients + loss = sum(torch.mean(output_[i]) for i in range(len(output_))) + loss.backward() + assert x.grad.shape == x.shape + assert pos.grad.shape == pos.shape + assert vel.grad.shape == vel.shape + + +@pytest.mark.parametrize("modes", [1, 5]) +def test_equivariance(modes): + + # Random rotation + rotation = torch.linalg.qr(torch.rand(pos.shape[2], pos.shape[2])).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, pos.shape[2]) + + model = EquivariantGraphNeuralOperatorBlock( + node_feature_dim=x.shape[2], + edge_feature_dim=edge_attributes.shape[1], + pos_dim=pos.shape[2], + modes=modes, + ).eval() + + # Transform inputs (no translation for velocity) + pos_rot = pos @ rotation.T + translation + vel_rot = vel @ rotation.T + + # Get model outputs + out1 = model( + x=x, + pos=pos, + vel=vel, + edge_index=edge_idx, + edge_attr=edge_attributes, + ) + out2 = model( + x=x, + pos=pos_rot, + vel=vel_rot, + edge_index=edge_idx, + edge_attr=edge_attributes, + ) + + # Unpack outputs + h1, pos1, vel1 = out1 + h2, pos2, vel2 = out2 + + assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5) + assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5) + assert torch.allclose(h1, h2, atol=1e-5) diff --git a/tests/test_model/test_equivariant_graph_neural_operator.py b/tests/test_model/test_equivariant_graph_neural_operator.py new file mode 100644 index 0000000..c4c0484 --- /dev/null +++ b/tests/test_model/test_equivariant_graph_neural_operator.py @@ -0,0 +1,194 @@ +import pytest +import torch +import copy +from pina.model import EquivariantGraphNeuralOperator +from pina.graph import Graph + + +# Utility to create graphs +def make_graph(include_vel=True, use_edge_attr=True): + data = dict( + x=torch.rand(10, 4), + pos=torch.rand(10, 3), + edge_index=torch.randint(0, 10, (2, 20)), + edge_attr=torch.randn(20, 2) if use_edge_attr else None, + ) + if include_vel: + data["vel"] = torch.rand(10, 3) + return Graph(**data) + + +@pytest.mark.parametrize("n_egno_layers", [1, 3]) +@pytest.mark.parametrize("time_steps", [1, 3]) +@pytest.mark.parametrize("time_emb_dim", [4, 8]) +@pytest.mark.parametrize("max_time_idx", [10, 20]) +def test_constructor(n_egno_layers, time_steps, time_emb_dim, max_time_idx): + + # Create graph and model + graph = make_graph() + EquivariantGraphNeuralOperator( + n_egno_layers=n_egno_layers, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1], + pos_dim=graph.pos.shape[1], + modes=5, + time_steps=time_steps, + time_emb_dim=time_emb_dim, + max_time_idx=max_time_idx, + ) + + # Should fail if n_egno_layers is negative + with pytest.raises(AssertionError): + EquivariantGraphNeuralOperator( + n_egno_layers=-1, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1], + pos_dim=graph.pos.shape[1], + modes=5, + time_steps=time_steps, + time_emb_dim=time_emb_dim, + max_time_idx=max_time_idx, + ) + + # Should fail if time_steps is negative + with pytest.raises(AssertionError): + EquivariantGraphNeuralOperator( + n_egno_layers=n_egno_layers, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1], + pos_dim=graph.pos.shape[1], + modes=5, + time_steps=-1, + time_emb_dim=time_emb_dim, + max_time_idx=max_time_idx, + ) + + # Should fail if max_time_idx is negative + with pytest.raises(AssertionError): + EquivariantGraphNeuralOperator( + n_egno_layers=n_egno_layers, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1], + pos_dim=graph.pos.shape[1], + modes=5, + time_steps=time_steps, + time_emb_dim=time_emb_dim, + max_time_idx=-1, + ) + + # Should fail if time_emb_dim is negative + with pytest.raises(AssertionError): + EquivariantGraphNeuralOperator( + n_egno_layers=n_egno_layers, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1], + pos_dim=graph.pos.shape[1], + modes=5, + time_steps=time_steps, + time_emb_dim=-1, + max_time_idx=max_time_idx, + ) + + +@pytest.mark.parametrize("n_egno_layers", [1, 3]) +@pytest.mark.parametrize("time_steps", [1, 5]) +@pytest.mark.parametrize("modes", [1, 3, 10]) +@pytest.mark.parametrize("use_edge_attr", [True, False]) +def test_forward(n_egno_layers, time_steps, modes, use_edge_attr): + + # Create graph and model + graph = make_graph(use_edge_attr=use_edge_attr) + model = EquivariantGraphNeuralOperator( + n_egno_layers=n_egno_layers, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0, + pos_dim=graph.pos.shape[1], + modes=modes, + time_steps=time_steps, + ) + + # Checks on output shapes + output_ = model(graph) + assert output_.x.shape == (time_steps, *graph.x.shape) + assert output_.pos.shape == (time_steps, *graph.pos.shape) + assert output_.vel.shape == (time_steps, *graph.vel.shape) + + # Should fail graph has no vel attribute + with pytest.raises(ValueError): + graph_no_vel = make_graph(include_vel=False) + model(graph_no_vel) + + +@pytest.mark.parametrize("n_egno_layers", [1, 3]) +@pytest.mark.parametrize("time_steps", [1, 5]) +@pytest.mark.parametrize("modes", [1, 3, 10]) +@pytest.mark.parametrize("use_edge_attr", [True, False]) +def test_backward(n_egno_layers, time_steps, modes, use_edge_attr): + + # Create graph and model + graph = make_graph(use_edge_attr=use_edge_attr) + model = EquivariantGraphNeuralOperator( + n_egno_layers=n_egno_layers, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0, + pos_dim=graph.pos.shape[1], + modes=modes, + time_steps=time_steps, + ) + + # Set requires_grad and perform forward pass + graph.x.requires_grad_() + graph.pos.requires_grad_() + graph.vel.requires_grad_() + out = model(graph) + + # Checks on gradients + loss = torch.mean(out.x) + torch.mean(out.pos) + torch.mean(out.vel) + loss.backward() + assert graph.x.grad.shape == graph.x.shape + assert graph.pos.grad.shape == graph.pos.shape + assert graph.vel.grad.shape == graph.vel.shape + + +@pytest.mark.parametrize("n_egno_layers", [1, 3]) +@pytest.mark.parametrize("time_steps", [1, 5]) +@pytest.mark.parametrize("modes", [1, 3, 10]) +@pytest.mark.parametrize("use_edge_attr", [True, False]) +def test_equivariance(n_egno_layers, time_steps, modes, use_edge_attr): + + graph = make_graph(use_edge_attr=use_edge_attr) + model = EquivariantGraphNeuralOperator( + n_egno_layers=n_egno_layers, + node_feature_dim=graph.x.shape[1], + edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0, + pos_dim=graph.pos.shape[1], + modes=modes, + time_steps=time_steps, + ).eval() + + # Random rotation + rotation = torch.linalg.qr( + torch.rand(graph.pos.shape[1], graph.pos.shape[1]) + ).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, graph.pos.shape[1]) + + # Transform graph (no translation for velocity) + graph_rot = copy.deepcopy(graph) + graph_rot.pos = graph.pos @ rotation.T + translation + graph_rot.vel = graph.vel @ rotation.T + + # Get model outputs + out1 = model(graph) + out2 = model(graph_rot) + + # Unpack outputs + h1, pos1, vel1 = out1.x, out1.pos, out1.vel + h2, pos2, vel2 = out2.x, out2.pos, out2.vel + + assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5) + assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5) + assert torch.allclose(h1, h2, atol=1e-5)