Fix rendering graph

This commit is contained in:
FilippoOlivo
2025-03-14 10:59:04 +01:00
committed by Nicola Demo
parent 10ccae3a33
commit bc62ef9120

View File

@@ -63,8 +63,9 @@ class Graph(Data):
:type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
``F'`` is the number of edge features
:type edge_attr: torch.Tensor | LabelTensor
:param bool undirected: Whether to make the graph undirected
:param kwargs: Additional keyword arguments passed to the
:param dict kwargs: Additional keyword arguments passed to the
:class:`~torch_geometric.data.Data` class constructor.
"""
# preprocessing
@@ -201,7 +202,7 @@ class Graph(Data):
class GraphBuilder:
"""
A class that allows the simple definition of Graph instances.
A class that allows an easy definition of :class:`Graph` instances.
"""
def __new__(
@@ -217,25 +218,25 @@ class GraphBuilder:
Compute the edge attributes and create a new instance of the
:class:`~pina.graph.Graph` class.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape `(2, E)` representing the indices
:param edge_index: A tensor of shape ``(2, E)`` representing the indices
of the graph's edges.
:type edge_index: torch.Tensor
:param x: Optional tensor of node features of shape `(N, F)`, where `F`
is the number of features per node.
:param x: Optional tensor of node features of shape ``(N, F)``, where
``F`` is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
where `F` is the number of features per edge.
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
, where ``F`` is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`.
If provided, overrides ``edge_attr``.
:type custom_edge_func: Callable, optional
:param kwargs: Additional keyword arguments passed to the
:class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the
provided information.
provided information.
:rtype: Graph
"""
edge_attr = cls._create_edge_attr(
@@ -274,6 +275,7 @@ class GraphBuilder:
def _build_edge_attr(pos, edge_index):
"""
Default function to compute the edge attributes.
:param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices.
@@ -289,14 +291,15 @@ class GraphBuilder:
class RadiusGraph(GraphBuilder):
"""
A class to build a graph based on a radius.
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points
within the radius.
"""
def __new__(cls, pos, radius, **kwargs):
"""
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points
within the radius.
Instantiate the :class:`~pina.graph.Graph` class by computing the
``edge_index`` based on the radius provided.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
@@ -336,13 +339,14 @@ class RadiusGraph(GraphBuilder):
class KNNGraph(GraphBuilder):
"""
A class to build a K-nearest neighbors graph.
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a K-nearest neighbors algorithm.
"""
def __new__(cls, pos, neighbours, **kwargs):
"""
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a K-nearest neighbors algorithm.
Instantiate the :class:`~pina.graph.Graph` class by computing the
``edge_index`` based on the K-nearest neighbors algorithm.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.