fix doc model part 2
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Module containing the Graph Integral Layer class.
|
||||
Module for the Graph Neural Operator Block class.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -8,7 +8,7 @@ from torch_geometric.nn import MessagePassing
|
||||
|
||||
class GNOBlock(MessagePassing):
|
||||
"""
|
||||
Graph Neural Operator (GNO) Block using PyG MessagePassing.
|
||||
The inner block of the Graph Neural Operator, based on Message Passing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -22,11 +22,22 @@ class GNOBlock(MessagePassing):
|
||||
external_func=None,
|
||||
):
|
||||
"""
|
||||
Initialize the GNOBlock.
|
||||
Initialization of the :class:`GNOBlock` class.
|
||||
|
||||
:param width: Hidden dimension of node features.
|
||||
:param edges_features: Number of edge features.
|
||||
:param n_layers: Number of layers in edge transformation MLP.
|
||||
:param int width: The width of the kernel.
|
||||
:param int edge_features: The number of edge features.
|
||||
:param int n_layers: The number of kernel layers. Default is ``2``.
|
||||
:param layers: A list specifying the number of neurons for each layer
|
||||
of the neural network. If not ``None``, it overrides the
|
||||
``inner_size`` and ``n_layers``parameters. Default is ``None``.
|
||||
:type layers: list[int] | tuple[int]
|
||||
:param int inner_size: The size of the inner layer. Default is ``None``.
|
||||
:param torch.nn.Module internal_func: The activation function applied to
|
||||
the output of each layer. If ``None``, it uses the
|
||||
:class:`torch.nn.Tanh` activation. Default is ``None``.
|
||||
:param torch.nn.Module external_func: The activation function applied to
|
||||
the output of the block. If ``None``, it uses the
|
||||
:class:`torch.nn.Tanh`. activation. Default is ``None``.
|
||||
"""
|
||||
|
||||
from ...model.feed_forward import FeedForward
|
||||
@@ -51,12 +62,13 @@ class GNOBlock(MessagePassing):
|
||||
|
||||
def message_and_aggregate(self, edge_index, x, edge_attr):
|
||||
"""
|
||||
Combines message and aggregation.
|
||||
Combine messages and perform aggregation.
|
||||
|
||||
:param edge_index: COO format edge indices.
|
||||
:param x: Node feature matrix [num_nodes, width].
|
||||
:param edge_attr: Edge features [num_edges, edge_dim].
|
||||
:return: Aggregated messages.
|
||||
:param torch.Tensor edge_index: The edge index.
|
||||
:param torch.Tensor x: The node feature matrix.
|
||||
:param torch.Tensor edge_attr: The edge features.
|
||||
:return: The aggregated messages.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Edge features are transformed into a matrix of shape
|
||||
# [num_edges, width, width]
|
||||
@@ -68,27 +80,33 @@ class GNOBlock(MessagePassing):
|
||||
|
||||
def edge_update(self, edge_attr):
|
||||
"""
|
||||
Updates edge features.
|
||||
Update edge features.
|
||||
|
||||
:param torch.Tensor edge_attr: The edge features.
|
||||
:return: The updated edge features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return edge_attr
|
||||
|
||||
def update(self, aggr_out, x):
|
||||
"""
|
||||
Updates node features.
|
||||
Update node features.
|
||||
|
||||
:param aggr_out: Aggregated messages.
|
||||
:param x: Node feature matrix.
|
||||
:return: Updated node features.
|
||||
:param torch.Tensor aggr_out: The aggregated messages.
|
||||
:param torch.Tensor x: The node feature matrix.
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return aggr_out + self.W(x)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
"""
|
||||
Forward pass of the GNOBlock.
|
||||
Forward pass of the block.
|
||||
|
||||
:param x: Node features.
|
||||
:param edge_index: Edge indices.
|
||||
:param edge_attr: Edge features.
|
||||
:return: Updated node features.
|
||||
:param torch.Tensor x: The node features.
|
||||
:param torch.Tensor edge_index: The edge indeces.
|
||||
:param torch.Tensor edge_attr: The edge features.
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.func(self.propagate(edge_index, x=x, edge_attr=edge_attr))
|
||||
|
||||
Reference in New Issue
Block a user