fix doc model part 2

This commit is contained in:
giovanni
2025-03-14 16:07:08 +01:00
committed by FilippoOlivo
parent 194f5d24c4
commit 28d24f3f41
18 changed files with 887 additions and 851 deletions

View File

@@ -1,5 +1,5 @@
"""
Module containing the Graph Integral Layer class.
Module for the Graph Neural Operator Block class.
"""
import torch
@@ -8,7 +8,7 @@ from torch_geometric.nn import MessagePassing
class GNOBlock(MessagePassing):
"""
Graph Neural Operator (GNO) Block using PyG MessagePassing.
The inner block of the Graph Neural Operator, based on Message Passing.
"""
def __init__(
@@ -22,11 +22,22 @@ class GNOBlock(MessagePassing):
external_func=None,
):
"""
Initialize the GNOBlock.
Initialization of the :class:`GNOBlock` class.
:param width: Hidden dimension of node features.
:param edges_features: Number of edge features.
:param n_layers: Number of layers in edge transformation MLP.
:param int width: The width of the kernel.
:param int edge_features: The number of edge features.
:param int n_layers: The number of kernel layers. Default is ``2``.
:param layers: A list specifying the number of neurons for each layer
of the neural network. If not ``None``, it overrides the
``inner_size`` and ``n_layers``parameters. Default is ``None``.
:type layers: list[int] | tuple[int]
:param int inner_size: The size of the inner layer. Default is ``None``.
:param torch.nn.Module internal_func: The activation function applied to
the output of each layer. If ``None``, it uses the
:class:`torch.nn.Tanh` activation. Default is ``None``.
:param torch.nn.Module external_func: The activation function applied to
the output of the block. If ``None``, it uses the
:class:`torch.nn.Tanh`. activation. Default is ``None``.
"""
from ...model.feed_forward import FeedForward
@@ -51,12 +62,13 @@ class GNOBlock(MessagePassing):
def message_and_aggregate(self, edge_index, x, edge_attr):
"""
Combines message and aggregation.
Combine messages and perform aggregation.
:param edge_index: COO format edge indices.
:param x: Node feature matrix [num_nodes, width].
:param edge_attr: Edge features [num_edges, edge_dim].
:return: Aggregated messages.
:param torch.Tensor edge_index: The edge index.
:param torch.Tensor x: The node feature matrix.
:param torch.Tensor edge_attr: The edge features.
:return: The aggregated messages.
:rtype: torch.Tensor
"""
# Edge features are transformed into a matrix of shape
# [num_edges, width, width]
@@ -68,27 +80,33 @@ class GNOBlock(MessagePassing):
def edge_update(self, edge_attr):
"""
Updates edge features.
Update edge features.
:param torch.Tensor edge_attr: The edge features.
:return: The updated edge features.
:rtype: torch.Tensor
"""
return edge_attr
def update(self, aggr_out, x):
"""
Updates node features.
Update node features.
:param aggr_out: Aggregated messages.
:param x: Node feature matrix.
:return: Updated node features.
:param torch.Tensor aggr_out: The aggregated messages.
:param torch.Tensor x: The node feature matrix.
:return: The updated node features.
:rtype: torch.Tensor
"""
return aggr_out + self.W(x)
def forward(self, x, edge_index, edge_attr):
"""
Forward pass of the GNOBlock.
Forward pass of the block.
:param x: Node features.
:param edge_index: Edge indices.
:param edge_attr: Edge features.
:return: Updated node features.
:param torch.Tensor x: The node features.
:param torch.Tensor edge_index: The edge indeces.
:param torch.Tensor edge_attr: The edge features.
:return: The updated node features.
:rtype: torch.Tensor
"""
return self.func(self.propagate(edge_index, x=x, edge_attr=edge_attr))