Documentation and docstring graph and data

This commit is contained in:
FilippoOlivo
2025-03-10 15:57:15 +01:00
committed by Nicola Demo
parent 6ce0bafc2b
commit 635e3b3a75
3 changed files with 342 additions and 83 deletions

View File

@@ -19,6 +19,9 @@ class Graph(Data):
**kwargs,
):
"""
Instantiates a new instance of the Graph class, performing type
consistency checks.
:param kwargs: Parameters to construct the Graph object.
:return: A new instance of the Graph class.
:rtype: Graph
@@ -42,7 +45,10 @@ class Graph(Data):
**kwargs,
):
"""
Initialize the Graph object.
Initialize the Graph object by setting the node features, edge index,
edge attributes, and positions. The edge index is preprocessed to make
the graph undirected if required. For more details, see the
:meth: `torch_geometric.data.Data`
:param x: Optional tensor of node features (N, F) where F is the number
of features per node.
@@ -69,6 +75,13 @@ class Graph(Data):
)
def _check_type_consistency(self, **kwargs):
"""
Check the consistency of the types of the input data.
:param kwargs: Attributes to be checked for consistency.
:type kwargs: dict
"""
# default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override
x, pos, edge_index, edge_attr = None, None, None, None
@@ -92,8 +105,10 @@ class Graph(Data):
def _check_pos_consistency(pos):
"""
Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor.
"""
if pos is not None:
check_consistency(pos, (torch.Tensor, LabelTensor))
if pos.ndim != 2:
@@ -103,8 +118,10 @@ class Graph(Data):
def _check_edge_index_consistency(edge_index):
"""
Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor.
"""
check_consistency(edge_index, (torch.Tensor, LabelTensor))
if edge_index.ndim != 2:
raise ValueError("edge_index must be a 2D tensor.")
@@ -114,11 +131,13 @@ class Graph(Data):
@staticmethod
def _check_edge_attr_consistency(edge_attr, edge_index):
"""
Check if the edge attr is consistent.
:param torch.Tensor edge_attr: The edge attribute tensor.
Check if the edge attribute tensor is consistent in type and shape
with the edge index.
:param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor.
"""
if edge_attr is not None:
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
if edge_attr.ndim != 2:
@@ -134,10 +153,13 @@ class Graph(Data):
@staticmethod
def _check_x_consistency(x, pos=None):
"""
Check if the input tensor x is consistent with the position tensor pos.
Check if the input tensor x is consistent with the position tensor
`pos`.
:param torch.Tensor x: The input tensor.
:param torch.Tensor pos: The position tensor.
"""
if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor))
if x.ndim != 2:
@@ -152,22 +174,24 @@ class Graph(Data):
@staticmethod
def _preprocess_edge_index(edge_index, undirected):
"""
Preprocess the edge index.
Preprocess the edge index to make the graph undirected (if required).
:param torch.Tensor edge_index: The edge index.
:param bool undirected: Whether the graph is undirected.
:return: The preprocessed edge index.
:rtype: torch.Tensor
"""
if undirected:
edge_index = to_undirected(edge_index)
return edge_index
def extract(self, labels, attr="x"):
"""
Perform extraction of labels on node features (x)
Perform extraction of labels from the attribute specified by `attr`.
:param labels: Labels to extract
:type labels: list[str] | tuple[str] | str
:type labels: list[str] | tuple[str] | str | dict
:return: Batch object with extraction performed on x
:rtype: PinaBatch
"""
@@ -193,21 +217,24 @@ class GraphBuilder:
**kwargs,
):
"""
Creates a new instance of the Graph class.
Compute the edge attributes and create a new instance of the Graph
class.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:type pos: torch.Tensor | LabelTensor
:type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape (2, E) representing the indices of
the graph's edges.
:type edge_index: torch.Tensor
:param x: Optional tensor of node features (N, F) where F is the number
of features per node.
:type x: torch.Tensor, LabelTensor
:param bool edge_attr: Optional edge attributes (E, F) where F is the
number of features per edge.
:param callable custom_edge_func: A custom function to compute edge
attributes.
:param x: Optional tensor of node features of shape (N, F), where F is
the number of features per node.
:type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
where F is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the Graph class
constructor.
:return: A Graph instance constructed using the provided information.
@@ -249,18 +276,18 @@ class RadiusGraph(GraphBuilder):
def __new__(cls, pos, radius, **kwargs):
"""
Creates a new instance of the Graph class using a radius-based graph
construction.
Extends the `GraphBuilder` class to compute edge_index based on a
radius. Each point is connected to all the points within the radius.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param float radius: The radius within which points are connected.
:Keyword Arguments:
The additional keyword arguments to be passed to GraphBuilder
and Graph classes
:return: Graph instance containg the information passed in input and
the computed edge_index
:type pos: torch.Tensor or LabelTensor
:param radius: The radius within which points are connected.
:type radius: float
:param kwargs: Additional keyword arguments to be passed to the
`GraphBuilder` and `Graph` constructors.
:return: A `Graph` instance containing the input information and the
computed edge_index.
:rtype: Graph
"""
edge_index = cls.compute_radius_graph(pos, radius)
@@ -269,7 +296,8 @@ class RadiusGraph(GraphBuilder):
@staticmethod
def compute_radius_graph(points, radius):
"""
Computes a radius-based graph for a given set of points.
Computes edge_index for a given set of points base on the radius.
Each point is connected to all the points within the radius.
:param points: A tensor of shape (N, D) representing the positions of
N points in D-dimensional space.
@@ -295,7 +323,7 @@ class KNNGraph(GraphBuilder):
def __new__(cls, pos, neighbours, **kwargs):
"""
Creates a new instance of the Graph class using k-nearest neighbors
to compute edge_index.
algorithm to define the edges.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
@@ -323,8 +351,9 @@ class KNNGraph(GraphBuilder):
N points in D-dimensional space.
:type points: torch.Tensor | LabelTensor
:param int k: The number of nearest neighbors to find for each point.
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
:return: A tensor of shape (2, E), where E is the number of
edges, representing the edge indices of the KNN graph.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
@@ -343,6 +372,11 @@ class LabelBatch(Batch):
def from_data_list(cls, data_list):
"""
Create a Batch object from a list of Data objects.
:param data_list: List of Data/Graph objects
:type data_list: list[Data] | list[Graph]
:return: A Batch object containing the data in the list
:rtype: Batch
"""
# Store the labels of Data/Graph objects (all data have the same labels)
# If the data do not contain labels, labels is an empty dictionary,