Implement Graph Neural Operator #231

This commit is contained in:
FilippoOlivo
2025-02-04 18:11:06 +01:00
committed by Nicola Demo
parent e63c3d9061
commit 86fe41261b
4 changed files with 259 additions and 0 deletions

View File

@@ -10,6 +10,7 @@ __all__ = [
"AveragingNeuralOperator", "AveragingNeuralOperator",
"LowRankNeuralOperator", "LowRankNeuralOperator",
"Spline", "Spline",
"GNO"
] ]
from .feed_forward import FeedForward, ResidualFeedForward from .feed_forward import FeedForward, ResidualFeedForward
@@ -20,3 +21,4 @@ from .base_no import KernelNeuralOperator
from .avno import AveragingNeuralOperator from .avno import AveragingNeuralOperator
from .lno import LowRankNeuralOperator from .lno import LowRankNeuralOperator
from .spline import Spline from .spline import Spline
from .gno import GNO

173
pina/model/gno.py Normal file
View File

@@ -0,0 +1,173 @@
import torch
from torch.nn import Tanh
from .layers import GraphIntegralLayer
from .base_no import KernelNeuralOperator
class GraphNeuralKernel(torch.nn.Module):
"""
TODO add docstring
"""
def __init__(
self,
width,
edge_features,
n_layers=2,
internal_n_layers=0,
internal_layers=None,
internal_func=None,
external_func=None,
shared_weights=False
):
"""
The Graph Neural Kernel constructor.
:param width: The width of the kernel.
:type width: int
:param edge_features: The number of edge features.
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
:param external_func: The activation function applied to the output of the Graph Integral Layer.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
"""
super().__init__()
if external_func is None:
external_func = Tanh
if internal_func is None:
internal_func = Tanh
if shared_weights:
self.layers = GraphIntegralLayer(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
layers=internal_layers,
internal_func=internal_func,
external_func=external_func)
self.n_layers = n_layers
self.forward = self.forward_shared
else:
self.layers = torch.nn.ModuleList(
[GraphIntegralLayer(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
layers=internal_layers,
internal_func=internal_func,
external_func=external_func
)
for _ in range(n_layers)]
)
def forward(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Neural Kernel used when the weights are not shared.
:param x: The input batch.
:type x: torch.Tensor
:param edge_index: The edge index.
:type edge_index: torch.Tensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor
"""
for layer in self.layers:
x = layer(x, edge_index, edge_attr)
return x
def forward_shared(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Neural Kernel used when the weights are shared.
:param x: The input batch.
:type x: torch.Tensor
:param edge_index: The edge index.
:type edge_index: torch.Tensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor
"""
for _ in range(self.n_layers):
x = self.layers(x, edge_index, edge_attr)
return x
class GNO(KernelNeuralOperator):
"""
TODO add docstring
"""
def __init__(
self,
lifting_operator,
projection_operator,
edge_features,
n_layers=10,
internal_n_layers=0,
inner_size=None,
internal_layers=None,
internal_func=None,
external_func=None,
shared_weights=True
):
"""
The Graph Neural Operator constructor.
:param lifting_operator: The lifting operator mapping the node features to its hidden dimension.
:type lifting_operator: torch.nn.Module
:param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function.
:type projection_operator: torch.nn.Module
:param edge_features: Number of edge features.
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
:type internal_func: torch.nn.Module
:param external_func: The activation function applied to the output of the Graph Integral Kernel.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
:type shared_weights: bool
"""
if internal_func is None:
internal_func = Tanh
if external_func is None:
external_func = Tanh
super().__init__(
lifting_operator=lifting_operator,
integral_kernels=GraphNeuralKernel(
width=lifting_operator.out_features,
edge_features=edge_features,
internal_n_layers=internal_n_layers,
internal_layers=internal_layers,
external_func=external_func,
internal_func=internal_func,
n_layers=n_layers,
shared_weights=shared_weights
),
projection_operator=projection_operator
)
def forward(self, x):
"""
The forward pass of the Graph Neural Operator.
:param x: The input batch.
:type x: torch_geometric.data.Batch
"""
x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr
x = self.lifting_operator(x)
x = self.integral_kernels(x, edge_index, edge_attr)
x = self.projection_operator(x)
return x

View File

@@ -15,6 +15,7 @@ __all__ = [
"AVNOBlock", "AVNOBlock",
"LowRankBlock", "LowRankBlock",
"RBFBlock", "RBFBlock",
"GraphIntegralLayer"
] ]
from .convolution_2d import ContinuousConvBlock from .convolution_2d import ContinuousConvBlock
@@ -31,3 +32,4 @@ from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding
from .avno_layer import AVNOBlock from .avno_layer import AVNOBlock
from .lowrank_layer import LowRankBlock from .lowrank_layer import LowRankBlock
from .rbf_layer import RBFBlock from .rbf_layer import RBFBlock
from .graph_integral_kernel import GraphIntegralLayer

View File

@@ -0,0 +1,82 @@
import torch
from torch_geometric.nn import MessagePassing
class GraphIntegralLayer(MessagePassing):
"""
TODO: Add documentation
"""
def __init__(
self,
width,
edges_features,
n_layers=0,
layers=None,
internal_func=None,
external_func=None
):
"""
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric.
:param width: The width of the hidden representation of the nodes features
:type width: int
:param edges_features: The number of edge features.
:type edges_features: int
:param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features.
:type n_layers: int
"""
from pina.model import FeedForward
super(GraphIntegralLayer, self).__init__(aggr='mean')
self.width = width
self.dense = FeedForward(input_dimensions=edges_features,
output_dimensions=width ** 2,
n_layers=n_layers,
layers=layers,
func=internal_func)
self.W = torch.nn.Linear(width, width)
self.func = external_func()
def message(self, x_j, edge_attr):
"""
This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class.
:param x_j: The node features of the neighboring.
:type x_j: torch.Tensor
:param edge_attr: The edge features.
:type edge_attr: torch.Tensor
:return: The message passed between the nodes of the graph.
:rtype: torch.Tensor
"""
x = self.dense(edge_attr).view(-1, self.width, self.width)
return torch.einsum('bij,bj->bi', x, x_j)
def update(self, aggr_out, x):
"""
This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class.
:param aggr_out: The aggregated messages.
:type aggr_out: torch.Tensor
:param x: The node features.
:type x: torch.Tensor
:return: The updated node features.
:rtype: torch.Tensor
"""
aggr_out = aggr_out + self.W(x)
return aggr_out
def forward(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Integral Layer.
:param x: Node features.
:type x: torch.Tensor
:param edge_index: Edge index.
:type edge_index: torch.Tensor
:param edge_attr: Edge features.
:type edge_attr: torch.Tensor
:return: Output of a single iteration over the Graph Integral Layer.
:rtype: torch.Tensor
"""
return self.func(
self.propagate(edge_index, x=x, edge_attr=edge_attr)
)