diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 3224d0a..0db0acd 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "AveragingNeuralOperator", "LowRankNeuralOperator", "Spline", + "GNO" ] from .feed_forward import FeedForward, ResidualFeedForward @@ -20,3 +21,4 @@ from .base_no import KernelNeuralOperator from .avno import AveragingNeuralOperator from .lno import LowRankNeuralOperator from .spline import Spline +from .gno import GNO \ No newline at end of file diff --git a/pina/model/gno.py b/pina/model/gno.py new file mode 100644 index 0000000..991ca39 --- /dev/null +++ b/pina/model/gno.py @@ -0,0 +1,173 @@ +import torch +from torch.nn import Tanh +from .layers import GraphIntegralLayer +from .base_no import KernelNeuralOperator + + +class GraphNeuralKernel(torch.nn.Module): + """ + TODO add docstring + """ + + def __init__( + self, + width, + edge_features, + n_layers=2, + internal_n_layers=0, + internal_layers=None, + internal_func=None, + external_func=None, + shared_weights=False + ): + """ + The Graph Neural Kernel constructor. + + :param width: The width of the kernel. + :type width: int + :param edge_features: The number of edge features. + :type edge_features: int + :param n_layers: The number of kernel layers. + :type n_layers: int + :param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer. + :type internal_n_layers: int + :param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. + :type internal_layers: list | tuple + :param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. + :param external_func: The activation function applied to the output of the Graph Integral Layer. + :type external_func: torch.nn.Module + :param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. + """ + super().__init__() + if external_func is None: + external_func = Tanh + if internal_func is None: + internal_func = Tanh + + if shared_weights: + self.layers = GraphIntegralLayer( + width=width, + edges_features=edge_features, + n_layers=internal_n_layers, + layers=internal_layers, + internal_func=internal_func, + external_func=external_func) + self.n_layers = n_layers + self.forward = self.forward_shared + else: + self.layers = torch.nn.ModuleList( + [GraphIntegralLayer( + width=width, + edges_features=edge_features, + n_layers=internal_n_layers, + layers=internal_layers, + internal_func=internal_func, + external_func=external_func + ) + for _ in range(n_layers)] + ) + + def forward(self, x, edge_index, edge_attr): + """ + The forward pass of the Graph Neural Kernel used when the weights are not shared. + + :param x: The input batch. + :type x: torch.Tensor + :param edge_index: The edge index. + :type edge_index: torch.Tensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor + """ + for layer in self.layers: + x = layer(x, edge_index, edge_attr) + return x + + def forward_shared(self, x, edge_index, edge_attr): + """ + The forward pass of the Graph Neural Kernel used when the weights are shared. + + :param x: The input batch. + :type x: torch.Tensor + :param edge_index: The edge index. + :type edge_index: torch.Tensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor + """ + for _ in range(self.n_layers): + x = self.layers(x, edge_index, edge_attr) + return x + + +class GNO(KernelNeuralOperator): + """ + TODO add docstring + """ + + def __init__( + self, + lifting_operator, + projection_operator, + edge_features, + n_layers=10, + internal_n_layers=0, + inner_size=None, + internal_layers=None, + internal_func=None, + external_func=None, + shared_weights=True + ): + """ + The Graph Neural Operator constructor. + + :param lifting_operator: The lifting operator mapping the node features to its hidden dimension. + :type lifting_operator: torch.nn.Module + :param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function. + :type projection_operator: torch.nn.Module + :param edge_features: Number of edge features. + :type edge_features: int + :param n_layers: The number of kernel layers. + :type n_layers: int + :param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer. + :type internal_n_layers: int + :param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. + :type internal_layers: list | tuple + :param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. + :type internal_func: torch.nn.Module + :param external_func: The activation function applied to the output of the Graph Integral Kernel. + :type external_func: torch.nn.Module + :param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. + :type shared_weights: bool + """ + + if internal_func is None: + internal_func = Tanh + if external_func is None: + external_func = Tanh + + super().__init__( + lifting_operator=lifting_operator, + integral_kernels=GraphNeuralKernel( + width=lifting_operator.out_features, + edge_features=edge_features, + internal_n_layers=internal_n_layers, + internal_layers=internal_layers, + external_func=external_func, + internal_func=internal_func, + n_layers=n_layers, + shared_weights=shared_weights + ), + projection_operator=projection_operator + ) + + def forward(self, x): + """ + The forward pass of the Graph Neural Operator. + + :param x: The input batch. + :type x: torch_geometric.data.Batch + """ + x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr + x = self.lifting_operator(x) + x = self.integral_kernels(x, edge_index, edge_attr) + x = self.projection_operator(x) + return x diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 5108522..50827dc 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -15,6 +15,7 @@ __all__ = [ "AVNOBlock", "LowRankBlock", "RBFBlock", + "GraphIntegralLayer" ] from .convolution_2d import ContinuousConvBlock @@ -31,3 +32,4 @@ from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding from .avno_layer import AVNOBlock from .lowrank_layer import LowRankBlock from .rbf_layer import RBFBlock +from .graph_integral_kernel import GraphIntegralLayer diff --git a/pina/model/layers/graph_integral_kernel.py b/pina/model/layers/graph_integral_kernel.py new file mode 100644 index 0000000..713b0d7 --- /dev/null +++ b/pina/model/layers/graph_integral_kernel.py @@ -0,0 +1,82 @@ +import torch +from torch_geometric.nn import MessagePassing + + +class GraphIntegralLayer(MessagePassing): + """ + TODO: Add documentation + """ + def __init__( + self, + width, + edges_features, + n_layers=0, + layers=None, + internal_func=None, + external_func=None + ): + """ + Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric. + + :param width: The width of the hidden representation of the nodes features + :type width: int + :param edges_features: The number of edge features. + :type edges_features: int + :param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features. + :type n_layers: int + """ + from pina.model import FeedForward + super(GraphIntegralLayer, self).__init__(aggr='mean') + self.width = width + self.dense = FeedForward(input_dimensions=edges_features, + output_dimensions=width ** 2, + n_layers=n_layers, + layers=layers, + func=internal_func) + self.W = torch.nn.Linear(width, width) + self.func = external_func() + + def message(self, x_j, edge_attr): + """ + This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class. + + :param x_j: The node features of the neighboring. + :type x_j: torch.Tensor + :param edge_attr: The edge features. + :type edge_attr: torch.Tensor + :return: The message passed between the nodes of the graph. + :rtype: torch.Tensor + """ + x = self.dense(edge_attr).view(-1, self.width, self.width) + return torch.einsum('bij,bj->bi', x, x_j) + + def update(self, aggr_out, x): + """ + This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class. + + :param aggr_out: The aggregated messages. + :type aggr_out: torch.Tensor + :param x: The node features. + :type x: torch.Tensor + :return: The updated node features. + :rtype: torch.Tensor + """ + aggr_out = aggr_out + self.W(x) + return aggr_out + + def forward(self, x, edge_index, edge_attr): + """ + The forward pass of the Graph Integral Layer. + + :param x: Node features. + :type x: torch.Tensor + :param edge_index: Edge index. + :type edge_index: torch.Tensor + :param edge_attr: Edge features. + :type edge_attr: torch.Tensor + :return: Output of a single iteration over the Graph Integral Layer. + :rtype: torch.Tensor + """ + return self.func( + self.propagate(edge_index, x=x, edge_attr=edge_attr) + )