Implement Graph Neural Operator #231
This commit is contained in:
committed by
Nicola Demo
parent
e63c3d9061
commit
86fe41261b
@@ -10,6 +10,7 @@ __all__ = [
|
|||||||
"AveragingNeuralOperator",
|
"AveragingNeuralOperator",
|
||||||
"LowRankNeuralOperator",
|
"LowRankNeuralOperator",
|
||||||
"Spline",
|
"Spline",
|
||||||
|
"GNO"
|
||||||
]
|
]
|
||||||
|
|
||||||
from .feed_forward import FeedForward, ResidualFeedForward
|
from .feed_forward import FeedForward, ResidualFeedForward
|
||||||
@@ -20,3 +21,4 @@ from .base_no import KernelNeuralOperator
|
|||||||
from .avno import AveragingNeuralOperator
|
from .avno import AveragingNeuralOperator
|
||||||
from .lno import LowRankNeuralOperator
|
from .lno import LowRankNeuralOperator
|
||||||
from .spline import Spline
|
from .spline import Spline
|
||||||
|
from .gno import GNO
|
||||||
173
pina/model/gno.py
Normal file
173
pina/model/gno.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import Tanh
|
||||||
|
from .layers import GraphIntegralLayer
|
||||||
|
from .base_no import KernelNeuralOperator
|
||||||
|
|
||||||
|
|
||||||
|
class GraphNeuralKernel(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
TODO add docstring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width,
|
||||||
|
edge_features,
|
||||||
|
n_layers=2,
|
||||||
|
internal_n_layers=0,
|
||||||
|
internal_layers=None,
|
||||||
|
internal_func=None,
|
||||||
|
external_func=None,
|
||||||
|
shared_weights=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The Graph Neural Kernel constructor.
|
||||||
|
|
||||||
|
:param width: The width of the kernel.
|
||||||
|
:type width: int
|
||||||
|
:param edge_features: The number of edge features.
|
||||||
|
:type edge_features: int
|
||||||
|
:param n_layers: The number of kernel layers.
|
||||||
|
:type n_layers: int
|
||||||
|
:param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer.
|
||||||
|
:type internal_n_layers: int
|
||||||
|
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
|
||||||
|
:type internal_layers: list | tuple
|
||||||
|
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
|
||||||
|
:param external_func: The activation function applied to the output of the Graph Integral Layer.
|
||||||
|
:type external_func: torch.nn.Module
|
||||||
|
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if external_func is None:
|
||||||
|
external_func = Tanh
|
||||||
|
if internal_func is None:
|
||||||
|
internal_func = Tanh
|
||||||
|
|
||||||
|
if shared_weights:
|
||||||
|
self.layers = GraphIntegralLayer(
|
||||||
|
width=width,
|
||||||
|
edges_features=edge_features,
|
||||||
|
n_layers=internal_n_layers,
|
||||||
|
layers=internal_layers,
|
||||||
|
internal_func=internal_func,
|
||||||
|
external_func=external_func)
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.forward = self.forward_shared
|
||||||
|
else:
|
||||||
|
self.layers = torch.nn.ModuleList(
|
||||||
|
[GraphIntegralLayer(
|
||||||
|
width=width,
|
||||||
|
edges_features=edge_features,
|
||||||
|
n_layers=internal_n_layers,
|
||||||
|
layers=internal_layers,
|
||||||
|
internal_func=internal_func,
|
||||||
|
external_func=external_func
|
||||||
|
)
|
||||||
|
for _ in range(n_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, edge_index, edge_attr):
|
||||||
|
"""
|
||||||
|
The forward pass of the Graph Neural Kernel used when the weights are not shared.
|
||||||
|
|
||||||
|
:param x: The input batch.
|
||||||
|
:type x: torch.Tensor
|
||||||
|
:param edge_index: The edge index.
|
||||||
|
:type edge_index: torch.Tensor
|
||||||
|
:param edge_attr: The edge attributes.
|
||||||
|
:type edge_attr: torch.Tensor
|
||||||
|
"""
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, edge_index, edge_attr)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_shared(self, x, edge_index, edge_attr):
|
||||||
|
"""
|
||||||
|
The forward pass of the Graph Neural Kernel used when the weights are shared.
|
||||||
|
|
||||||
|
:param x: The input batch.
|
||||||
|
:type x: torch.Tensor
|
||||||
|
:param edge_index: The edge index.
|
||||||
|
:type edge_index: torch.Tensor
|
||||||
|
:param edge_attr: The edge attributes.
|
||||||
|
:type edge_attr: torch.Tensor
|
||||||
|
"""
|
||||||
|
for _ in range(self.n_layers):
|
||||||
|
x = self.layers(x, edge_index, edge_attr)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GNO(KernelNeuralOperator):
|
||||||
|
"""
|
||||||
|
TODO add docstring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lifting_operator,
|
||||||
|
projection_operator,
|
||||||
|
edge_features,
|
||||||
|
n_layers=10,
|
||||||
|
internal_n_layers=0,
|
||||||
|
inner_size=None,
|
||||||
|
internal_layers=None,
|
||||||
|
internal_func=None,
|
||||||
|
external_func=None,
|
||||||
|
shared_weights=True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The Graph Neural Operator constructor.
|
||||||
|
|
||||||
|
:param lifting_operator: The lifting operator mapping the node features to its hidden dimension.
|
||||||
|
:type lifting_operator: torch.nn.Module
|
||||||
|
:param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function.
|
||||||
|
:type projection_operator: torch.nn.Module
|
||||||
|
:param edge_features: Number of edge features.
|
||||||
|
:type edge_features: int
|
||||||
|
:param n_layers: The number of kernel layers.
|
||||||
|
:type n_layers: int
|
||||||
|
:param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer.
|
||||||
|
:type internal_n_layers: int
|
||||||
|
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
|
||||||
|
:type internal_layers: list | tuple
|
||||||
|
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
|
||||||
|
:type internal_func: torch.nn.Module
|
||||||
|
:param external_func: The activation function applied to the output of the Graph Integral Kernel.
|
||||||
|
:type external_func: torch.nn.Module
|
||||||
|
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
|
||||||
|
:type shared_weights: bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
if internal_func is None:
|
||||||
|
internal_func = Tanh
|
||||||
|
if external_func is None:
|
||||||
|
external_func = Tanh
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
lifting_operator=lifting_operator,
|
||||||
|
integral_kernels=GraphNeuralKernel(
|
||||||
|
width=lifting_operator.out_features,
|
||||||
|
edge_features=edge_features,
|
||||||
|
internal_n_layers=internal_n_layers,
|
||||||
|
internal_layers=internal_layers,
|
||||||
|
external_func=external_func,
|
||||||
|
internal_func=internal_func,
|
||||||
|
n_layers=n_layers,
|
||||||
|
shared_weights=shared_weights
|
||||||
|
),
|
||||||
|
projection_operator=projection_operator
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
The forward pass of the Graph Neural Operator.
|
||||||
|
|
||||||
|
:param x: The input batch.
|
||||||
|
:type x: torch_geometric.data.Batch
|
||||||
|
"""
|
||||||
|
x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr
|
||||||
|
x = self.lifting_operator(x)
|
||||||
|
x = self.integral_kernels(x, edge_index, edge_attr)
|
||||||
|
x = self.projection_operator(x)
|
||||||
|
return x
|
||||||
@@ -15,6 +15,7 @@ __all__ = [
|
|||||||
"AVNOBlock",
|
"AVNOBlock",
|
||||||
"LowRankBlock",
|
"LowRankBlock",
|
||||||
"RBFBlock",
|
"RBFBlock",
|
||||||
|
"GraphIntegralLayer"
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
@@ -31,3 +32,4 @@ from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding
|
|||||||
from .avno_layer import AVNOBlock
|
from .avno_layer import AVNOBlock
|
||||||
from .lowrank_layer import LowRankBlock
|
from .lowrank_layer import LowRankBlock
|
||||||
from .rbf_layer import RBFBlock
|
from .rbf_layer import RBFBlock
|
||||||
|
from .graph_integral_kernel import GraphIntegralLayer
|
||||||
|
|||||||
82
pina/model/layers/graph_integral_kernel.py
Normal file
82
pina/model/layers/graph_integral_kernel.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import torch
|
||||||
|
from torch_geometric.nn import MessagePassing
|
||||||
|
|
||||||
|
|
||||||
|
class GraphIntegralLayer(MessagePassing):
|
||||||
|
"""
|
||||||
|
TODO: Add documentation
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width,
|
||||||
|
edges_features,
|
||||||
|
n_layers=0,
|
||||||
|
layers=None,
|
||||||
|
internal_func=None,
|
||||||
|
external_func=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric.
|
||||||
|
|
||||||
|
:param width: The width of the hidden representation of the nodes features
|
||||||
|
:type width: int
|
||||||
|
:param edges_features: The number of edge features.
|
||||||
|
:type edges_features: int
|
||||||
|
:param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features.
|
||||||
|
:type n_layers: int
|
||||||
|
"""
|
||||||
|
from pina.model import FeedForward
|
||||||
|
super(GraphIntegralLayer, self).__init__(aggr='mean')
|
||||||
|
self.width = width
|
||||||
|
self.dense = FeedForward(input_dimensions=edges_features,
|
||||||
|
output_dimensions=width ** 2,
|
||||||
|
n_layers=n_layers,
|
||||||
|
layers=layers,
|
||||||
|
func=internal_func)
|
||||||
|
self.W = torch.nn.Linear(width, width)
|
||||||
|
self.func = external_func()
|
||||||
|
|
||||||
|
def message(self, x_j, edge_attr):
|
||||||
|
"""
|
||||||
|
This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class.
|
||||||
|
|
||||||
|
:param x_j: The node features of the neighboring.
|
||||||
|
:type x_j: torch.Tensor
|
||||||
|
:param edge_attr: The edge features.
|
||||||
|
:type edge_attr: torch.Tensor
|
||||||
|
:return: The message passed between the nodes of the graph.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
x = self.dense(edge_attr).view(-1, self.width, self.width)
|
||||||
|
return torch.einsum('bij,bj->bi', x, x_j)
|
||||||
|
|
||||||
|
def update(self, aggr_out, x):
|
||||||
|
"""
|
||||||
|
This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class.
|
||||||
|
|
||||||
|
:param aggr_out: The aggregated messages.
|
||||||
|
:type aggr_out: torch.Tensor
|
||||||
|
:param x: The node features.
|
||||||
|
:type x: torch.Tensor
|
||||||
|
:return: The updated node features.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
aggr_out = aggr_out + self.W(x)
|
||||||
|
return aggr_out
|
||||||
|
|
||||||
|
def forward(self, x, edge_index, edge_attr):
|
||||||
|
"""
|
||||||
|
The forward pass of the Graph Integral Layer.
|
||||||
|
|
||||||
|
:param x: Node features.
|
||||||
|
:type x: torch.Tensor
|
||||||
|
:param edge_index: Edge index.
|
||||||
|
:type edge_index: torch.Tensor
|
||||||
|
:param edge_attr: Edge features.
|
||||||
|
:type edge_attr: torch.Tensor
|
||||||
|
:return: Output of a single iteration over the Graph Integral Layer.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
return self.func(
|
||||||
|
self.propagate(edge_index, x=x, edge_attr=edge_attr)
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user