Tmp fixes

This commit is contained in:
FilippoOlivo
2025-03-14 10:20:56 +01:00
committed by Nicola Demo
parent c164db874b
commit 10ccae3a33
8 changed files with 74 additions and 69 deletions

View File

@@ -29,7 +29,6 @@ class Graph(Data):
:return: A new instance of the :class:`~pina.graph.Graph` class.
:rtype: Graph
"""
# create class instance
instance = Data.__new__(cls)
@@ -54,23 +53,20 @@ class Graph(Data):
the graph undirected if required. For more details, see the
:meth:`torch_geometric.data.Data`
:param x: Optional tensor of node features `(N, F)` where `F` is the
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
number of features per node.
:type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape `(2, E)` representing
:param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
the indices of the graph's edges.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'`
is the number of edge features
:param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
``F'`` is the number of edge features
:param bool undirected: Whether to make the graph undirected
:param kwargs: Additional keyword arguments passed to the
`torch_geometric.data.Data` class constructor. If the argument
is a `torch.Tensor` or `LabelTensor`, it is included in the graph
object.
:class:`~torch_geometric.data.Data` class constructor.
"""
# preprocessing
self._preprocess_edge_index(edge_index, undirected)
@@ -86,7 +82,6 @@ class Graph(Data):
:param kwargs: Attributes to be checked for consistency.
:type kwargs: dict
"""
# default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override
x, pos, edge_index, edge_attr = None, None, None, None
@@ -110,12 +105,9 @@ class Graph(Data):
def _check_pos_consistency(pos):
"""
Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor.
:raises ValueError: If the position tensor is not consistent.
"""
if pos is not None:
check_consistency(pos, (torch.Tensor, LabelTensor))
if pos.ndim != 2:
@@ -127,10 +119,8 @@ class Graph(Data):
Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge index tensor is not consistent.
"""
check_consistency(edge_index, (torch.Tensor, LabelTensor))
if edge_index.ndim != 2:
raise ValueError("edge_index must be a 2D tensor.")
@@ -145,11 +135,8 @@ class Graph(Data):
:param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent.
"""
if edge_attr is not None:
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
if edge_attr.ndim != 2:
@@ -171,7 +158,6 @@ class Graph(Data):
:param torch.Tensor x: The input tensor.
:param torch.Tensor pos: The position tensor.
"""
if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor))
if x.ndim != 2:
@@ -193,7 +179,6 @@ class Graph(Data):
:return: The preprocessed edge index.
:rtype: torch.Tensor
"""
if undirected:
edge_index = to_undirected(edge_index)
return edge_index
@@ -246,7 +231,7 @@ class GraphBuilder:
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional
:type custom_edge_func: Callable, optional
:param kwargs: Additional keyword arguments passed to the
:class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the
@@ -266,6 +251,18 @@ class GraphBuilder:
@staticmethod
def _create_edge_attr(pos, edge_index, edge_attr, func):
"""
Create the edge attributes based on the input parameters.
:param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices.
:param bool edge_attr: Whether to compute the edge attributes.
:param Callable func: Function to compute the edge attributes.
:raises ValueError: If ``func`` is not a function.
:return: The edge attributes.
:rtype: torch.Tensor | LabelTensor | None
"""
check_consistency(edge_attr, bool)
if edge_attr:
if is_function(func):
@@ -275,6 +272,14 @@ class GraphBuilder:
@staticmethod
def _build_edge_attr(pos, edge_index):
"""
Default function to compute the edge attributes.
:param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices.
:return: The edge attributes.
:rtype: torch.Tensor
"""
return (
(pos[edge_index[0]] - pos[edge_index[1]])
.abs()
@@ -293,37 +298,34 @@ class RadiusGraph(GraphBuilder):
edge_index based on a radius. Each point is connected to all the points
within the radius.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:type pos: torch.Tensor or LabelTensor
:param radius: The radius within which points are connected.
:type radius: float
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param float radius: The radius within which points are connected.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
constructors.
:return: A :class:`~pina.graph.Graph` instance containing the input
information and the computed edge_index.
information and the computed ``edge_index``.
:rtype: Graph
"""
edge_index = cls.compute_radius_graph(pos, radius)
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
@staticmethod
def compute_radius_graph(points, radius):
"""
Computes `edge_index` for a given set of points base on the radius.
Computes ``edge_index`` for a given set of points base on the radius.
Each point is connected to all the points within the radius.
:param points: A tensor of shape `(N, D)` representing the positions of
N points in D-dimensional space.
:param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor
:param float radius: The number of nearest neighbors to find for each
point.
:rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number
of edges, representing the edge indices of the KNN graph.
:rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
number of edges, representing the edge indices of the KNN graph.
"""
dist = torch.cdist(points, points, p=2)
return (
torch.nonzero(dist <= radius, as_tuple=False)
@@ -342,8 +344,8 @@ class KNNGraph(GraphBuilder):
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a K-nearest neighbors algorithm.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param int neighbours: The number of nearest neighbors to consider when
building the graph.
@@ -352,7 +354,7 @@ class KNNGraph(GraphBuilder):
and Graph classes
:return: A :class:`~pina.graph.Graph` instance containg the
information passed in input and the computed edge_index
information passed in input and the computed ``edge_index``
:rtype: Graph
"""
@@ -364,11 +366,11 @@ class KNNGraph(GraphBuilder):
"""
Computes the edge_index based k-nearest neighbors graph algorithm
:param points: A tensor of shape (N, D) representing the positions of
N points in D-dimensional space.
:param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor
:param int k: The number of nearest neighbors to find for each point.
:return: A tensor of shape (2, E), where E is the number of
:return: A tensor of shape ``(2, E)``, where ``E`` is the number of
edges, representing the edge indices of the KNN graph.
:rtype: torch.Tensor
"""
@@ -391,7 +393,8 @@ class LabelBatch(Batch):
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
:param data_list: List of Data/Graph objects
:param data_list: List of :class:`~torch_geometric.data.Data` or
:class:`~pina.graph.Graph` objects.
:type data_list: list[Data] | list[Graph]
:return: A Batch object containing the data in the list
:rtype: Batch