From f587c3bf65afa729ea52afde5cf81b00a865435a Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 12 Mar 2025 12:29:18 +0100 Subject: [PATCH] Other fixes --- pina/collector.py | 10 +++--- pina/condition/condition.py | 10 +++--- pina/condition/data_condition.py | 8 ++--- pina/condition/input_equation_condition.py | 17 ++++----- pina/condition/input_target_condition.py | 18 +++++----- pina/data/data_module.py | 8 ++--- pina/data/dataset.py | 34 +++++++++--------- pina/graph.py | 22 ++++++------ pina/label_tensor.py | 40 +++++++++++----------- 9 files changed, 85 insertions(+), 82 deletions(-) diff --git a/pina/collector.py b/pina/collector.py index f912fb8..4a7fcdc 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -18,7 +18,8 @@ class Collector: and the problem and initializing the data collections (dictionary where data will be stored). - :param AbstractProblem problem: The problem to collect data from. + :param pina.problem.abstract_problem.AbstractProblem problem: The + problem to collect data from. """ # creating a hook between collector and problem self.problem = problem @@ -71,16 +72,17 @@ class Collector: Problem connected to the collector. :return: The problem from which the data is collected. - :rtype: AbstractProblem + :rtype: pina.problem.abstract_problem.AbstractProblem """ return self._problem @problem.setter def problem(self, value): """ - Ser the problem connected to the collector. + Set the problem connected to the collector. - :param AbstractProblem value: The problem to connect to the collector. + :param pina.problem.abstract_problem.AbstractProblem value: The problem + to connect to the collector. """ self._problem = value diff --git a/pina/condition/condition.py b/pina/condition/condition.py index c828a82..d44f124 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -89,15 +89,15 @@ class Condition: Create a new condition object based on the keyword arguments passed. - `input` and `target`: - :class:`pina.condition.input_target_condition.InputTargetCondition` + :class:`~pina.condition.input_target_condition.InputTargetCondition` - `domain` and `equation`: - :class:`pina.condition.domain_equation_condition. + :class:`~pina.condition.domain_equation_condition. DomainEquationCondition` - - `input` and `equation`: :class:`pina.condition. + - `input` and `equation`: :class:`~pina.condition. input_equation_condition.InputEquationCondition` - - `input`: :class:`pina.condition.data_condition.DataCondition` + - `input`: :class:`~pina.condition.data_condition.DataCondition` - `input` and `conditional_variables`: - :class:`pina.condition.data_condition.DataCondition` + :class:`~pina.condition.data_condition.DataCondition` :raises ValueError: No valid condition has been found. :return: A new condition instance belonging to the proper class. diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 83ae598..d0d0702 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -35,7 +35,7 @@ class DataCondition(ConditionInterface): pina.condition.data_condition.GraphDataCondition :raises ValueError: If input is not of type :class:`torch.Tensor`, - :class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`, + :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. @@ -68,7 +68,7 @@ class DataCondition(ConditionInterface): :type conditional_variables: torch.Tensor or LabelTensor .. note:: - If either `input` is composed by a list of :class:`pina.graph.Graph` + If either `input` is composed by a list of :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` objects, all elements must have the same structure (keys and data types) """ @@ -80,12 +80,12 @@ class DataCondition(ConditionInterface): class TensorDataCondition(DataCondition): """ DataCondition for :class:`torch.Tensor` or - :class:`pina.label_tensor.LabelTensor` input data + :class:`~pina.label_tensor.LabelTensor` input data """ class GraphDataCondition(DataCondition): """ - DataCondition for :class:`pina.graph.Graph` or + DataCondition for :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` input data """ diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index c6938da..7553906 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -35,7 +35,7 @@ class InputEquationCondition(ConditionInterface): pina.condition.input_equation_condition.InputGraphEquationCondition :raises ValueError: If input is not of type - :class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`. + :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`. """ # If the class is already a subclass, return the instance @@ -62,15 +62,16 @@ class InputEquationCondition(ConditionInterface): Initialize the InputEquationCondition by storing the input and equation. :param input: Input data for the condition. - :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] + :type input: pina.label_tensor.LabelTensor | pina.graph.Graph | + list[pina.graph.Graph] | tuple[pina.graph.Graph] :param EquationInterface equation: Equation object containing the equation function. .. note:: - If `input` is composed by a list of :class:`pina.graph.Graph` + If `input` is composed by a list of :class:`~pina.graph.Graph` objects, all elements must have the same structure (keys and data types). Moreover, at least one attribute must be a - :class:`pina.label_tensor.LabelTensor`. + :class:`~pina.label_tensor.LabelTensor`. """ super().__init__() @@ -90,21 +91,21 @@ class InputEquationCondition(ConditionInterface): class InputTensorEquationCondition(InputEquationCondition): """ - InputEquationCondition subclass for :class:`pina.label_tensor.LabelTensor` + InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor` input data. """ class InputGraphEquationCondition(InputEquationCondition): """ - InputEquationCondition subclass for :class:`pina.graph.Graph` input data. + InputEquationCondition subclass for :class:`~pina.graph.Graph` input data. """ @staticmethod def _check_label_tensor(input): """ - Check if at least one :class:`pina.label_tensor.LabelTensor` is present - in the :class:`pina.graph.Graph` object. + Check if at least one :class:`~pina.label_tensor.LabelTensor` is present + in the :class:`~pina.graph.Graph` object. :param input: Input data. :type input: torch.Tensor | Graph | Data diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index cc8e292..11070fe 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -44,8 +44,8 @@ class InputTargetCondition(ConditionInterface): pina.condition.input_target_condition.GraphInputGraphTargetCondition :raises ValueError: If input and or target are not of type - :class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`, - :class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. + :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, + :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. """ if cls != InputTargetCondition: return super().__new__(cls) @@ -97,7 +97,7 @@ class InputTargetCondition(ConditionInterface): .. note:: If either `input` or `target` are composed by a list of - :class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data` + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` objects, all elements must have the same structure (keys and data types) """ @@ -122,29 +122,29 @@ class InputTargetCondition(ConditionInterface): class TensorInputTensorTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`pina.label_tensor.LabelTensor` input and target data. + :class:`~pina.label_tensor.LabelTensor` input and target data. """ class TensorInputGraphTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`pina.label_tensor.LabelTensor` input and - :class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data` target + :class:`~pina.label_tensor.LabelTensor` input and + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target data. """ class GraphInputTensorTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`pina.graph.Graph` o + InputTargetCondition subclass for :class:`~pina.graph.Graph` o :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or - :class:`pina.label_tensor.LabelTensor` target data. + :class:`~pina.label_tensor.LabelTensor` target data. """ class GraphInputGraphTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`pina.graph.Graph`/ + InputTargetCondition subclass for :class:`~pina.graph.Graph`/ :class:`~torch_geometric.data.Data` input and target data. """ diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 56a8f50..e8bf702 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -157,7 +157,7 @@ class Collator: def _collate_tensor_dataset(data_list): """ Function used to collate the data when the dataset is a - :class:`pina.data.dataset.PinaTensorDataset`. + :class:`~pina.data.dataset.PinaTensorDataset`. :param data_list: Elements to be collated. :type data_list: list[torch.Tensor] | list[LabelTensor] @@ -165,7 +165,7 @@ class Collator: :rtype: dict :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a - :class:`pina.label_tensor.LabelTensor`. + :class:`~pina.label_tensor.LabelTensor`. """ if isinstance(data_list[0], LabelTensor): @@ -177,7 +177,7 @@ class Collator: def _collate_graph_dataset(self, data_list): """ Function used to collate the data when the dataset is a - :class:`pina.data.dataset.PinaGraphDataset`. + :class:`~pina.data.dataset.PinaGraphDataset`. :param data_list: Elememts to be collated. :type data_list: list[Data] | list[Graph] @@ -185,7 +185,7 @@ class Collator: :rtype: dict :raises RuntimeError: If the data is not a - :class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`. + :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. """ if isinstance(data_list[0], LabelTensor): diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 2effe60..9b7ff87 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -15,26 +15,26 @@ class PinaDatasetFactory: Depending on the type inside the conditions, it creates a different dataset object: - - :class:`pina.data.dataset.PinaTensorDataset` for handling - :class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data. - - :class:`pina.data.dataset.PinaGraphDataset` for handling - :class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. + - :class:`~pina.data.dataset.PinaTensorDataset` for handling + :class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data. + - :class:`~pina.data.dataset.PinaGraphDataset` for handling + :class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. """ def __new__(cls, conditions_dict, **kwargs): """ Instantiate the appropriate subclass of - :class:`pina.data.dataset.PinaDataset`. + :class:`~pina.data.dataset.PinaDataset`. If a graph is present in the conditions, returns a - :class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a - :class:`pina.data.dataset.PinaTensorDataset`. + :class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a + :class:`~pina.data.dataset.PinaTensorDataset`. :param dict conditions_dict: Dictionary containing all the conditions to be included in the dataset instance. - :return: A subclass of :class:`pina.data.dataset.PinaDataset`. - :rtype: :class:`pina.data.dataset.PinaTensorDataset` | - :class:`pina.data.dataset.PinaGraphDataset` + :return: A subclass of :class:`~pina.data.dataset.PinaDataset`. + :rtype: :class:`~pina.data.dataset.PinaTensorDataset` | + :class:`~pina.data.dataset.PinaGraphDataset` :raises ValueError: If an empty dictionary is provided. """ @@ -75,15 +75,15 @@ class PinaDatasetFactory: class PinaDataset(Dataset, ABC): """ Abstract class for the PINA dataset. It defines the common interface for - the :class:`pina.data.dataset.PinaTensorDataset` and - :class:`pina.data.dataset.PinaGraphDataset` classes. + the :class:`~pina.data.dataset.PinaTensorDataset` and + :class:`~pina.data.dataset.PinaGraphDataset` classes. """ def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): """ - Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing + Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing the providedconditions dictionary, the maximum number of conditions to consider, and the automatic batching flag. @@ -93,7 +93,7 @@ class PinaDataset(Dataset, ABC): points to include in a single batch for each condition. :param bool automatic_batching: Indicates whether PyTorch automatic batching is enabled in - :class:`pina.data.data_module.PinaDataModule`. + :class:`~pina.data.data_module.PinaDataModule`. """ # Store the conditions dictionary @@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC): class PinaTensorDataset(PinaDataset): """ Dataset class for the PINA dataset with :class:`torch.Tensor` and - :class:`pina.label_tensor.LabelTensor` data. + :class:`~pina.label_tensor.LabelTensor` data. """ # Override _retrive_data method for torch.Tensor data @@ -215,7 +215,7 @@ class PinaTensorDataset(PinaDataset): :param dict data: Dictionary containing the data (only :class:`torch.Tensor` or - :class:`pina.label_tensor.LabelTensor`). + :class:`~pina.label_tensor.LabelTensor`). :param list[int] idx_list: indices to retrieve. :return: Dictionary containing the data at the given indices. :rtype: dict @@ -237,7 +237,7 @@ class PinaTensorDataset(PinaDataset): class PinaGraphDataset(PinaDataset): """ Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data` - and :class:`pina.graph.Graph` data. + and :class:`~pina.graph.Graph` data. """ def _create_graph_batch(self, data): diff --git a/pina/graph.py b/pina/graph.py index f89e728..c6f5335 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -20,13 +20,13 @@ class Graph(Data): **kwargs, ): """ - Create a new instance of the :class:`pina.graph.Graph` class by checking + Create a new instance of the :class:`~pina.graph.Graph` class by checking the consistency of the input data and storing the attributes. :param kwargs: Parameters used to initialize the - :class:`pina.graph.Graph` object. + :class:`~pina.graph.Graph` object. :type kwargs: dict - :return: A new instance of the :class:`pina.graph.Graph` class. + :return: A new instance of the :class:`~pina.graph.Graph` class. :rtype: Graph """ @@ -230,7 +230,7 @@ class GraphBuilder: ): """ Compute the edge attributes and create a new instance of the - :class:`pina.graph.Graph` class. + :class:`~pina.graph.Graph` class. :param pos: A tensor of shape `(N, D)` representing the positions of `N` points in `D`-dimensional space. @@ -248,8 +248,8 @@ class GraphBuilder: If provided, overrides `edge_attr`. :type custom_edge_func: callable, optional :param kwargs: Additional keyword arguments passed to the - :class:`pina.graph.Graph` class constructor. - :return: A :class:`pina.graph.Graph` instance constructed using the + :class:`~pina.graph.Graph` class constructor. + :return: A :class:`~pina.graph.Graph` instance constructed using the provided information. :rtype: Graph """ @@ -289,7 +289,7 @@ class RadiusGraph(GraphBuilder): def __new__(cls, pos, radius, **kwargs): """ - Extends the :class:`pina.graph.GraphBuilder` class to compute + Extends the :class:`~pina.graph.GraphBuilder` class to compute edge_index based on a radius. Each point is connected to all the points within the radius. @@ -299,9 +299,9 @@ class RadiusGraph(GraphBuilder): :param radius: The radius within which points are connected. :type radius: float :param kwargs: Additional keyword arguments to be passed to the - :class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph` + :class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph` constructors. - :return: A :class:`pina.graph.Graph` instance containing the input + :return: A :class:`~pina.graph.Graph` instance containing the input information and the computed edge_index. :rtype: Graph """ @@ -339,7 +339,7 @@ class KNNGraph(GraphBuilder): def __new__(cls, pos, neighbours, **kwargs): """ - Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index + Extends the :class:`~pina.graph.GraphBuilder` class to compute edge_index based on a K-nearest neighbors algorithm. :param pos: A tensor of shape (N, D) representing the positions of N @@ -351,7 +351,7 @@ class KNNGraph(GraphBuilder): The additional keyword arguments to be passed to GraphBuilder and Graph classes - :return: A :class:`pina.graph.Graph` instance containg the + :return: A :class:`~pina.graph.Graph` instance containg the information passed in input and the computed edge_index :rtype: Graph """ diff --git a/pina/label_tensor.py b/pina/label_tensor.py index eb67a56..48e7f62 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -14,14 +14,14 @@ class LabelTensor(torch.Tensor): @staticmethod def __new__(cls, x, labels, *args, **kwargs): """ - Create a new instance of the :class:`pina.label_tensor.LabelTensor` + Create a new instance of the :class:`~pina.label_tensor.LabelTensor` class. :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a - :class:`pina.label_tensor.LabelTensor`. + :class:`~pina.label_tensor.LabelTensor`. :param labels: Labels to assign to the tensor. :type labels: str | list[str] | dict - :return: The instance of the :class:`pina.label_tensor.LabelTensor` + :return: The instance of the :class:`~pina.label_tensor.LabelTensor` class. :rtype: LabelTensor """ @@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor): @property def tensor(self): """ - Give the tensor part of the :class:`pina.label_tensor.LabelTensor` + Give the tensor part of the :class:`~pina.label_tensor.LabelTensor` object. - :return: tensor part of the :class:`pina.label_tensor.LabelTensor`. + :return: tensor part of the :class:`~pina.label_tensor.LabelTensor`. :rtype: torch.Tensor """ @@ -44,7 +44,7 @@ class LabelTensor(torch.Tensor): def __init__(self, x, labels): """ - Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of + Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of the labels and a :class:`torch.Tensor`. Internally, the initialization method will store check the compatibility of the labels with the tensor shape. @@ -275,10 +275,10 @@ class LabelTensor(torch.Tensor): def __str__(self): """ - The string representation of the :class:`pina.label_tensor.LabelTensor`. + The string representation of the :class:`~pina.label_tensor.LabelTensor`. :return: String representation of the - :class:`pina.label_tensor.LabelTensor` instance. + :class:`~pina.label_tensor.LabelTensor` instance. :rtype: str """ @@ -295,7 +295,7 @@ class LabelTensor(torch.Tensor): Concatenate a list of tensors along a specified dimension. For more details, see :meth:`torch.cat`. - :param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor` + :param list[LabelTensor] tensors: :class:`~pina.label_tensor.LabelTensor` instances to concatenate :param int dim: dimensions on which you want to perform the operation (default is 0) @@ -351,7 +351,7 @@ class LabelTensor(torch.Tensor): :param list[LabelTensor] tensors: A list of tensors to stack. All tensors must have the same shape. - :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained + :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained by stacking the input tensors, with the updated labels. :rtype: LabelTensor """ @@ -373,7 +373,7 @@ class LabelTensor(torch.Tensor): :param bool mode: A boolean value indicating whether the tensor should track gradients.If `True`, the tensor will track gradients; if `False`, it will not. - :return: The :class:`pina.label_tensor.LabelTensor` itself with the + :return: The :class:`~pina.label_tensor.LabelTensor` itself with the updated `requires_grad` state and retained labels. :rtype: LabelTensor """ @@ -399,7 +399,7 @@ class LabelTensor(torch.Tensor): Performs Tensor dtype and/or device conversion. For more details, see :meth:`torch.Tensor.to`. - :return: A new :class:`pina.label_tensor.LabelTensor` instance with the + :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the updated dtype and/or device and retained labels. :rtype: LabelTensor """ @@ -410,10 +410,10 @@ class LabelTensor(torch.Tensor): def clone(self, *args, **kwargs): """ - Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see + Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see :meth:`torch.Tensor.clone`. - :return: A new :class:`pina.label_tensor.LabelTensor` instance with the + :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the same data and labels but allocated in a different memory location. :rtype: LabelTensor """ @@ -472,9 +472,9 @@ class LabelTensor(torch.Tensor): Stack tensors vertically. For more details, see :meth:`torch.vstack`. :param list of LabelTensor label_tensors: The - :class:`pina.label_tensor.LabelTensor` instances to stack. They need + :class:`~pina.label_tensor.LabelTensor` instances to stack. They need to have equal labels. - :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained + :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained by stacking the input tensors vertically. :rtype: LabelTensor """ @@ -530,13 +530,13 @@ class LabelTensor(torch.Tensor): def __getitem__(self, index): """ " Override the __getitem__ method to handle the labels of the - :class:`pina.label_tensor.LabelTensor` instance. It first performs + :class:`~pina.label_tensor.LabelTensor` instance. It first performs __getitem__ operation on the :class:`torch.Tensor` part of the instance, then updates the labels based on the index. :param index: The index used to access the item :type index: int | str | tuple of int | list ot int | torch.Tensor - :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained + :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained `__getitem__` operation on :class:`torch.Tensor` part of the instance, with the updated labels. :rtype: LabelTensor @@ -672,7 +672,7 @@ class LabelTensor(torch.Tensor): def summation(tensors): """ Computes the summation of a list of - :class:`pina.label_tensor.LabelTensor` instances. + :class:`~pina.label_tensor.LabelTensor` instances. :param list[LabelTensor] tensors: A list of tensors to sum. All @@ -719,7 +719,7 @@ class LabelTensor(torch.Tensor): For more details, see :meth:`torch.Tensor.reshape`. :param tuple of int shape: The new shape of the tensor. - :return: A new :class:`pina.label_tensor.LabelTensor` instance with the + :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the updated shape and labels. :rtype: LabelTensor """