Implement Graph Neural Operator #231
This commit is contained in:
committed by
Nicola Demo
parent
e63c3d9061
commit
86fe41261b
82
pina/model/layers/graph_integral_kernel.py
Normal file
82
pina/model/layers/graph_integral_kernel.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
|
||||
|
||||
class GraphIntegralLayer(MessagePassing):
|
||||
"""
|
||||
TODO: Add documentation
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
edges_features,
|
||||
n_layers=0,
|
||||
layers=None,
|
||||
internal_func=None,
|
||||
external_func=None
|
||||
):
|
||||
"""
|
||||
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric.
|
||||
|
||||
:param width: The width of the hidden representation of the nodes features
|
||||
:type width: int
|
||||
:param edges_features: The number of edge features.
|
||||
:type edges_features: int
|
||||
:param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features.
|
||||
:type n_layers: int
|
||||
"""
|
||||
from pina.model import FeedForward
|
||||
super(GraphIntegralLayer, self).__init__(aggr='mean')
|
||||
self.width = width
|
||||
self.dense = FeedForward(input_dimensions=edges_features,
|
||||
output_dimensions=width ** 2,
|
||||
n_layers=n_layers,
|
||||
layers=layers,
|
||||
func=internal_func)
|
||||
self.W = torch.nn.Linear(width, width)
|
||||
self.func = external_func()
|
||||
|
||||
def message(self, x_j, edge_attr):
|
||||
"""
|
||||
This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class.
|
||||
|
||||
:param x_j: The node features of the neighboring.
|
||||
:type x_j: torch.Tensor
|
||||
:param edge_attr: The edge features.
|
||||
:type edge_attr: torch.Tensor
|
||||
:return: The message passed between the nodes of the graph.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
x = self.dense(edge_attr).view(-1, self.width, self.width)
|
||||
return torch.einsum('bij,bj->bi', x, x_j)
|
||||
|
||||
def update(self, aggr_out, x):
|
||||
"""
|
||||
This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class.
|
||||
|
||||
:param aggr_out: The aggregated messages.
|
||||
:type aggr_out: torch.Tensor
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
aggr_out = aggr_out + self.W(x)
|
||||
return aggr_out
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
"""
|
||||
The forward pass of the Graph Integral Layer.
|
||||
|
||||
:param x: Node features.
|
||||
:type x: torch.Tensor
|
||||
:param edge_index: Edge index.
|
||||
:type edge_index: torch.Tensor
|
||||
:param edge_attr: Edge features.
|
||||
:type edge_attr: torch.Tensor
|
||||
:return: Output of a single iteration over the Graph Integral Layer.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.func(
|
||||
self.propagate(edge_index, x=x, edge_attr=edge_attr)
|
||||
)
|
||||
Reference in New Issue
Block a user