This commit is contained in:
FilippoOlivo
2025-03-12 12:03:47 +01:00
committed by Nicola Demo
parent 59e6ee595c
commit ae796ce34c
6 changed files with 128 additions and 97 deletions

View File

@@ -1,5 +1,5 @@
"""
This module provides an interface to build torch_geometric.data.Data objects.
Module to build Graph objects and perform operations on them.
"""
import torch
@@ -11,7 +11,8 @@ from .utils import check_consistency, is_function
class Graph(Data):
"""
A class to build torch_geometric.data.Data objects.
Extends :class:`~torch_geometric.data.Data` class to include additional
checks and functionlities.
"""
def __new__(
@@ -19,13 +20,16 @@ class Graph(Data):
**kwargs,
):
"""
Instantiates a new instance of the Graph class, performing type
consistency checks.
Create a new instance of the :class:`pina.graph.Graph` class by checking
the consistency of the input data and storing the attributes.
:param kwargs: Parameters to construct the Graph object.
:return: A new instance of the Graph class.
:param kwargs: Parameters used to initialize the
:class:`pina.graph.Graph` object.
:type kwargs: dict
:return: A new instance of the :class:`pina.graph.Graph` class.
:rtype: Graph
"""
# create class instance
instance = Data.__new__(cls)
@@ -45,27 +49,28 @@ class Graph(Data):
**kwargs,
):
"""
Initialize the Graph object by setting the node features, edge index,
Initialize the object by setting the node features, edge index,
edge attributes, and positions. The edge index is preprocessed to make
the graph undirected if required. For more details, see the
:meth: `torch_geometric.data.Data`
:meth:`torch_geometric.data.Data`
:param x: Optional tensor of node features (N, F) where F is the number
of features per node.
:param x: Optional tensor of node features `(N, F)` where `F` is the
number of features per node.
:type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape (2, E) representing
:param torch.Tensor edge_index: A tensor of shape `(2, E)` representing
the indices of the graph's edges.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured (E, F') where F' is
the number of edge features
:param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'`
is the number of edge features
:param bool undirected: Whether to make the graph undirected
:param kwargs: Additional keyword arguments passed to the
`torch_geometric.data.Data` class constructor. If the argument
is a `torch.Tensor` or `LabelTensor`, it is included in the Data
object as a graph parameter.
is a `torch.Tensor` or `LabelTensor`, it is included in the graph
object.
"""
# preprocessing
self._preprocess_edge_index(edge_index, undirected)
@@ -107,6 +112,8 @@ class Graph(Data):
Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor.
:raises ValueError: If the position tensor is not consistent.
"""
if pos is not None:
@@ -120,6 +127,8 @@ class Graph(Data):
Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge index tensor is not consistent.
"""
check_consistency(edge_index, (torch.Tensor, LabelTensor))
@@ -136,6 +145,9 @@ class Graph(Data):
:param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent.
"""
if edge_attr is not None:
@@ -217,27 +229,28 @@ class GraphBuilder:
**kwargs,
):
"""
Compute the edge attributes and create a new instance of the Graph
class.
Compute the edge attributes and create a new instance of the
:class:`pina.graph.Graph` class.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape (2, E) representing the indices of
the graph's edges.
:param edge_index: A tensor of shape `(2, E)` representing the indices
of the graph's edges.
:type edge_index: torch.Tensor
:param x: Optional tensor of node features of shape (N, F), where F is
the number of features per node.
:param x: Optional tensor of node features of shape `(N, F)`, where `F`
is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
where F is the number of features per edge.
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
where `F` is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the Graph class
constructor.
:return: A Graph instance constructed using the provided information.
:param kwargs: Additional keyword arguments passed to the
:class:`pina.graph.Graph` class constructor.
:return: A :class:`pina.graph.Graph` instance constructed using the
provided information.
:rtype: Graph
"""
edge_attr = cls._create_edge_attr(
@@ -271,42 +284,46 @@ class GraphBuilder:
class RadiusGraph(GraphBuilder):
"""
A class to build a radius graph.
A class to build a graph based on a radius.
"""
def __new__(cls, pos, radius, **kwargs):
"""
Extends the `GraphBuilder` class to compute edge_index based on a
radius. Each point is connected to all the points within the radius.
Extends the :class:`pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points
within the radius.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:type pos: torch.Tensor or LabelTensor
:param radius: The radius within which points are connected.
:type radius: float
:param kwargs: Additional keyword arguments to be passed to the
`GraphBuilder` and `Graph` constructors.
:return: A `Graph` instance containing the input information and the
computed edge_index.
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
constructors.
:return: A :class:`pina.graph.Graph` instance containing the input
information and the computed edge_index.
:rtype: Graph
"""
edge_index = cls.compute_radius_graph(pos, radius)
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
@staticmethod
def compute_radius_graph(points, radius):
"""
Computes edge_index for a given set of points base on the radius.
Computes `edge_index` for a given set of points base on the radius.
Each point is connected to all the points within the radius.
:param points: A tensor of shape (N, D) representing the positions of
:param points: A tensor of shape `(N, D)` representing the positions of
N points in D-dimensional space.
:type points: torch.Tensor | LabelTensor
:param float radius: The number of nearest neighbors to find for each
point.
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
edges, representing the edge indices of the KNN graph.
:rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number
of edges, representing the edge indices of the KNN graph.
"""
dist = torch.cdist(points, points, p=2)
return (
torch.nonzero(dist <= radius, as_tuple=False)
@@ -317,13 +334,13 @@ class RadiusGraph(GraphBuilder):
class KNNGraph(GraphBuilder):
"""
A class to build a KNN graph.
A class to build a K-nearest neighbors graph.
"""
def __new__(cls, pos, neighbours, **kwargs):
"""
Creates a new instance of the Graph class using k-nearest neighbors
algorithm to define the edges.
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
based on a K-nearest neighbors algorithm.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
@@ -334,8 +351,8 @@ class KNNGraph(GraphBuilder):
The additional keyword arguments to be passed to GraphBuilder
and Graph classes
:return: Graph instance containg the information passed in input and
the computed edge_index
:return: A :class:`pina.graph.Graph` instance containg the
information passed in input and the computed edge_index
:rtype: Graph
"""
@@ -371,7 +388,8 @@ class LabelBatch(Batch):
@classmethod
def from_data_list(cls, data_list):
"""
Create a Batch object from a list of Data objects.
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
:param data_list: List of Data/Graph objects
:type data_list: list[Data] | list[Graph]