Fix rendering and codacy
This commit is contained in:
committed by
Nicola Demo
parent
05105dd517
commit
001d1fc9cf
@@ -23,9 +23,8 @@ class Graph(Data):
|
||||
Create a new instance of the :class:`~pina.graph.Graph` class by
|
||||
checking the consistency of the input data and storing the attributes.
|
||||
|
||||
:param kwargs: Parameters used to initialize the
|
||||
:param dict kwargs: Parameters used to initialize the
|
||||
:class:`~pina.graph.Graph` object.
|
||||
:type kwargs: dict
|
||||
:return: A new instance of the :class:`~pina.graph.Graph` class.
|
||||
:rtype: Graph
|
||||
"""
|
||||
@@ -56,8 +55,8 @@ class Graph(Data):
|
||||
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
|
||||
number of features per node.
|
||||
:type x: torch.Tensor, LabelTensor
|
||||
:param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
|
||||
the indices of the graph's edges.
|
||||
:param torch.Tensor edge_index: A tensor of shape ``(2, E)``
|
||||
representing the indices of the graph's edges.
|
||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||
``N`` points in ``D``-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
@@ -80,8 +79,7 @@ class Graph(Data):
|
||||
"""
|
||||
Check the consistency of the types of the input data.
|
||||
|
||||
:param kwargs: Attributes to be checked for consistency.
|
||||
:type kwargs: dict
|
||||
:param dict kwargs: Attributes to be checked for consistency.
|
||||
"""
|
||||
# default types, specified in cls.__new__, by default they are Nont
|
||||
# if specified in **kwargs they get override
|
||||
@@ -134,7 +132,8 @@ class Graph(Data):
|
||||
Check if the edge attribute tensor is consistent in type and shape
|
||||
with the edge index.
|
||||
|
||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||
:param edge_attr: The edge attribute tensor.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
:raises ValueError: If the edge attribute tensor is not consistent.
|
||||
"""
|
||||
@@ -156,8 +155,11 @@ class Graph(Data):
|
||||
Check if the input tensor x is consistent with the position tensor
|
||||
`pos`.
|
||||
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
:param x: The input tensor.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param pos: The position tensor.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:raises ValueError: If the input tensor is not consistent.
|
||||
"""
|
||||
if x is not None:
|
||||
check_consistency(x, (torch.Tensor, LabelTensor))
|
||||
@@ -166,9 +168,6 @@ class Graph(Data):
|
||||
if pos is not None:
|
||||
if x.size(0) != pos.size(0):
|
||||
raise ValueError("Inconsistent number of nodes.")
|
||||
if pos is not None:
|
||||
if x.size(0) != pos.size(0):
|
||||
raise ValueError("Inconsistent number of nodes.")
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_edge_index(edge_index, undirected):
|
||||
@@ -292,7 +291,7 @@ class GraphBuilder:
|
||||
class RadiusGraph(GraphBuilder):
|
||||
"""
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a radius. Each point is connected to all the points
|
||||
``edge_index`` based on a radius. Each point is connected to all the points
|
||||
within the radius.
|
||||
"""
|
||||
|
||||
@@ -305,11 +304,10 @@ class RadiusGraph(GraphBuilder):
|
||||
``N`` points in ``D``-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param float radius: The radius within which points are connected.
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
|
||||
constructors.
|
||||
:return: A :class:`~pina.graph.Graph` instance containing the input
|
||||
information and the computed ``edge_index``.
|
||||
:param dict kwargs: The additional keyword arguments to be passed to
|
||||
:class:`GraphBuilder` and :class:`Graph` classes.
|
||||
:return: A :class:`~pina.graph.Graph` instance with the computed
|
||||
``edge_index``.
|
||||
:rtype: Graph
|
||||
"""
|
||||
edge_index = cls.compute_radius_graph(pos, radius)
|
||||
@@ -318,16 +316,16 @@ class RadiusGraph(GraphBuilder):
|
||||
@staticmethod
|
||||
def compute_radius_graph(points, radius):
|
||||
"""
|
||||
Computes ``edge_index`` for a given set of points base on the radius.
|
||||
Each point is connected to all the points within the radius.
|
||||
Computes the ``edge_index`` based on the radius. Each point is connected
|
||||
to all the points within the radius.
|
||||
|
||||
:param points: A tensor of shape ``(N, D)`` representing the positions
|
||||
of ``N`` points in ``D``-dimensional space.
|
||||
:type points: torch.Tensor | LabelTensor
|
||||
:param float radius: The number of nearest neighbors to find for each
|
||||
point.
|
||||
:rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
|
||||
number of edges, representing the edge indices of the KNN graph.
|
||||
:param float radius: The radius within which points are connected.
|
||||
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
|
||||
representing the edge indices of the graph.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
return (
|
||||
@@ -340,7 +338,7 @@ class RadiusGraph(GraphBuilder):
|
||||
class KNNGraph(GraphBuilder):
|
||||
"""
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a K-nearest neighbors algorithm.
|
||||
``edge_index`` based on a K-nearest neighbors algorithm.
|
||||
"""
|
||||
|
||||
def __new__(cls, pos, neighbours, **kwargs):
|
||||
@@ -353,12 +351,11 @@ class KNNGraph(GraphBuilder):
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param int neighbours: The number of nearest neighbors to consider when
|
||||
building the graph.
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments to be passed to GraphBuilder
|
||||
and Graph classes
|
||||
:param dict kwargs: The additional keyword arguments to be passed to
|
||||
:class:`GraphBuilder` and :class:`Graph` classes.
|
||||
|
||||
:return: A :class:`~pina.graph.Graph` instance containg the
|
||||
information passed in input and the computed ``edge_index``
|
||||
:return: A :class:`~pina.graph.Graph` instance with the computed
|
||||
``edge_index``.
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
@@ -366,41 +363,45 @@ class KNNGraph(GraphBuilder):
|
||||
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def compute_knn_graph(points, k):
|
||||
def compute_knn_graph(points, neighbours):
|
||||
"""
|
||||
Computes the edge_index based k-nearest neighbors graph algorithm
|
||||
Computes the ``edge_index`` based on the K-nearest neighbors algorithm.
|
||||
|
||||
:param points: A tensor of shape ``(N, D)`` representing the positions
|
||||
of ``N`` points in ``D``-dimensional space.
|
||||
:type points: torch.Tensor | LabelTensor
|
||||
:param int k: The number of nearest neighbors to find for each point.
|
||||
:return: A tensor of shape ``(2, E)``, where ``E`` is the number of
|
||||
edges, representing the edge indices of the KNN graph.
|
||||
:param int neighbours: The number of nearest neighbors to consider when
|
||||
building the graph.
|
||||
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
|
||||
representing the edge indices of the graph.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
|
||||
row = torch.arange(points.size(0)).repeat_interleave(k)
|
||||
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
|
||||
:, 1:
|
||||
]
|
||||
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
|
||||
col = knn_indices.flatten()
|
||||
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
|
||||
|
||||
|
||||
class LabelBatch(Batch):
|
||||
"""
|
||||
Add extract function to torch_geometric Batch object
|
||||
Extends the :class:`~torch_geometric.data.Batch` class to include
|
||||
:class:`~pina.label_tensor.LabelTensor` objects.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_data_list(cls, data_list):
|
||||
"""
|
||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||
objects.
|
||||
or :class:`~pina.graph.Graph` objects.
|
||||
|
||||
:param data_list: List of :class:`~torch_geometric.data.Data` or
|
||||
:class:`~pina.graph.Graph` objects.
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
:return: A Batch object containing the data in the list
|
||||
:return: A :class:`Batch` object containing the input data.
|
||||
:rtype: Batch
|
||||
"""
|
||||
# Store the labels of Data/Graph objects (all data have the same labels)
|
||||
|
||||
Reference in New Issue
Block a user