Fix rendering graph
This commit is contained in:
committed by
Nicola Demo
parent
10ccae3a33
commit
bc62ef9120
@@ -63,8 +63,9 @@ class Graph(Data):
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
|
||||
``F'`` is the number of edge features
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:param bool undirected: Whether to make the graph undirected
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
:param dict kwargs: Additional keyword arguments passed to the
|
||||
:class:`~torch_geometric.data.Data` class constructor.
|
||||
"""
|
||||
# preprocessing
|
||||
@@ -201,7 +202,7 @@ class Graph(Data):
|
||||
|
||||
class GraphBuilder:
|
||||
"""
|
||||
A class that allows the simple definition of Graph instances.
|
||||
A class that allows an easy definition of :class:`Graph` instances.
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
@@ -217,25 +218,25 @@ class GraphBuilder:
|
||||
Compute the edge attributes and create a new instance of the
|
||||
:class:`~pina.graph.Graph` class.
|
||||
|
||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||
points in `D`-dimensional space.
|
||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||
``N`` points in ``D``-dimensional space.
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param edge_index: A tensor of shape `(2, E)` representing the indices
|
||||
:param edge_index: A tensor of shape ``(2, E)`` representing the indices
|
||||
of the graph's edges.
|
||||
:type edge_index: torch.Tensor
|
||||
:param x: Optional tensor of node features of shape `(N, F)`, where `F`
|
||||
is the number of features per node.
|
||||
:param x: Optional tensor of node features of shape ``(N, F)``, where
|
||||
``F`` is the number of features per node.
|
||||
:type x: torch.Tensor | LabelTensor, optional
|
||||
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
|
||||
where `F` is the number of features per edge.
|
||||
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
|
||||
, where ``F`` is the number of features per edge.
|
||||
:type edge_attr: torch.Tensor, optional
|
||||
:param custom_edge_func: A custom function to compute edge attributes.
|
||||
If provided, overrides `edge_attr`.
|
||||
If provided, overrides ``edge_attr``.
|
||||
:type custom_edge_func: Callable, optional
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
:class:`~pina.graph.Graph` class constructor.
|
||||
:return: A :class:`~pina.graph.Graph` instance constructed using the
|
||||
provided information.
|
||||
provided information.
|
||||
:rtype: Graph
|
||||
"""
|
||||
edge_attr = cls._create_edge_attr(
|
||||
@@ -274,6 +275,7 @@ class GraphBuilder:
|
||||
def _build_edge_attr(pos, edge_index):
|
||||
"""
|
||||
Default function to compute the edge attributes.
|
||||
|
||||
:param pos: Positions of the points.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: Edge indices.
|
||||
@@ -289,14 +291,15 @@ class GraphBuilder:
|
||||
|
||||
class RadiusGraph(GraphBuilder):
|
||||
"""
|
||||
A class to build a graph based on a radius.
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a radius. Each point is connected to all the points
|
||||
within the radius.
|
||||
"""
|
||||
|
||||
def __new__(cls, pos, radius, **kwargs):
|
||||
"""
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a radius. Each point is connected to all the points
|
||||
within the radius.
|
||||
Instantiate the :class:`~pina.graph.Graph` class by computing the
|
||||
``edge_index`` based on the radius provided.
|
||||
|
||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||
``N`` points in ``D``-dimensional space.
|
||||
@@ -336,13 +339,14 @@ class RadiusGraph(GraphBuilder):
|
||||
|
||||
class KNNGraph(GraphBuilder):
|
||||
"""
|
||||
A class to build a K-nearest neighbors graph.
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a K-nearest neighbors algorithm.
|
||||
"""
|
||||
|
||||
def __new__(cls, pos, neighbours, **kwargs):
|
||||
"""
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a K-nearest neighbors algorithm.
|
||||
Instantiate the :class:`~pina.graph.Graph` class by computing the
|
||||
``edge_index`` based on the K-nearest neighbors algorithm.
|
||||
|
||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||
``N`` points in ``D``-dimensional space.
|
||||
|
||||
Reference in New Issue
Block a user