Documentation and docstring graph and data
This commit is contained in:
@@ -19,6 +19,9 @@ class Graph(Data):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiates a new instance of the Graph class, performing type
|
||||
consistency checks.
|
||||
|
||||
:param kwargs: Parameters to construct the Graph object.
|
||||
:return: A new instance of the Graph class.
|
||||
:rtype: Graph
|
||||
@@ -42,7 +45,10 @@ class Graph(Data):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Graph object.
|
||||
Initialize the Graph object by setting the node features, edge index,
|
||||
edge attributes, and positions. The edge index is preprocessed to make
|
||||
the graph undirected if required. For more details, see the
|
||||
:meth: `torch_geometric.data.Data`
|
||||
|
||||
:param x: Optional tensor of node features (N, F) where F is the number
|
||||
of features per node.
|
||||
@@ -69,6 +75,13 @@ class Graph(Data):
|
||||
)
|
||||
|
||||
def _check_type_consistency(self, **kwargs):
|
||||
"""
|
||||
Check the consistency of the types of the input data.
|
||||
|
||||
:param kwargs: Attributes to be checked for consistency.
|
||||
:type kwargs: dict
|
||||
"""
|
||||
|
||||
# default types, specified in cls.__new__, by default they are Nont
|
||||
# if specified in **kwargs they get override
|
||||
x, pos, edge_index, edge_attr = None, None, None, None
|
||||
@@ -92,8 +105,10 @@ class Graph(Data):
|
||||
def _check_pos_consistency(pos):
|
||||
"""
|
||||
Check if the position tensor is consistent.
|
||||
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
"""
|
||||
|
||||
if pos is not None:
|
||||
check_consistency(pos, (torch.Tensor, LabelTensor))
|
||||
if pos.ndim != 2:
|
||||
@@ -103,8 +118,10 @@ class Graph(Data):
|
||||
def _check_edge_index_consistency(edge_index):
|
||||
"""
|
||||
Check if the edge index is consistent.
|
||||
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
"""
|
||||
|
||||
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
||||
if edge_index.ndim != 2:
|
||||
raise ValueError("edge_index must be a 2D tensor.")
|
||||
@@ -114,11 +131,13 @@ class Graph(Data):
|
||||
@staticmethod
|
||||
def _check_edge_attr_consistency(edge_attr, edge_index):
|
||||
"""
|
||||
Check if the edge attr is consistent.
|
||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||
Check if the edge attribute tensor is consistent in type and shape
|
||||
with the edge index.
|
||||
|
||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
"""
|
||||
|
||||
if edge_attr is not None:
|
||||
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
|
||||
if edge_attr.ndim != 2:
|
||||
@@ -134,10 +153,13 @@ class Graph(Data):
|
||||
@staticmethod
|
||||
def _check_x_consistency(x, pos=None):
|
||||
"""
|
||||
Check if the input tensor x is consistent with the position tensor pos.
|
||||
Check if the input tensor x is consistent with the position tensor
|
||||
`pos`.
|
||||
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
"""
|
||||
|
||||
if x is not None:
|
||||
check_consistency(x, (torch.Tensor, LabelTensor))
|
||||
if x.ndim != 2:
|
||||
@@ -152,22 +174,24 @@ class Graph(Data):
|
||||
@staticmethod
|
||||
def _preprocess_edge_index(edge_index, undirected):
|
||||
"""
|
||||
Preprocess the edge index.
|
||||
Preprocess the edge index to make the graph undirected (if required).
|
||||
|
||||
:param torch.Tensor edge_index: The edge index.
|
||||
:param bool undirected: Whether the graph is undirected.
|
||||
:return: The preprocessed edge index.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
if undirected:
|
||||
edge_index = to_undirected(edge_index)
|
||||
return edge_index
|
||||
|
||||
def extract(self, labels, attr="x"):
|
||||
"""
|
||||
Perform extraction of labels on node features (x)
|
||||
Perform extraction of labels from the attribute specified by `attr`.
|
||||
|
||||
:param labels: Labels to extract
|
||||
:type labels: list[str] | tuple[str] | str
|
||||
:type labels: list[str] | tuple[str] | str | dict
|
||||
:return: Batch object with extraction performed on x
|
||||
:rtype: PinaBatch
|
||||
"""
|
||||
@@ -193,21 +217,24 @@ class GraphBuilder:
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a new instance of the Graph class.
|
||||
Compute the edge attributes and create a new instance of the Graph
|
||||
class.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param edge_index: A tensor of shape (2, E) representing the indices of
|
||||
the graph's edges.
|
||||
:type edge_index: torch.Tensor
|
||||
:param x: Optional tensor of node features (N, F) where F is the number
|
||||
of features per node.
|
||||
:type x: torch.Tensor, LabelTensor
|
||||
:param bool edge_attr: Optional edge attributes (E, F) where F is the
|
||||
number of features per edge.
|
||||
:param callable custom_edge_func: A custom function to compute edge
|
||||
attributes.
|
||||
:param x: Optional tensor of node features of shape (N, F), where F is
|
||||
the number of features per node.
|
||||
:type x: torch.Tensor | LabelTensor, optional
|
||||
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
|
||||
where F is the number of features per edge.
|
||||
:type edge_attr: torch.Tensor, optional
|
||||
:param custom_edge_func: A custom function to compute edge attributes.
|
||||
If provided, overrides `edge_attr`.
|
||||
:type custom_edge_func: callable, optional
|
||||
:param kwargs: Additional keyword arguments passed to the Graph class
|
||||
constructor.
|
||||
:return: A Graph instance constructed using the provided information.
|
||||
@@ -249,18 +276,18 @@ class RadiusGraph(GraphBuilder):
|
||||
|
||||
def __new__(cls, pos, radius, **kwargs):
|
||||
"""
|
||||
Creates a new instance of the Graph class using a radius-based graph
|
||||
construction.
|
||||
Extends the `GraphBuilder` class to compute edge_index based on a
|
||||
radius. Each point is connected to all the points within the radius.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param float radius: The radius within which points are connected.
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments to be passed to GraphBuilder
|
||||
and Graph classes
|
||||
:return: Graph instance containg the information passed in input and
|
||||
the computed edge_index
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param radius: The radius within which points are connected.
|
||||
:type radius: float
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
`GraphBuilder` and `Graph` constructors.
|
||||
:return: A `Graph` instance containing the input information and the
|
||||
computed edge_index.
|
||||
:rtype: Graph
|
||||
"""
|
||||
edge_index = cls.compute_radius_graph(pos, radius)
|
||||
@@ -269,7 +296,8 @@ class RadiusGraph(GraphBuilder):
|
||||
@staticmethod
|
||||
def compute_radius_graph(points, radius):
|
||||
"""
|
||||
Computes a radius-based graph for a given set of points.
|
||||
Computes edge_index for a given set of points base on the radius.
|
||||
Each point is connected to all the points within the radius.
|
||||
|
||||
:param points: A tensor of shape (N, D) representing the positions of
|
||||
N points in D-dimensional space.
|
||||
@@ -295,7 +323,7 @@ class KNNGraph(GraphBuilder):
|
||||
def __new__(cls, pos, neighbours, **kwargs):
|
||||
"""
|
||||
Creates a new instance of the Graph class using k-nearest neighbors
|
||||
to compute edge_index.
|
||||
algorithm to define the edges.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
@@ -323,8 +351,9 @@ class KNNGraph(GraphBuilder):
|
||||
N points in D-dimensional space.
|
||||
:type points: torch.Tensor | LabelTensor
|
||||
:param int k: The number of nearest neighbors to find for each point.
|
||||
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
|
||||
:return: A tensor of shape (2, E), where E is the number of
|
||||
edges, representing the edge indices of the KNN graph.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
@@ -343,6 +372,11 @@ class LabelBatch(Batch):
|
||||
def from_data_list(cls, data_list):
|
||||
"""
|
||||
Create a Batch object from a list of Data objects.
|
||||
|
||||
:param data_list: List of Data/Graph objects
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
:return: A Batch object containing the data in the list
|
||||
:rtype: Batch
|
||||
"""
|
||||
# Store the labels of Data/Graph objects (all data have the same labels)
|
||||
# If the data do not contain labels, labels is an empty dictionary,
|
||||
|
||||
Reference in New Issue
Block a user