add egno (#602)
Co-authored-by: GiovanniCanali <giovanni.canali98@yahoo.it>
This commit is contained in:
@@ -105,6 +105,7 @@ Models
|
||||
GraphNeuralOperator <model/graph_neural_operator.rst>
|
||||
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
|
||||
PirateNet <model/pirate_network.rst>
|
||||
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
|
||||
|
||||
Blocks
|
||||
-------------
|
||||
@@ -134,6 +135,7 @@ Message Passing
|
||||
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
|
||||
Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
|
||||
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
|
||||
EquivariantGraphNeuralOperatorBlock <model/block/message_passing/equivariant_graph_neural_operator_block.rst>
|
||||
|
||||
|
||||
Reduction and Embeddings
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
EquivariantGraphNeuralOperatorBlock
|
||||
=====================================
|
||||
.. currentmodule:: pina.model.block.message_passing.equivariant_graph_neural_operator_block
|
||||
|
||||
.. autoclass:: EquivariantGraphNeuralOperatorBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -0,0 +1,7 @@
|
||||
EquivariantGraphNeuralOperator
|
||||
=================================
|
||||
.. currentmodule:: pina.model.equivariant_graph_neural_operator
|
||||
|
||||
.. autoclass:: EquivariantGraphNeuralOperator
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -14,6 +14,7 @@ __all__ = [
|
||||
"Spline",
|
||||
"GraphNeuralOperator",
|
||||
"PirateNet",
|
||||
"EquivariantGraphNeuralOperator",
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward, ResidualFeedForward
|
||||
@@ -26,3 +27,4 @@ from .low_rank_neural_operator import LowRankNeuralOperator
|
||||
from .spline import Spline
|
||||
from .graph_neural_operator import GraphNeuralOperator
|
||||
from .pirate_network import PirateNet
|
||||
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
|
||||
|
||||
@@ -5,9 +5,13 @@ __all__ = [
|
||||
"DeepTensorNetworkBlock",
|
||||
"EnEquivariantNetworkBlock",
|
||||
"RadialFieldNetworkBlock",
|
||||
"EquivariantGraphNeuralOperatorBlock",
|
||||
]
|
||||
|
||||
from .interaction_network_block import InteractionNetworkBlock
|
||||
from .deep_tensor_network_block import DeepTensorNetworkBlock
|
||||
from .en_equivariant_network_block import EnEquivariantNetworkBlock
|
||||
from .radial_field_network_block import RadialFieldNetworkBlock
|
||||
from .equivariant_graph_neural_operator_block import (
|
||||
EquivariantGraphNeuralOperatorBlock,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from torch_geometric.utils import degree
|
||||
from ....utils import check_positive_integer
|
||||
from ....utils import check_positive_integer, check_consistency
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
@@ -27,6 +27,12 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
positions are updated by adding the incoming messages divided by the
|
||||
degree of the recipient node.
|
||||
|
||||
When velocity features are used, node velocities are passed through a small
|
||||
MLP to compute updates, which are then combined with the aggregated position
|
||||
messages. The node positions are updated both by the normalized position
|
||||
messages and by the updated velocities, ensuring equivariance while
|
||||
incorporating dynamic information.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
|
||||
@@ -40,6 +46,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
pos_dim,
|
||||
use_velocity=False,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
@@ -54,6 +61,8 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
:param int pos_dim: The dimension of the position features.
|
||||
:param bool use_velocity: Whether to use velocity features in the
|
||||
message passing. Default is False.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
@@ -80,6 +89,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_message_layers` is not a positive integer.
|
||||
:raises AssertionError: If `n_update_layers` is not a positive integer.
|
||||
:raises AssertionError: If `use_velocity` is not a boolean.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
@@ -90,6 +100,10 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_message_layers, strict=True)
|
||||
check_positive_integer(n_update_layers, strict=True)
|
||||
check_consistency(use_velocity, bool)
|
||||
|
||||
# Initialization
|
||||
self.use_velocity = use_velocity
|
||||
|
||||
# Layer for computing the message
|
||||
self.message_net = FeedForward(
|
||||
@@ -119,7 +133,17 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, edge_index, edge_attr=None):
|
||||
# If velocity is used, instantiate layer for velocity updates
|
||||
if self.use_velocity:
|
||||
self.update_vel_net = FeedForward(
|
||||
input_dimensions=node_feature_dim,
|
||||
output_dimensions=1,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
@@ -130,11 +154,19 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:param edge_attr: The edge attributes. Default is None.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:param vel: The velocity of the nodes. Default is None.
|
||||
:type vel: torch.Tensor | LabelTensor
|
||||
:return: The updated node features and node positions.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
:raises: ValueError: If ``use_velocity`` is True and ``vel`` is None.
|
||||
"""
|
||||
if self.use_velocity and vel is None:
|
||||
raise ValueError(
|
||||
"Velocity features are enabled, but no velocity is passed."
|
||||
)
|
||||
|
||||
return self.propagate(
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
|
||||
)
|
||||
|
||||
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
|
||||
@@ -202,10 +234,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
|
||||
return agg_message, agg_m_ij
|
||||
|
||||
def update(self, aggregated_inputs, x, pos, edge_index):
|
||||
def update(self, aggregated_inputs, x, pos, edge_index, vel):
|
||||
"""
|
||||
Update the node features and the node coordinates with the received
|
||||
messages.
|
||||
Update node features, positions, and optionally velocities.
|
||||
|
||||
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
|
||||
:param x: The node features.
|
||||
@@ -213,17 +244,26 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
:param pos: The euclidean coordinates of the nodes.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:param vel: The velocity of the nodes.
|
||||
:type vel: torch.Tensor | LabelTensor
|
||||
:return: The updated node features and node positions.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor) |
|
||||
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
# aggregated_inputs is tuple (agg_message, agg_m_ij)
|
||||
agg_message, agg_m_ij = aggregated_inputs
|
||||
|
||||
# Degree for normalization of position updates
|
||||
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
|
||||
|
||||
# If velocity is used, update it and use it to update positions
|
||||
if self.use_velocity:
|
||||
vel = self.update_vel_net(x) * vel
|
||||
|
||||
# Update node features with aggregated m_ij
|
||||
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
|
||||
|
||||
# Degree for normalization of position updates
|
||||
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
|
||||
pos = pos + agg_message / c
|
||||
# Update positions with aggregated messages m_ij and velocities
|
||||
pos = pos + agg_message / c + (vel if self.use_velocity else 0)
|
||||
|
||||
return x, pos
|
||||
return (x, pos, vel) if self.use_velocity else (x, pos)
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Module for the Equivariant Graph Neural Operator block."""
|
||||
|
||||
import torch
|
||||
from ....utils import check_positive_integer
|
||||
from .en_equivariant_network_block import EnEquivariantNetworkBlock
|
||||
|
||||
|
||||
class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
||||
"""
|
||||
A single block of the Equivariant Graph Neural Operator (EGNO).
|
||||
|
||||
This block combines a temporal convolution with an equivariant graph neural
|
||||
network (EGNN) layer. It preserves equivariance while modeling complex
|
||||
interactions between nodes in a graph over time.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**
|
||||
Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli,
|
||||
K., Leskovec, J., Ermon, S., Anandkumar, A. (2024).
|
||||
*Equivariant Graph Neural Operator for Modeling 3D Dynamics*
|
||||
DOI: `arXiv preprint arXiv:2401.11037.
|
||||
<https://arxiv.org/abs/2401.11037>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
pos_dim,
|
||||
modes,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
activation=torch.nn.SiLU,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`EquivariantGraphNeuralOperatorBlock`
|
||||
class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
:param int pos_dim: The dimension of the position features.
|
||||
:param int modes: The number of Fourier modes to use in the temporal
|
||||
convolution.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
network. Default is 2.
|
||||
:param int n_update_layers: The number of layers in the update network.
|
||||
Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.SiLU`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If ``modes`` is not a positive integer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_positive_integer(modes, strict=True)
|
||||
|
||||
# Initialization
|
||||
self.modes = modes
|
||||
|
||||
# Temporal convolution weights - real and imaginary parts
|
||||
self.weight_scalar_r = torch.nn.Parameter(
|
||||
torch.rand(node_feature_dim, node_feature_dim, modes)
|
||||
)
|
||||
self.weight_scalar_i = torch.nn.Parameter(
|
||||
torch.rand(node_feature_dim, node_feature_dim, modes)
|
||||
)
|
||||
self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
|
||||
self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
|
||||
|
||||
# EGNN block
|
||||
self.egnn = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity=True,
|
||||
hidden_dim=hidden_dim,
|
||||
n_message_layers=n_message_layers,
|
||||
n_update_layers=n_update_layers,
|
||||
activation=activation,
|
||||
aggr=aggr,
|
||||
node_dim=node_dim,
|
||||
flow=flow,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, vel, edge_index, edge_attr=None):
|
||||
"""
|
||||
Forward pass of the Equivariant Graph Neural Operator block.
|
||||
|
||||
:param x: The node feature tensor of shape
|
||||
``[time_steps, num_nodes, node_feature_dim]``.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param pos: The node position tensor (Euclidean coordinates) of shape
|
||||
``[time_steps, num_nodes, pos_dim]``.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param vel: The node velocity tensor of shape
|
||||
``[time_steps, num_nodes, pos_dim]``.
|
||||
:type vel: torch.Tensor | LabelTensor
|
||||
:param edge_index: The edge connectivity of shape ``[2, num_edges]``.
|
||||
:type edge_index: torch.Tensor
|
||||
:param edge_attr: The edge feature tensor of shape
|
||||
``[time_steps, num_edges, edge_feature_dim]``. Default is None.
|
||||
:type edge_attr: torch.Tensor | LabelTensor, optional
|
||||
:return: The updated node features, positions, and velocities, each with
|
||||
the same shape as the inputs.
|
||||
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
# Prepare features
|
||||
center = pos.mean(dim=1, keepdim=True)
|
||||
vector = torch.stack((pos - center, vel), dim=-1)
|
||||
|
||||
# Compute temporal convolution
|
||||
x = x + self._convolution(
|
||||
x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i
|
||||
)
|
||||
vector = vector + self._convolution(
|
||||
vector,
|
||||
"mndi, iom -> mndo",
|
||||
self.weight_vector_r,
|
||||
self.weight_vector_i,
|
||||
)
|
||||
|
||||
# Split position and velocity
|
||||
pos, vel = vector.unbind(dim=-1)
|
||||
pos = pos + center
|
||||
|
||||
# Reshape to (time * nodes, feature) for egnn
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
pos = pos.reshape(-1, pos.shape[-1])
|
||||
vel = vel.reshape(-1, vel.shape[-1])
|
||||
if edge_attr is not None:
|
||||
edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1])
|
||||
|
||||
x, pos, vel = self.egnn(
|
||||
x=x,
|
||||
pos=pos,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
vel=vel,
|
||||
)
|
||||
|
||||
# Reshape back to (time, nodes, feature)
|
||||
x = x.reshape(center.shape[0], -1, x.shape[-1])
|
||||
pos = pos.reshape(center.shape[0], -1, pos.shape[-1])
|
||||
vel = vel.reshape(center.shape[0], -1, vel.shape[-1])
|
||||
|
||||
return x, pos, vel
|
||||
|
||||
def _convolution(self, x, einsum_idx, real, img):
|
||||
"""
|
||||
Compute the temporal convolution.
|
||||
|
||||
:param torch.Tensor x: The input features.
|
||||
:param str einsum_idx: The indices for the einsum operation.
|
||||
:param torch.Tensor real: The real part of the convolution weights.
|
||||
:param torch.Tensor img: The imaginary part of the convolution weights.
|
||||
:return: The convolved features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Number of modes to use
|
||||
modes = min(self.modes, (x.shape[0] // 2) + 1)
|
||||
|
||||
# Build complex weights
|
||||
weights = torch.complex(real[..., :modes], img[..., :modes])
|
||||
|
||||
# Convolution in Fourier space
|
||||
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
|
||||
out = torch.einsum(einsum_idx, fourier, weights)
|
||||
|
||||
return torch.fft.irfftn(out, s=x.shape[0], dim=0)
|
||||
219
pina/model/equivariant_graph_neural_operator.py
Normal file
219
pina/model/equivariant_graph_neural_operator.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Module for the Equivariant Graph Neural Operator model."""
|
||||
|
||||
import torch
|
||||
from ..utils import check_positive_integer
|
||||
from .block.message_passing import EquivariantGraphNeuralOperatorBlock
|
||||
|
||||
|
||||
class EquivariantGraphNeuralOperator(torch.nn.Module):
|
||||
"""
|
||||
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
|
||||
|
||||
EGNO is a graph-based neural operator that preserves equivariance with
|
||||
respect to 3D transformations while modeling temporal and spatial
|
||||
interactions between nodes. It combines:
|
||||
|
||||
1. Temporal convolution in the Fourier domain to capture long-range
|
||||
temporal dependencies efficiently.
|
||||
2. Equivariant Graph Neural Network (EGNN) layers to model interactions
|
||||
between nodes while respecting geometric symmetries.
|
||||
|
||||
This design allows EGNO to learn complex spatiotemporal dynamics of
|
||||
physical systems, molecules, or particles while enforcing physically
|
||||
meaningful constraints.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**
|
||||
Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli,
|
||||
K., Leskovec, J., Ermon, S., Anandkumar, A. (2024).
|
||||
*Equivariant Graph Neural Operator for Modeling 3D Dynamics*
|
||||
DOI: `arXiv preprint arXiv:2401.11037.
|
||||
<https://arxiv.org/abs/2401.11037>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_egno_layers,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
pos_dim,
|
||||
modes,
|
||||
time_steps=2,
|
||||
hidden_dim=64,
|
||||
time_emb_dim=16,
|
||||
max_time_idx=10000,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
activation=torch.nn.SiLU,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`EquivariantGraphNeuralOperator` class.
|
||||
|
||||
:param int n_egno_layers: The number of EGNO layers.
|
||||
:param int node_feature_dim: The dimension of the node features in each
|
||||
EGNO layer.
|
||||
:param int edge_feature_dim: The dimension of the edge features in each
|
||||
EGNO layer.
|
||||
:param int pos_dim: The dimension of the position features in each
|
||||
EGNO layer.
|
||||
:param int modes: The number of Fourier modes to use in the temporal
|
||||
convolution.
|
||||
:param int time_steps: The number of time steps to consider in the
|
||||
temporal convolution. Default is 2.
|
||||
:param int hidden_dim: The dimension of the hidden features in each EGNO
|
||||
layer. Default is 64.
|
||||
:param int time_emb_dim: The dimension of the sinusoidal time
|
||||
embeddings. Default is 16.
|
||||
:param int max_time_idx: The maximum time index for the sinusoidal
|
||||
embeddings. Default is 10000.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
network of each EGNO layer. Default is 2.
|
||||
:param int n_update_layers: The number of layers in the update network
|
||||
of each EGNO layer. Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.SiLU`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If ``n_egno_layers`` is not a positive integer.
|
||||
:raises AssertionError: If ``time_emb_dim`` is not a positive integer.
|
||||
:raises AssertionError: If ``max_time_idx`` is not a positive integer.
|
||||
:raises AssertionError: If ``time_steps`` is not a positive integer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_positive_integer(n_egno_layers, strict=True)
|
||||
check_positive_integer(time_emb_dim, strict=True)
|
||||
check_positive_integer(max_time_idx, strict=True)
|
||||
check_positive_integer(time_steps, strict=True)
|
||||
|
||||
# Initialize parameters
|
||||
self.time_steps = time_steps
|
||||
self.time_emb_dim = time_emb_dim
|
||||
self.max_time_idx = max_time_idx
|
||||
|
||||
# Initialize EGNO layers
|
||||
self.egno_layers = torch.nn.ModuleList()
|
||||
for _ in range(n_egno_layers):
|
||||
self.egno_layers.append(
|
||||
EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
modes=modes,
|
||||
hidden_dim=hidden_dim,
|
||||
n_message_layers=n_message_layers,
|
||||
n_update_layers=n_update_layers,
|
||||
activation=activation,
|
||||
aggr=aggr,
|
||||
node_dim=node_dim,
|
||||
flow=flow,
|
||||
)
|
||||
)
|
||||
|
||||
# Linear layer to adjust the scalar feature dimension
|
||||
self.linear = torch.nn.Linear(
|
||||
node_feature_dim + time_emb_dim, node_feature_dim
|
||||
)
|
||||
|
||||
def forward(self, graph):
|
||||
"""
|
||||
Forward pass of the :class:`EquivariantGraphNeuralOperator` class.
|
||||
|
||||
:param graph: The input graph object with the following attributes:
|
||||
- 'x': Node features, shape ``[num_nodes, node_feature_dim]``.
|
||||
- 'pos': Node positions, shape ``[num_nodes, pos_dim]``.
|
||||
- 'vel': Node velocities, shape ``[num_nodes, pos_dim]``.
|
||||
- 'edge_index': Graph connectivity, shape ``[2, num_edges]``.
|
||||
- 'edge_attr': Edge attrs, shape ``[num_edges, edge_feature_dim]``.
|
||||
:type graph: Data | Graph
|
||||
:return: The output graph object with updated node features,
|
||||
positions, and velocities. The output graph adds to 'x', 'pos',
|
||||
'vel', and 'edge_attr' the time dimension, resulting in shapes:
|
||||
- 'x': ``[time_steps, num_nodes, node_feature_dim]``
|
||||
- 'pos': ``[time_steps, num_nodes, pos_dim]``
|
||||
- 'vel': ``[time_steps, num_nodes, pos_dim]``
|
||||
- 'edge_attr': ``[time_steps, num_edges, edge_feature_dim]``
|
||||
:rtype: Data | Graph
|
||||
:raises ValueError: If the input graph does not have a 'vel' attribute.
|
||||
"""
|
||||
# Check that the graph has the required attributes
|
||||
if "vel" not in graph:
|
||||
raise ValueError("The input graph must have a 'vel' attribute.")
|
||||
|
||||
# Compute the temporal embedding
|
||||
emb = self._embedding(torch.arange(self.time_steps)).to(graph.x.device)
|
||||
emb = emb.unsqueeze(1).repeat(1, graph.x.shape[0], 1)
|
||||
|
||||
# Expand dimensions
|
||||
x = graph.x.unsqueeze(0).repeat(self.time_steps, 1, 1)
|
||||
x = self.linear(torch.cat((x, emb), dim=-1))
|
||||
pos = graph.pos.unsqueeze(0).repeat(self.time_steps, 1, 1)
|
||||
vel = graph.vel.unsqueeze(0).repeat(self.time_steps, 1, 1)
|
||||
|
||||
# Manage edge index
|
||||
offset = torch.arange(self.time_steps).reshape(-1, 1)
|
||||
offset = offset.to(graph.x.device) * graph.x.shape[0]
|
||||
src = graph.edge_index[0].unsqueeze(0) + offset
|
||||
dst = graph.edge_index[1].unsqueeze(0) + offset
|
||||
edge_index = torch.stack([src, dst], dim=0).reshape(2, -1)
|
||||
|
||||
# Manage edge attributes
|
||||
if graph.edge_attr is not None:
|
||||
edge_attr = graph.edge_attr.unsqueeze(0)
|
||||
edge_attr = edge_attr.repeat(self.time_steps, 1, 1)
|
||||
else:
|
||||
edge_attr = None
|
||||
|
||||
# Iteratively apply EGNO layers
|
||||
for layer in self.egno_layers:
|
||||
x, pos, vel = layer(
|
||||
x=x,
|
||||
pos=pos,
|
||||
vel=vel,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
)
|
||||
|
||||
# Build new graph
|
||||
new_graph = graph.clone()
|
||||
new_graph.x, new_graph.pos, new_graph.vel = x, pos, vel
|
||||
if edge_attr is not None:
|
||||
new_graph.edge_attr = edge_attr
|
||||
|
||||
return new_graph
|
||||
|
||||
def _embedding(self, time):
|
||||
"""
|
||||
Generate sinusoidal temporal embeddings.
|
||||
|
||||
:param torch.Tensor time: The time instances.
|
||||
:return: The sinusoidal embedding tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Compute the sinusoidal embeddings
|
||||
half_dim = self.time_emb_dim // 2
|
||||
logs = torch.log(torch.as_tensor(self.max_time_idx)) / (half_dim - 1)
|
||||
freqs = torch.exp(-torch.arange(half_dim) * logs)
|
||||
args = torch.as_tensor(time)[:, None] * freqs[None, :]
|
||||
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||
|
||||
# Apply padding if the embedding dimension is odd
|
||||
if self.time_emb_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1), mode="constant")
|
||||
|
||||
return emb
|
||||
@@ -5,19 +5,22 @@ from pina.model.block.message_passing import EnEquivariantNetworkBlock
|
||||
# Data for testing
|
||||
x = torch.rand(10, 4)
|
||||
pos = torch.rand(10, 3)
|
||||
edge_index = torch.randint(0, 10, (2, 20))
|
||||
edge_attr = torch.randn(20, 2)
|
||||
velocity = torch.rand(10, 3)
|
||||
edge_idx = torch.randint(0, 10, (2, 20))
|
||||
edge_attributes = torch.randn(20, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
@pytest.mark.parametrize("pos_dim", [2, 3])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, use_velocity):
|
||||
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity=use_velocity,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
@@ -29,6 +32,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
node_feature_dim=-1,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if edge_feature_dim is negative
|
||||
@@ -37,6 +41,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=-1,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if pos_dim is negative
|
||||
@@ -45,6 +50,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if hidden_dim is negative
|
||||
@@ -54,6 +60,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
hidden_dim=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if n_message_layers is negative
|
||||
@@ -63,6 +70,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
n_message_layers=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if n_update_layers is negative
|
||||
@@ -72,11 +80,22 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
n_update_layers=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if use_velocity is not boolean
|
||||
with pytest.raises(ValueError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity="False",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_forward(edge_feature_dim):
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_forward(edge_feature_dim, use_velocity):
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
@@ -85,21 +104,26 @@ def test_forward(edge_feature_dim):
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(edge_index=edge_index, x=x, pos=pos)
|
||||
else:
|
||||
output_ = model(
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
|
||||
)
|
||||
# Manage inputs
|
||||
vel = velocity if use_velocity else None
|
||||
edge_attr = edge_attributes if edge_feature_dim > 0 else None
|
||||
|
||||
# Checks on output shapes
|
||||
output_ = model(
|
||||
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
|
||||
)
|
||||
assert output_[0].shape == x.shape
|
||||
assert output_[1].shape == pos.shape
|
||||
if vel is not None:
|
||||
assert output_[2].shape == vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_backward(edge_feature_dim):
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_backward(edge_feature_dim, use_velocity):
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
@@ -108,35 +132,45 @@ def test_backward(edge_feature_dim):
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Manage inputs
|
||||
vel = velocity.requires_grad_() if use_velocity else None
|
||||
edge_attr = (
|
||||
edge_attributes.requires_grad_() if edge_feature_dim > 0 else None
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
edge_index=edge_idx,
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
vel=vel,
|
||||
)
|
||||
else:
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
edge_index=edge_idx,
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
edge_attr=edge_attr.requires_grad_(),
|
||||
edge_attr=edge_attr,
|
||||
vel=vel,
|
||||
)
|
||||
|
||||
loss = torch.mean(output_[0])
|
||||
# Checks on gradients
|
||||
loss = sum(torch.mean(output_[i]) for i in range(len(output_)))
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
assert pos.grad.shape == pos.shape
|
||||
if use_velocity:
|
||||
assert vel.grad.shape == vel.shape
|
||||
|
||||
|
||||
def test_equivariance():
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_equivariance(edge_feature_dim, use_velocity):
|
||||
|
||||
# Graph to be fully connected and undirected
|
||||
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
|
||||
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
|
||||
|
||||
# Random rotation (det(rotation) should be 1)
|
||||
# Random rotation
|
||||
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
@@ -146,20 +180,37 @@ def test_equivariance():
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=0,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos.shape[1],
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
use_velocity=use_velocity,
|
||||
).eval()
|
||||
|
||||
h1, pos1 = model(edge_index=edge_index, x=x, pos=pos)
|
||||
h2, pos2 = model(
|
||||
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation
|
||||
# Manage inputs
|
||||
vel = velocity if use_velocity else None
|
||||
edge_attr = edge_attributes if edge_feature_dim > 0 else None
|
||||
|
||||
# Transform inputs (no translation for velocity)
|
||||
pos_rot = pos @ rotation.T + translation
|
||||
vel_rot = vel @ rotation.T if use_velocity else vel
|
||||
|
||||
# Get model outputs
|
||||
out1 = model(
|
||||
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
|
||||
)
|
||||
out2 = model(
|
||||
x=x, pos=pos_rot, edge_index=edge_idx, edge_attr=edge_attr, vel=vel_rot
|
||||
)
|
||||
|
||||
# Transform model output
|
||||
pos1_transformed = (pos1 @ rotation.T) + translation
|
||||
# Unpack outputs
|
||||
h1, pos1, *other1 = out1
|
||||
h2, pos2, *other2 = out2
|
||||
if use_velocity:
|
||||
vel1, vel2 = other1[0], other2[0]
|
||||
|
||||
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)
|
||||
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
|
||||
assert torch.allclose(h1, h2, atol=1e-5)
|
||||
if vel is not None:
|
||||
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
|
||||
|
||||
132
tests/test_messagepassing/test_equivariant_operator_block.py
Normal file
132
tests/test_messagepassing/test_equivariant_operator_block.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina.model.block.message_passing import EquivariantGraphNeuralOperatorBlock
|
||||
|
||||
# Data for testing. Shapes: (time, nodes, features)
|
||||
x = torch.rand(5, 10, 4)
|
||||
pos = torch.rand(5, 10, 3)
|
||||
vel = torch.rand(5, 10, 3)
|
||||
|
||||
# Edge index and attributes
|
||||
edge_idx = torch.randint(0, 10, (2, 20))
|
||||
edge_attributes = torch.randn(20, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
@pytest.mark.parametrize("pos_dim", [2, 3])
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, modes):
|
||||
|
||||
EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
modes=modes,
|
||||
)
|
||||
|
||||
# Should fail if modes is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
modes=-1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_forward(modes):
|
||||
|
||||
model = EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=x.shape[2],
|
||||
edge_feature_dim=edge_attributes.shape[1],
|
||||
pos_dim=pos.shape[2],
|
||||
modes=modes,
|
||||
)
|
||||
|
||||
output_ = model(
|
||||
x=x,
|
||||
pos=pos,
|
||||
vel=vel,
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes,
|
||||
)
|
||||
|
||||
# Checks on output shapes
|
||||
assert output_[0].shape == x.shape
|
||||
assert output_[1].shape == pos.shape
|
||||
assert output_[2].shape == vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_backward(modes):
|
||||
|
||||
model = EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=x.shape[2],
|
||||
edge_feature_dim=edge_attributes.shape[1],
|
||||
pos_dim=pos.shape[2],
|
||||
modes=modes,
|
||||
)
|
||||
|
||||
output_ = model(
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
vel=vel.requires_grad_(),
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes.requires_grad_(),
|
||||
)
|
||||
|
||||
# Checks on gradients
|
||||
loss = sum(torch.mean(output_[i]) for i in range(len(output_)))
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
assert pos.grad.shape == pos.shape
|
||||
assert vel.grad.shape == vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_equivariance(modes):
|
||||
|
||||
# Random rotation
|
||||
rotation = torch.linalg.qr(torch.rand(pos.shape[2], pos.shape[2])).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
|
||||
# Random translation
|
||||
translation = torch.rand(1, pos.shape[2])
|
||||
|
||||
model = EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=x.shape[2],
|
||||
edge_feature_dim=edge_attributes.shape[1],
|
||||
pos_dim=pos.shape[2],
|
||||
modes=modes,
|
||||
).eval()
|
||||
|
||||
# Transform inputs (no translation for velocity)
|
||||
pos_rot = pos @ rotation.T + translation
|
||||
vel_rot = vel @ rotation.T
|
||||
|
||||
# Get model outputs
|
||||
out1 = model(
|
||||
x=x,
|
||||
pos=pos,
|
||||
vel=vel,
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes,
|
||||
)
|
||||
out2 = model(
|
||||
x=x,
|
||||
pos=pos_rot,
|
||||
vel=vel_rot,
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes,
|
||||
)
|
||||
|
||||
# Unpack outputs
|
||||
h1, pos1, vel1 = out1
|
||||
h2, pos2, vel2 = out2
|
||||
|
||||
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
|
||||
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
|
||||
assert torch.allclose(h1, h2, atol=1e-5)
|
||||
194
tests/test_model/test_equivariant_graph_neural_operator.py
Normal file
194
tests/test_model/test_equivariant_graph_neural_operator.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import pytest
|
||||
import torch
|
||||
import copy
|
||||
from pina.model import EquivariantGraphNeuralOperator
|
||||
from pina.graph import Graph
|
||||
|
||||
|
||||
# Utility to create graphs
|
||||
def make_graph(include_vel=True, use_edge_attr=True):
|
||||
data = dict(
|
||||
x=torch.rand(10, 4),
|
||||
pos=torch.rand(10, 3),
|
||||
edge_index=torch.randint(0, 10, (2, 20)),
|
||||
edge_attr=torch.randn(20, 2) if use_edge_attr else None,
|
||||
)
|
||||
if include_vel:
|
||||
data["vel"] = torch.rand(10, 3)
|
||||
return Graph(**data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 3])
|
||||
@pytest.mark.parametrize("time_emb_dim", [4, 8])
|
||||
@pytest.mark.parametrize("max_time_idx", [10, 20])
|
||||
def test_constructor(n_egno_layers, time_steps, time_emb_dim, max_time_idx):
|
||||
|
||||
# Create graph and model
|
||||
graph = make_graph()
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
# Should fail if n_egno_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=-1,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
# Should fail if time_steps is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=-1,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
# Should fail if max_time_idx is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=-1,
|
||||
)
|
||||
|
||||
# Should fail if time_emb_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=-1,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 5])
|
||||
@pytest.mark.parametrize("modes", [1, 3, 10])
|
||||
@pytest.mark.parametrize("use_edge_attr", [True, False])
|
||||
def test_forward(n_egno_layers, time_steps, modes, use_edge_attr):
|
||||
|
||||
# Create graph and model
|
||||
graph = make_graph(use_edge_attr=use_edge_attr)
|
||||
model = EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=modes,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
# Checks on output shapes
|
||||
output_ = model(graph)
|
||||
assert output_.x.shape == (time_steps, *graph.x.shape)
|
||||
assert output_.pos.shape == (time_steps, *graph.pos.shape)
|
||||
assert output_.vel.shape == (time_steps, *graph.vel.shape)
|
||||
|
||||
# Should fail graph has no vel attribute
|
||||
with pytest.raises(ValueError):
|
||||
graph_no_vel = make_graph(include_vel=False)
|
||||
model(graph_no_vel)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 5])
|
||||
@pytest.mark.parametrize("modes", [1, 3, 10])
|
||||
@pytest.mark.parametrize("use_edge_attr", [True, False])
|
||||
def test_backward(n_egno_layers, time_steps, modes, use_edge_attr):
|
||||
|
||||
# Create graph and model
|
||||
graph = make_graph(use_edge_attr=use_edge_attr)
|
||||
model = EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=modes,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
# Set requires_grad and perform forward pass
|
||||
graph.x.requires_grad_()
|
||||
graph.pos.requires_grad_()
|
||||
graph.vel.requires_grad_()
|
||||
out = model(graph)
|
||||
|
||||
# Checks on gradients
|
||||
loss = torch.mean(out.x) + torch.mean(out.pos) + torch.mean(out.vel)
|
||||
loss.backward()
|
||||
assert graph.x.grad.shape == graph.x.shape
|
||||
assert graph.pos.grad.shape == graph.pos.shape
|
||||
assert graph.vel.grad.shape == graph.vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 5])
|
||||
@pytest.mark.parametrize("modes", [1, 3, 10])
|
||||
@pytest.mark.parametrize("use_edge_attr", [True, False])
|
||||
def test_equivariance(n_egno_layers, time_steps, modes, use_edge_attr):
|
||||
|
||||
graph = make_graph(use_edge_attr=use_edge_attr)
|
||||
model = EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=modes,
|
||||
time_steps=time_steps,
|
||||
).eval()
|
||||
|
||||
# Random rotation
|
||||
rotation = torch.linalg.qr(
|
||||
torch.rand(graph.pos.shape[1], graph.pos.shape[1])
|
||||
).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
|
||||
# Random translation
|
||||
translation = torch.rand(1, graph.pos.shape[1])
|
||||
|
||||
# Transform graph (no translation for velocity)
|
||||
graph_rot = copy.deepcopy(graph)
|
||||
graph_rot.pos = graph.pos @ rotation.T + translation
|
||||
graph_rot.vel = graph.vel @ rotation.T
|
||||
|
||||
# Get model outputs
|
||||
out1 = model(graph)
|
||||
out2 = model(graph_rot)
|
||||
|
||||
# Unpack outputs
|
||||
h1, pos1, vel1 = out1.x, out1.pos, out1.vel
|
||||
h2, pos2, vel2 = out2.x, out2.pos, out2.vel
|
||||
|
||||
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
|
||||
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
|
||||
assert torch.allclose(h1, h2, atol=1e-5)
|
||||
Reference in New Issue
Block a user