Fix rendering graph

This commit is contained in:
FilippoOlivo
2025-03-14 10:59:04 +01:00
parent a3081cc09f
commit feb6ca952a

View File

@@ -63,8 +63,9 @@ class Graph(Data):
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured ``(E, F')`` where :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
``F'`` is the number of edge features ``F'`` is the number of edge features
:type edge_attr: torch.Tensor | LabelTensor
:param bool undirected: Whether to make the graph undirected :param bool undirected: Whether to make the graph undirected
:param kwargs: Additional keyword arguments passed to the :param dict kwargs: Additional keyword arguments passed to the
:class:`~torch_geometric.data.Data` class constructor. :class:`~torch_geometric.data.Data` class constructor.
""" """
# preprocessing # preprocessing
@@ -201,7 +202,7 @@ class Graph(Data):
class GraphBuilder: class GraphBuilder:
""" """
A class that allows the simple definition of Graph instances. A class that allows an easy definition of :class:`Graph` instances.
""" """
def __new__( def __new__(
@@ -217,25 +218,25 @@ class GraphBuilder:
Compute the edge attributes and create a new instance of the Compute the edge attributes and create a new instance of the
:class:`~pina.graph.Graph` class. :class:`~pina.graph.Graph` class.
:param pos: A tensor of shape `(N, D)` representing the positions of `N` :param pos: A tensor of shape ``(N, D)`` representing the positions of
points in `D`-dimensional space. ``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor or LabelTensor :type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape `(2, E)` representing the indices :param edge_index: A tensor of shape ``(2, E)`` representing the indices
of the graph's edges. of the graph's edges.
:type edge_index: torch.Tensor :type edge_index: torch.Tensor
:param x: Optional tensor of node features of shape `(N, F)`, where `F` :param x: Optional tensor of node features of shape ``(N, F)``, where
is the number of features per node. ``F`` is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional :type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`, :param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
where `F` is the number of features per edge. , where ``F`` is the number of features per edge.
:type edge_attr: torch.Tensor, optional :type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes. :param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`. If provided, overrides ``edge_attr``.
:type custom_edge_func: Callable, optional :type custom_edge_func: Callable, optional
:param kwargs: Additional keyword arguments passed to the :param kwargs: Additional keyword arguments passed to the
:class:`~pina.graph.Graph` class constructor. :class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the :return: A :class:`~pina.graph.Graph` instance constructed using the
provided information. provided information.
:rtype: Graph :rtype: Graph
""" """
edge_attr = cls._create_edge_attr( edge_attr = cls._create_edge_attr(
@@ -274,6 +275,7 @@ class GraphBuilder:
def _build_edge_attr(pos, edge_index): def _build_edge_attr(pos, edge_index):
""" """
Default function to compute the edge attributes. Default function to compute the edge attributes.
:param pos: Positions of the points. :param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices. :param torch.Tensor edge_index: Edge indices.
@@ -289,14 +291,15 @@ class GraphBuilder:
class RadiusGraph(GraphBuilder): class RadiusGraph(GraphBuilder):
""" """
A class to build a graph based on a radius. Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points
within the radius.
""" """
def __new__(cls, pos, radius, **kwargs): def __new__(cls, pos, radius, **kwargs):
""" """
Extends the :class:`~pina.graph.GraphBuilder` class to compute Instantiate the :class:`~pina.graph.Graph` class by computing the
edge_index based on a radius. Each point is connected to all the points ``edge_index`` based on the radius provided.
within the radius.
:param pos: A tensor of shape ``(N, D)`` representing the positions of :param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space. ``N`` points in ``D``-dimensional space.
@@ -336,13 +339,14 @@ class RadiusGraph(GraphBuilder):
class KNNGraph(GraphBuilder): class KNNGraph(GraphBuilder):
""" """
A class to build a K-nearest neighbors graph. Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a K-nearest neighbors algorithm.
""" """
def __new__(cls, pos, neighbours, **kwargs): def __new__(cls, pos, neighbours, **kwargs):
""" """
Extends the :class:`~pina.graph.GraphBuilder` class to compute Instantiate the :class:`~pina.graph.Graph` class by computing the
edge_index based on a K-nearest neighbors algorithm. ``edge_index`` based on the K-nearest neighbors algorithm.
:param pos: A tensor of shape ``(N, D)`` representing the positions of :param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space. ``N`` points in ``D``-dimensional space.