diff --git a/pina/graph.py b/pina/graph.py index 74dc91f..3cd1322 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -63,8 +63,9 @@ class Graph(Data): :type pos: torch.Tensor | LabelTensor :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where ``F'`` is the number of edge features + :type edge_attr: torch.Tensor | LabelTensor :param bool undirected: Whether to make the graph undirected - :param kwargs: Additional keyword arguments passed to the + :param dict kwargs: Additional keyword arguments passed to the :class:`~torch_geometric.data.Data` class constructor. """ # preprocessing @@ -201,7 +202,7 @@ class Graph(Data): class GraphBuilder: """ - A class that allows the simple definition of Graph instances. + A class that allows an easy definition of :class:`Graph` instances. """ def __new__( @@ -217,25 +218,25 @@ class GraphBuilder: Compute the edge attributes and create a new instance of the :class:`~pina.graph.Graph` class. - :param pos: A tensor of shape `(N, D)` representing the positions of `N` - points in `D`-dimensional space. + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. :type pos: torch.Tensor or LabelTensor - :param edge_index: A tensor of shape `(2, E)` representing the indices + :param edge_index: A tensor of shape ``(2, E)`` representing the indices of the graph's edges. :type edge_index: torch.Tensor - :param x: Optional tensor of node features of shape `(N, F)`, where `F` - is the number of features per node. + :param x: Optional tensor of node features of shape ``(N, F)``, where + ``F`` is the number of features per node. :type x: torch.Tensor | LabelTensor, optional - :param edge_attr: Optional tensor of edge attributes of shape `(E, F)`, - where `F` is the number of features per edge. + :param edge_attr: Optional tensor of edge attributes of shape ``(E, F)`` + , where ``F`` is the number of features per edge. :type edge_attr: torch.Tensor, optional :param custom_edge_func: A custom function to compute edge attributes. - If provided, overrides `edge_attr`. + If provided, overrides ``edge_attr``. :type custom_edge_func: Callable, optional :param kwargs: Additional keyword arguments passed to the :class:`~pina.graph.Graph` class constructor. :return: A :class:`~pina.graph.Graph` instance constructed using the - provided information. + provided information. :rtype: Graph """ edge_attr = cls._create_edge_attr( @@ -274,6 +275,7 @@ class GraphBuilder: def _build_edge_attr(pos, edge_index): """ Default function to compute the edge attributes. + :param pos: Positions of the points. :type pos: torch.Tensor | LabelTensor :param torch.Tensor edge_index: Edge indices. @@ -289,14 +291,15 @@ class GraphBuilder: class RadiusGraph(GraphBuilder): """ - A class to build a graph based on a radius. + Extends the :class:`~pina.graph.GraphBuilder` class to compute + edge_index based on a radius. Each point is connected to all the points + within the radius. """ def __new__(cls, pos, radius, **kwargs): """ - Extends the :class:`~pina.graph.GraphBuilder` class to compute - edge_index based on a radius. Each point is connected to all the points - within the radius. + Instantiate the :class:`~pina.graph.Graph` class by computing the + ``edge_index`` based on the radius provided. :param pos: A tensor of shape ``(N, D)`` representing the positions of ``N`` points in ``D``-dimensional space. @@ -336,13 +339,14 @@ class RadiusGraph(GraphBuilder): class KNNGraph(GraphBuilder): """ - A class to build a K-nearest neighbors graph. + Extends the :class:`~pina.graph.GraphBuilder` class to compute + edge_index based on a K-nearest neighbors algorithm. """ def __new__(cls, pos, neighbours, **kwargs): """ - Extends the :class:`~pina.graph.GraphBuilder` class to compute - edge_index based on a K-nearest neighbors algorithm. + Instantiate the :class:`~pina.graph.Graph` class by computing the + ``edge_index`` based on the K-nearest neighbors algorithm. :param pos: A tensor of shape ``(N, D)`` representing the positions of ``N`` points in ``D``-dimensional space.