Fix rendering and codacy

This commit is contained in:
FilippoOlivo
2025-03-14 15:05:16 +01:00
committed by Nicola Demo
parent 05105dd517
commit 001d1fc9cf
8 changed files with 98 additions and 96 deletions

View File

@@ -23,9 +23,8 @@ class Graph(Data):
Create a new instance of the :class:`~pina.graph.Graph` class by
checking the consistency of the input data and storing the attributes.
:param kwargs: Parameters used to initialize the
:param dict kwargs: Parameters used to initialize the
:class:`~pina.graph.Graph` object.
:type kwargs: dict
:return: A new instance of the :class:`~pina.graph.Graph` class.
:rtype: Graph
"""
@@ -56,8 +55,8 @@ class Graph(Data):
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
number of features per node.
:type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
the indices of the graph's edges.
:param torch.Tensor edge_index: A tensor of shape ``(2, E)``
representing the indices of the graph's edges.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
@@ -80,8 +79,7 @@ class Graph(Data):
"""
Check the consistency of the types of the input data.
:param kwargs: Attributes to be checked for consistency.
:type kwargs: dict
:param dict kwargs: Attributes to be checked for consistency.
"""
# default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override
@@ -134,7 +132,8 @@ class Graph(Data):
Check if the edge attribute tensor is consistent in type and shape
with the edge index.
:param torch.Tensor edge_attr: The edge attribute tensor.
:param edge_attr: The edge attribute tensor.
:type edge_attr: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent.
"""
@@ -156,8 +155,11 @@ class Graph(Data):
Check if the input tensor x is consistent with the position tensor
`pos`.
:param torch.Tensor x: The input tensor.
:param torch.Tensor pos: The position tensor.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param pos: The position tensor.
:type pos: torch.Tensor | LabelTensor
:raises ValueError: If the input tensor is not consistent.
"""
if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor))
@@ -166,9 +168,6 @@ class Graph(Data):
if pos is not None:
if x.size(0) != pos.size(0):
raise ValueError("Inconsistent number of nodes.")
if pos is not None:
if x.size(0) != pos.size(0):
raise ValueError("Inconsistent number of nodes.")
@staticmethod
def _preprocess_edge_index(edge_index, undirected):
@@ -292,7 +291,7 @@ class GraphBuilder:
class RadiusGraph(GraphBuilder):
"""
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points
``edge_index`` based on a radius. Each point is connected to all the points
within the radius.
"""
@@ -305,11 +304,10 @@ class RadiusGraph(GraphBuilder):
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param float radius: The radius within which points are connected.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
constructors.
:return: A :class:`~pina.graph.Graph` instance containing the input
information and the computed ``edge_index``.
:param dict kwargs: The additional keyword arguments to be passed to
:class:`GraphBuilder` and :class:`Graph` classes.
:return: A :class:`~pina.graph.Graph` instance with the computed
``edge_index``.
:rtype: Graph
"""
edge_index = cls.compute_radius_graph(pos, radius)
@@ -318,16 +316,16 @@ class RadiusGraph(GraphBuilder):
@staticmethod
def compute_radius_graph(points, radius):
"""
Computes ``edge_index`` for a given set of points base on the radius.
Each point is connected to all the points within the radius.
Computes the ``edge_index`` based on the radius. Each point is connected
to all the points within the radius.
:param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor
:param float radius: The number of nearest neighbors to find for each
point.
:rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
number of edges, representing the edge indices of the KNN graph.
:param float radius: The radius within which points are connected.
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
representing the edge indices of the graph.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
return (
@@ -340,7 +338,7 @@ class RadiusGraph(GraphBuilder):
class KNNGraph(GraphBuilder):
"""
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a K-nearest neighbors algorithm.
``edge_index`` based on a K-nearest neighbors algorithm.
"""
def __new__(cls, pos, neighbours, **kwargs):
@@ -353,12 +351,11 @@ class KNNGraph(GraphBuilder):
:type pos: torch.Tensor | LabelTensor
:param int neighbours: The number of nearest neighbors to consider when
building the graph.
:Keyword Arguments:
The additional keyword arguments to be passed to GraphBuilder
and Graph classes
:param dict kwargs: The additional keyword arguments to be passed to
:class:`GraphBuilder` and :class:`Graph` classes.
:return: A :class:`~pina.graph.Graph` instance containg the
information passed in input and the computed ``edge_index``
:return: A :class:`~pina.graph.Graph` instance with the computed
``edge_index``.
:rtype: Graph
"""
@@ -366,41 +363,45 @@ class KNNGraph(GraphBuilder):
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
@staticmethod
def compute_knn_graph(points, k):
def compute_knn_graph(points, neighbours):
"""
Computes the edge_index based k-nearest neighbors graph algorithm
Computes the ``edge_index`` based on the K-nearest neighbors algorithm.
:param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor
:param int k: The number of nearest neighbors to find for each point.
:return: A tensor of shape ``(2, E)``, where ``E`` is the number of
edges, representing the edge indices of the KNN graph.
:param int neighbours: The number of nearest neighbors to consider when
building the graph.
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
representing the edge indices of the graph.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
row = torch.arange(points.size(0)).repeat_interleave(k)
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
:, 1:
]
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
col = knn_indices.flatten()
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
class LabelBatch(Batch):
"""
Add extract function to torch_geometric Batch object
Extends the :class:`~torch_geometric.data.Batch` class to include
:class:`~pina.label_tensor.LabelTensor` objects.
"""
@classmethod
def from_data_list(cls, data_list):
"""
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
or :class:`~pina.graph.Graph` objects.
:param data_list: List of :class:`~torch_geometric.data.Data` or
:class:`~pina.graph.Graph` objects.
:type data_list: list[Data] | list[Graph]
:return: A Batch object containing the data in the list
:return: A :class:`Batch` object containing the input data.
:rtype: Batch
"""
# Store the labels of Data/Graph objects (all data have the same labels)