From ae796ce34ce20608b9c7d1d4d1f67d1a4d204cd1 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 12 Mar 2025 12:03:47 +0100 Subject: [PATCH] Fix doc --- pina/condition/condition_interface.py | 4 +- pina/condition/data_condition.py | 5 +- pina/condition/input_equation_condition.py | 6 +- pina/condition/input_target_condition.py | 21 ++-- pina/graph.py | 114 ++++++++++++--------- pina/label_tensor.py | 75 ++++++++------ 6 files changed, 128 insertions(+), 97 deletions(-) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index d40c246..9e5b4df 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -32,8 +32,8 @@ class ConditionInterface(metaclass=ABCMeta): """ Set the problem to which the condition is associated. - :param value: Problem to which the condition is associated. - :type value: pina.problem.AbstractProblem + :param pina.problem.AbstractProblem value: Problem to which the + condition is associated. """ self._problem = value diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 15f3ef6..83ae598 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -12,7 +12,7 @@ from ..graph import Graph class DataCondition(ConditionInterface): """ This condition must be used every time a Unsupervised Loss is needed in - the Solver. The conditionalvariable can be passed as extra-input when + the Solver. The `conditional_variable` can be passed as extra-input when the model learns a conditional distribution. """ @@ -31,7 +31,8 @@ class DataCondition(ConditionInterface): :param conditional_variables: Conditional variables for the condition. :type conditional_variables: torch.Tensor | LabelTensor :return: Subclass of DataCondition. - :rtype: TensorDataCondition | GraphDataCondition + :rtype: pina.condition.data_condition.TensorDataCondition | + pina.condition.data_condition.GraphDataCondition :raises ValueError: If input is not of type :class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`, diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 61973ff..c6938da 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -30,7 +30,9 @@ class InputEquationCondition(ConditionInterface): :param EquationInterface equation: Equation object containing the equation function. :return: Subclass of InputEquationCondition, based on the input type. - :rtype: InputTensorEquationCondition | InputGraphEquationCondition + :rtype: pina.condition.input_equation_condition. + InputTensorEquationCondition | + pina.condition.input_equation_condition.InputGraphEquationCondition :raises ValueError: If input is not of type :class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`. @@ -105,7 +107,7 @@ class InputGraphEquationCondition(InputEquationCondition): in the :class:`pina.graph.Graph` object. :param input: Input data. - :type input: torch.Tensor | Graph | torch_geometric.data.Data + :type input: torch.Tensor | Graph | Data :raises ValueError: If the input data object does not contain at least one LabelTensor. diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index c58bdaa..cc8e292 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -25,23 +25,26 @@ class InputTargetCondition(ConditionInterface): the types of input and target data. :param input: Input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | + :type input: torch.Tensor | LabelTensor | Graph | + torch_geometric.data.Data | list[Graph] | list[torch_geometric.data.Data] | tuple[Graph] | tuple[torch_geometric.data.Data] :param target: Target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | + :type target: torch.Tensor | LabelTensor | Graph | + torch_geometric.data.Data | list[Graph] | list[torch_geometric.data.Data] | tuple[Graph] | tuple[torch_geometric.data.Data] :return: Subclass of InputTargetCondition - :rtype: TensorInputTensorTargetCondition | \ - TensorInputGraphTargetCondition | \ - GraphInputTensorTargetCondition | \ - GraphInputGraphTargetCondition + :rtype: pina.condition.input_target_condition. + TensorInputTensorTargetCondition | + pina.condition.input_target_condition. + TensorInputGraphTargetCondition | + pina.condition.input_target_condition. + GraphInputTensorTargetCondition | + pina.condition.input_target_condition.GraphInputGraphTargetCondition :raises ValueError: If input and or target are not of type - :class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`, + :class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. """ if cls != InputTargetCondition: diff --git a/pina/graph.py b/pina/graph.py index d4a5a19..f89e728 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -1,5 +1,5 @@ """ -This module provides an interface to build torch_geometric.data.Data objects. +Module to build Graph objects and perform operations on them. """ import torch @@ -11,7 +11,8 @@ from .utils import check_consistency, is_function class Graph(Data): """ - A class to build torch_geometric.data.Data objects. + Extends :class:`~torch_geometric.data.Data` class to include additional + checks and functionlities. """ def __new__( @@ -19,13 +20,16 @@ class Graph(Data): **kwargs, ): """ - Instantiates a new instance of the Graph class, performing type - consistency checks. + Create a new instance of the :class:`pina.graph.Graph` class by checking + the consistency of the input data and storing the attributes. - :param kwargs: Parameters to construct the Graph object. - :return: A new instance of the Graph class. + :param kwargs: Parameters used to initialize the + :class:`pina.graph.Graph` object. + :type kwargs: dict + :return: A new instance of the :class:`pina.graph.Graph` class. :rtype: Graph """ + # create class instance instance = Data.__new__(cls) @@ -45,27 +49,28 @@ class Graph(Data): **kwargs, ): """ - Initialize the Graph object by setting the node features, edge index, + Initialize the object by setting the node features, edge index, edge attributes, and positions. The edge index is preprocessed to make the graph undirected if required. For more details, see the - :meth: `torch_geometric.data.Data` + :meth:`torch_geometric.data.Data` - :param x: Optional tensor of node features (N, F) where F is the number - of features per node. + :param x: Optional tensor of node features `(N, F)` where `F` is the + number of features per node. :type x: torch.Tensor, LabelTensor - :param torch.Tensor edge_index: A tensor of shape (2, E) representing + :param torch.Tensor edge_index: A tensor of shape `(2, E)` representing the indices of the graph's edges. - :param pos: A tensor of shape (N, D) representing the positions of N - points in D-dimensional space. + :param pos: A tensor of shape `(N, D)` representing the positions of `N` + points in `D`-dimensional space. :type pos: torch.Tensor | LabelTensor - :param edge_attr: Optional tensor of edge_featured (E, F') where F' is - the number of edge features + :param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'` + is the number of edge features :param bool undirected: Whether to make the graph undirected :param kwargs: Additional keyword arguments passed to the `torch_geometric.data.Data` class constructor. If the argument - is a `torch.Tensor` or `LabelTensor`, it is included in the Data - object as a graph parameter. + is a `torch.Tensor` or `LabelTensor`, it is included in the graph + object. """ + # preprocessing self._preprocess_edge_index(edge_index, undirected) @@ -107,6 +112,8 @@ class Graph(Data): Check if the position tensor is consistent. :param torch.Tensor pos: The position tensor. + + :raises ValueError: If the position tensor is not consistent. """ if pos is not None: @@ -120,6 +127,8 @@ class Graph(Data): Check if the edge index is consistent. :param torch.Tensor edge_index: The edge index tensor. + + :raises ValueError: If the edge index tensor is not consistent. """ check_consistency(edge_index, (torch.Tensor, LabelTensor)) @@ -136,6 +145,9 @@ class Graph(Data): :param torch.Tensor edge_attr: The edge attribute tensor. :param torch.Tensor edge_index: The edge index tensor. + + :raises ValueError: If the edge attribute tensor is not consistent. + """ if edge_attr is not None: @@ -217,27 +229,28 @@ class GraphBuilder: **kwargs, ): """ - Compute the edge attributes and create a new instance of the Graph - class. + Compute the edge attributes and create a new instance of the + :class:`pina.graph.Graph` class. - :param pos: A tensor of shape (N, D) representing the positions of N - points in D-dimensional space. + :param pos: A tensor of shape `(N, D)` representing the positions of `N` + points in `D`-dimensional space. :type pos: torch.Tensor or LabelTensor - :param edge_index: A tensor of shape (2, E) representing the indices of - the graph's edges. + :param edge_index: A tensor of shape `(2, E)` representing the indices + of the graph's edges. :type edge_index: torch.Tensor - :param x: Optional tensor of node features of shape (N, F), where F is - the number of features per node. + :param x: Optional tensor of node features of shape `(N, F)`, where `F` + is the number of features per node. :type x: torch.Tensor | LabelTensor, optional - :param edge_attr: Optional tensor of edge attributes of shape (E, F), - where F is the number of features per edge. + :param edge_attr: Optional tensor of edge attributes of shape `(E, F)`, + where `F` is the number of features per edge. :type edge_attr: torch.Tensor, optional :param custom_edge_func: A custom function to compute edge attributes. If provided, overrides `edge_attr`. :type custom_edge_func: callable, optional - :param kwargs: Additional keyword arguments passed to the Graph class - constructor. - :return: A Graph instance constructed using the provided information. + :param kwargs: Additional keyword arguments passed to the + :class:`pina.graph.Graph` class constructor. + :return: A :class:`pina.graph.Graph` instance constructed using the + provided information. :rtype: Graph """ edge_attr = cls._create_edge_attr( @@ -271,42 +284,46 @@ class GraphBuilder: class RadiusGraph(GraphBuilder): """ - A class to build a radius graph. + A class to build a graph based on a radius. """ def __new__(cls, pos, radius, **kwargs): """ - Extends the `GraphBuilder` class to compute edge_index based on a - radius. Each point is connected to all the points within the radius. + Extends the :class:`pina.graph.GraphBuilder` class to compute + edge_index based on a radius. Each point is connected to all the points + within the radius. - :param pos: A tensor of shape (N, D) representing the positions of N - points in D-dimensional space. + :param pos: A tensor of shape `(N, D)` representing the positions of `N` + points in `D`-dimensional space. :type pos: torch.Tensor or LabelTensor :param radius: The radius within which points are connected. :type radius: float :param kwargs: Additional keyword arguments to be passed to the - `GraphBuilder` and `Graph` constructors. - :return: A `Graph` instance containing the input information and the - computed edge_index. + :class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph` + constructors. + :return: A :class:`pina.graph.Graph` instance containing the input + information and the computed edge_index. :rtype: Graph """ + edge_index = cls.compute_radius_graph(pos, radius) return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) @staticmethod def compute_radius_graph(points, radius): """ - Computes edge_index for a given set of points base on the radius. + Computes `edge_index` for a given set of points base on the radius. Each point is connected to all the points within the radius. - :param points: A tensor of shape (N, D) representing the positions of + :param points: A tensor of shape `(N, D)` representing the positions of N points in D-dimensional space. :type points: torch.Tensor | LabelTensor :param float radius: The number of nearest neighbors to find for each point. - :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of - edges, representing the edge indices of the KNN graph. + :rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number + of edges, representing the edge indices of the KNN graph. """ + dist = torch.cdist(points, points, p=2) return ( torch.nonzero(dist <= radius, as_tuple=False) @@ -317,13 +334,13 @@ class RadiusGraph(GraphBuilder): class KNNGraph(GraphBuilder): """ - A class to build a KNN graph. + A class to build a K-nearest neighbors graph. """ def __new__(cls, pos, neighbours, **kwargs): """ - Creates a new instance of the Graph class using k-nearest neighbors - algorithm to define the edges. + Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index + based on a K-nearest neighbors algorithm. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. @@ -334,8 +351,8 @@ class KNNGraph(GraphBuilder): The additional keyword arguments to be passed to GraphBuilder and Graph classes - :return: Graph instance containg the information passed in input and - the computed edge_index + :return: A :class:`pina.graph.Graph` instance containg the + information passed in input and the computed edge_index :rtype: Graph """ @@ -371,7 +388,8 @@ class LabelBatch(Batch): @classmethod def from_data_list(cls, data_list): """ - Create a Batch object from a list of Data objects. + Create a Batch object from a list of :class:`~torch_geometric.data.Data` + objects. :param data_list: List of Data/Graph objects :type data_list: list[Data] | list[Graph] diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 9109083..eb67a56 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -14,13 +14,15 @@ class LabelTensor(torch.Tensor): @staticmethod def __new__(cls, x, labels, *args, **kwargs): """ - Create a new instance of the :class:`LabelTensor` class. + Create a new instance of the :class:`pina.label_tensor.LabelTensor` + class. :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a - :class:`LabelTensor`. + :class:`pina.label_tensor.LabelTensor`. :param labels: Labels to assign to the tensor. :type labels: str | list[str] | dict - :return: The instance of the :class:`LabelTensor` class. + :return: The instance of the :class:`pina.label_tensor.LabelTensor` + class. :rtype: LabelTensor """ @@ -31,9 +33,10 @@ class LabelTensor(torch.Tensor): @property def tensor(self): """ - Give the tensor part of the :class:`LabelTensor` object. + Give the tensor part of the :class:`pina.label_tensor.LabelTensor` + object. - :return: tensor part of the :class:`LabelTensor`. + :return: tensor part of the :class:`pina.label_tensor.LabelTensor`. :rtype: torch.Tensor """ @@ -41,9 +44,10 @@ class LabelTensor(torch.Tensor): def __init__(self, x, labels): """ - Construct a :class:`LabelTensor` by passing a dict of the labels and a - :class:`torch.Tensor`. Internally, the initialization method will store - check the compatibility of the labels with the tensor shape. + Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of + the labels and a :class:`torch.Tensor`. Internally, the initialization + method will store check the compatibility of the labels with the tensor + shape. :Example: >>> from pina import LabelTensor @@ -271,9 +275,10 @@ class LabelTensor(torch.Tensor): def __str__(self): """ - The string representation of the :class:`LabelTensor`. + The string representation of the :class:`pina.label_tensor.LabelTensor`. - :return: String representation of the :class:`LabelTensor` instance. + :return: String representation of the + :class:`pina.label_tensor.LabelTensor` instance. :rtype: str """ @@ -290,8 +295,8 @@ class LabelTensor(torch.Tensor): Concatenate a list of tensors along a specified dimension. For more details, see :meth:`torch.cat`. - :param list[LabelTensor] tensors: :class:`LabelTensor` instances to - concatenate + :param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor` + instances to concatenate :param int dim: dimensions on which you want to perform the operation (default is 0) :return: A new :class:`LabelTensor' instance obtained by concatenating @@ -346,8 +351,8 @@ class LabelTensor(torch.Tensor): :param list[LabelTensor] tensors: A list of tensors to stack. All tensors must have the same shape. - :return: A new :class:`LabelTensor` instance obtained by stacking the - input tensors, with the updated labels. + :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained + by stacking the input tensors, with the updated labels. :rtype: LabelTensor """ @@ -368,8 +373,8 @@ class LabelTensor(torch.Tensor): :param bool mode: A boolean value indicating whether the tensor should track gradients.If `True`, the tensor will track gradients; if `False`, it will not. - :return: The :class:`LabelTensor` itself with the updated - `requires_grad` state and retained labels. + :return: The :class:`pina.label_tensor.LabelTensor` itself with the + updated `requires_grad` state and retained labels. :rtype: LabelTensor """ @@ -394,8 +399,8 @@ class LabelTensor(torch.Tensor): Performs Tensor dtype and/or device conversion. For more details, see :meth:`torch.Tensor.to`. - :return: A new :class:`LabelTensor` instance with the updated dtype - and/or device and retained labels. + :return: A new :class:`pina.label_tensor.LabelTensor` instance with the + updated dtype and/or device and retained labels. :rtype: LabelTensor """ @@ -405,11 +410,11 @@ class LabelTensor(torch.Tensor): def clone(self, *args, **kwargs): """ - Clone the :class:`LabelTensor`. For more details, see + Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see :meth:`torch.Tensor.clone`. - :return: A new :class:`LabelTensor` instance with the same data and - labels but allocated in a different memory location. + :return: A new :class:`pina.label_tensor.LabelTensor` instance with the + same data and labels but allocated in a different memory location. :rtype: LabelTensor """ @@ -466,10 +471,11 @@ class LabelTensor(torch.Tensor): """ Stack tensors vertically. For more details, see :meth:`torch.vstack`. - :param list of LabelTensor label_tensors: The :class:`LabelTensor` - instances to stack. They need to have equal labels. - :return: A new :class:`LabelTensor` instance obtained by stacking the - input tensors vertically. + :param list of LabelTensor label_tensors: The + :class:`pina.label_tensor.LabelTensor` instances to stack. They need + to have equal labels. + :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained + by stacking the input tensors vertically. :rtype: LabelTensor """ @@ -524,15 +530,15 @@ class LabelTensor(torch.Tensor): def __getitem__(self, index): """ " Override the __getitem__ method to handle the labels of the - :class:`LabelTensor` instance. It first performs __getitem__ operation - on the :class:`torch.Tensor` part of the instance, then updates the - labels based on the index. + :class:`pina.label_tensor.LabelTensor` instance. It first performs + __getitem__ operation on the :class:`torch.Tensor` part of the instance, + then updates the labels based on the index. :param index: The index used to access the item :type index: int | str | tuple of int | list ot int | torch.Tensor - :return: A new :class:`LabelTensor` instance obtained __getitem__ - operation on :class:`torch.Tensor` part of the instance, with the - updated labels. + :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained + `__getitem__` operation on :class:`torch.Tensor` part of the + instance, with the updated labels. :rtype: LabelTensor :raises KeyError: If an invalid label index is provided. @@ -665,7 +671,8 @@ class LabelTensor(torch.Tensor): @staticmethod def summation(tensors): """ - Computes the summation of a list of :class:`LabelTensor` instances. + Computes the summation of a list of + :class:`pina.label_tensor.LabelTensor` instances. :param list[LabelTensor] tensors: A list of tensors to sum. All @@ -712,8 +719,8 @@ class LabelTensor(torch.Tensor): For more details, see :meth:`torch.Tensor.reshape`. :param tuple of int shape: The new shape of the tensor. - :return: A new :class:`LabelTensor` instance with the updated shape and - labels. + :return: A new :class:`pina.label_tensor.LabelTensor` instance with the + updated shape and labels. :rtype: LabelTensor """