Fix doc
This commit is contained in:
committed by
Nicola Demo
parent
59e6ee595c
commit
ae796ce34c
114
pina/graph.py
114
pina/graph.py
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
This module provides an interface to build torch_geometric.data.Data objects.
|
||||
Module to build Graph objects and perform operations on them.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -11,7 +11,8 @@ from .utils import check_consistency, is_function
|
||||
|
||||
class Graph(Data):
|
||||
"""
|
||||
A class to build torch_geometric.data.Data objects.
|
||||
Extends :class:`~torch_geometric.data.Data` class to include additional
|
||||
checks and functionlities.
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
@@ -19,13 +20,16 @@ class Graph(Data):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiates a new instance of the Graph class, performing type
|
||||
consistency checks.
|
||||
Create a new instance of the :class:`pina.graph.Graph` class by checking
|
||||
the consistency of the input data and storing the attributes.
|
||||
|
||||
:param kwargs: Parameters to construct the Graph object.
|
||||
:return: A new instance of the Graph class.
|
||||
:param kwargs: Parameters used to initialize the
|
||||
:class:`pina.graph.Graph` object.
|
||||
:type kwargs: dict
|
||||
:return: A new instance of the :class:`pina.graph.Graph` class.
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
# create class instance
|
||||
instance = Data.__new__(cls)
|
||||
|
||||
@@ -45,27 +49,28 @@ class Graph(Data):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Graph object by setting the node features, edge index,
|
||||
Initialize the object by setting the node features, edge index,
|
||||
edge attributes, and positions. The edge index is preprocessed to make
|
||||
the graph undirected if required. For more details, see the
|
||||
:meth: `torch_geometric.data.Data`
|
||||
:meth:`torch_geometric.data.Data`
|
||||
|
||||
:param x: Optional tensor of node features (N, F) where F is the number
|
||||
of features per node.
|
||||
:param x: Optional tensor of node features `(N, F)` where `F` is the
|
||||
number of features per node.
|
||||
:type x: torch.Tensor, LabelTensor
|
||||
:param torch.Tensor edge_index: A tensor of shape (2, E) representing
|
||||
:param torch.Tensor edge_index: A tensor of shape `(2, E)` representing
|
||||
the indices of the graph's edges.
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||
points in `D`-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param edge_attr: Optional tensor of edge_featured (E, F') where F' is
|
||||
the number of edge features
|
||||
:param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'`
|
||||
is the number of edge features
|
||||
:param bool undirected: Whether to make the graph undirected
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
`torch_geometric.data.Data` class constructor. If the argument
|
||||
is a `torch.Tensor` or `LabelTensor`, it is included in the Data
|
||||
object as a graph parameter.
|
||||
is a `torch.Tensor` or `LabelTensor`, it is included in the graph
|
||||
object.
|
||||
"""
|
||||
|
||||
# preprocessing
|
||||
self._preprocess_edge_index(edge_index, undirected)
|
||||
|
||||
@@ -107,6 +112,8 @@ class Graph(Data):
|
||||
Check if the position tensor is consistent.
|
||||
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
|
||||
:raises ValueError: If the position tensor is not consistent.
|
||||
"""
|
||||
|
||||
if pos is not None:
|
||||
@@ -120,6 +127,8 @@ class Graph(Data):
|
||||
Check if the edge index is consistent.
|
||||
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
|
||||
:raises ValueError: If the edge index tensor is not consistent.
|
||||
"""
|
||||
|
||||
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
||||
@@ -136,6 +145,9 @@ class Graph(Data):
|
||||
|
||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
|
||||
:raises ValueError: If the edge attribute tensor is not consistent.
|
||||
|
||||
"""
|
||||
|
||||
if edge_attr is not None:
|
||||
@@ -217,27 +229,28 @@ class GraphBuilder:
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Compute the edge attributes and create a new instance of the Graph
|
||||
class.
|
||||
Compute the edge attributes and create a new instance of the
|
||||
:class:`pina.graph.Graph` class.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||
points in `D`-dimensional space.
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param edge_index: A tensor of shape (2, E) representing the indices of
|
||||
the graph's edges.
|
||||
:param edge_index: A tensor of shape `(2, E)` representing the indices
|
||||
of the graph's edges.
|
||||
:type edge_index: torch.Tensor
|
||||
:param x: Optional tensor of node features of shape (N, F), where F is
|
||||
the number of features per node.
|
||||
:param x: Optional tensor of node features of shape `(N, F)`, where `F`
|
||||
is the number of features per node.
|
||||
:type x: torch.Tensor | LabelTensor, optional
|
||||
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
|
||||
where F is the number of features per edge.
|
||||
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
|
||||
where `F` is the number of features per edge.
|
||||
:type edge_attr: torch.Tensor, optional
|
||||
:param custom_edge_func: A custom function to compute edge attributes.
|
||||
If provided, overrides `edge_attr`.
|
||||
:type custom_edge_func: callable, optional
|
||||
:param kwargs: Additional keyword arguments passed to the Graph class
|
||||
constructor.
|
||||
:return: A Graph instance constructed using the provided information.
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
:class:`pina.graph.Graph` class constructor.
|
||||
:return: A :class:`pina.graph.Graph` instance constructed using the
|
||||
provided information.
|
||||
:rtype: Graph
|
||||
"""
|
||||
edge_attr = cls._create_edge_attr(
|
||||
@@ -271,42 +284,46 @@ class GraphBuilder:
|
||||
|
||||
class RadiusGraph(GraphBuilder):
|
||||
"""
|
||||
A class to build a radius graph.
|
||||
A class to build a graph based on a radius.
|
||||
"""
|
||||
|
||||
def __new__(cls, pos, radius, **kwargs):
|
||||
"""
|
||||
Extends the `GraphBuilder` class to compute edge_index based on a
|
||||
radius. Each point is connected to all the points within the radius.
|
||||
Extends the :class:`pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a radius. Each point is connected to all the points
|
||||
within the radius.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||
points in `D`-dimensional space.
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param radius: The radius within which points are connected.
|
||||
:type radius: float
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
`GraphBuilder` and `Graph` constructors.
|
||||
:return: A `Graph` instance containing the input information and the
|
||||
computed edge_index.
|
||||
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
|
||||
constructors.
|
||||
:return: A :class:`pina.graph.Graph` instance containing the input
|
||||
information and the computed edge_index.
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
edge_index = cls.compute_radius_graph(pos, radius)
|
||||
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def compute_radius_graph(points, radius):
|
||||
"""
|
||||
Computes edge_index for a given set of points base on the radius.
|
||||
Computes `edge_index` for a given set of points base on the radius.
|
||||
Each point is connected to all the points within the radius.
|
||||
|
||||
:param points: A tensor of shape (N, D) representing the positions of
|
||||
:param points: A tensor of shape `(N, D)` representing the positions of
|
||||
N points in D-dimensional space.
|
||||
:type points: torch.Tensor | LabelTensor
|
||||
:param float radius: The number of nearest neighbors to find for each
|
||||
point.
|
||||
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
|
||||
edges, representing the edge indices of the KNN graph.
|
||||
:rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number
|
||||
of edges, representing the edge indices of the KNN graph.
|
||||
"""
|
||||
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
return (
|
||||
torch.nonzero(dist <= radius, as_tuple=False)
|
||||
@@ -317,13 +334,13 @@ class RadiusGraph(GraphBuilder):
|
||||
|
||||
class KNNGraph(GraphBuilder):
|
||||
"""
|
||||
A class to build a KNN graph.
|
||||
A class to build a K-nearest neighbors graph.
|
||||
"""
|
||||
|
||||
def __new__(cls, pos, neighbours, **kwargs):
|
||||
"""
|
||||
Creates a new instance of the Graph class using k-nearest neighbors
|
||||
algorithm to define the edges.
|
||||
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
|
||||
based on a K-nearest neighbors algorithm.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
@@ -334,8 +351,8 @@ class KNNGraph(GraphBuilder):
|
||||
The additional keyword arguments to be passed to GraphBuilder
|
||||
and Graph classes
|
||||
|
||||
:return: Graph instance containg the information passed in input and
|
||||
the computed edge_index
|
||||
:return: A :class:`pina.graph.Graph` instance containg the
|
||||
information passed in input and the computed edge_index
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
@@ -371,7 +388,8 @@ class LabelBatch(Batch):
|
||||
@classmethod
|
||||
def from_data_list(cls, data_list):
|
||||
"""
|
||||
Create a Batch object from a list of Data objects.
|
||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||
objects.
|
||||
|
||||
:param data_list: List of Data/Graph objects
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
|
||||
Reference in New Issue
Block a user