From 10ccae3a334451e58b56ac50e90c8d7411dd37fa Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 14 Mar 2025 10:20:56 +0100 Subject: [PATCH] Tmp fixes --- pina/collector.py | 14 ++-- pina/condition/condition.py | 4 +- pina/condition/data_condition.py | 4 +- pina/condition/input_equation_condition.py | 4 +- pina/condition/input_target_condition.py | 16 ++-- pina/data/data_module.py | 6 +- pina/data/dataset.py | 6 +- pina/graph.py | 89 +++++++++++----------- 8 files changed, 74 insertions(+), 69 deletions(-) diff --git a/pina/collector.py b/pina/collector.py index c7a60be..ce4d77e 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -37,10 +37,11 @@ class Collector: @property def full(self): """ - Returns True if the collector is full. The collector is considered full - if all conditions have entries in the data_collection dictionary. + Returns ``True`` if the collector is full. The collector is considered + full if all conditions have entries in the ``data_collection`` + dictionary. - :return: True if all conditions are ready, False otherwise. + :return: ``True`` if all conditions are ready, ``False`` otherwise. :rtype: bool """ @@ -49,9 +50,9 @@ class Collector: @full.setter def full(self, value): """ - Set the full variable. + Set the ``_full`` variable. - :param bool value: The value to set the full property to. + :param bool value: The value to set the ``_full`` variable. """ check_consistency(value, bool) @@ -117,7 +118,8 @@ class Collector: """ Store inside data collections the sampled data of the problem. These comes from the conditions that require sampling (e.g. - DomainEquationCondition). + :class:`~pina.condition.domain_equation_condition. + DomainEquationCondition`). """ for condition_name in self.problem.conditions: diff --git a/pina/condition/condition.py b/pina/condition/condition.py index 4194259..f8a81f9 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -32,7 +32,7 @@ class Condition: The class ``Condition`` is used to represent the constraints (physical equations, boundary conditions, etc.) that should be satisfied in the problem at hand. Condition objects are used to formulate the - PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object. + PINA :class:`~pina.problem.abstract_problem.AbstractProblem` object. Conditions can be specified in four ways: 1. By specifying the input and target of the condition; in such a @@ -61,7 +61,7 @@ class Condition: input, there are different implementations of the condition. For more details, see :class:`~pina.condition.data_condition.DataCondition`. - Example:: + :Example: >>> from pina import Condition >>> condition = Condition( diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 4dd7eb1..196c3d2 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -28,7 +28,7 @@ class DataCondition(ConditionInterface): def __new__(cls, input, conditional_variables=None): """ Instantiate the appropriate subclass of :class:`DataCondition` based on - the type of `input`. + the type of ``input``. :param input: Input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | @@ -72,7 +72,7 @@ class DataCondition(ConditionInterface): :type conditional_variables: torch.Tensor or LabelTensor .. note:: - If either `input` is composed by a list of + If either ``input`` is composed by a list of :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`, all elements must have the same structure (keys and data types) diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 3494e1c..16a521c 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -29,7 +29,7 @@ class InputEquationCondition(ConditionInterface): def __new__(cls, input, equation): """ Instantiate the appropriate subclass of :class:`InputEquationCondition` - based on the type of `input`. + based on the type of ``input``. :param input: Input data for the condition. :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] @@ -74,7 +74,7 @@ class InputEquationCondition(ConditionInterface): equation function. .. note:: - If `input` is composed by a list of :class:`~pina.graph.Graph` + If ``input`` is composed by a list of :class:`~pina.graph.Graph` objects, all elements must have the same structure (keys and data types). Moreover, at least one attribute must be a :class:`~pina.label_tensor.LabelTensor`. diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 2465038..3ed571d 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -52,7 +52,7 @@ class InputTargetCondition(ConditionInterface): GraphInputTensorTargetCondition | pina.condition.input_target_condition.GraphInputGraphTargetCondition - :raises ValueError: If `input` and/or `target` are not of type + :raises ValueError: If ``input`` and/or ``target`` are not of type :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. """ @@ -94,7 +94,7 @@ class InputTargetCondition(ConditionInterface): def __init__(self, input, target): """ - Initialize the object by storing the `input` and `target` data. + Initialize the object by storing the ``input`` and ``target`` data. :param input: Input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | @@ -104,7 +104,7 @@ class InputTargetCondition(ConditionInterface): list[Data] | tuple[Graph] | tuple[Data] .. note:: - If either `input` or `target` are composed by a list of + If either ``input`` or ``target`` are composed by a list of :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` objects, all elements must have the same structure (keys and data types) @@ -130,14 +130,14 @@ class InputTargetCondition(ConditionInterface): class TensorInputTensorTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` `input` and `target` data. + :class:`~pina.label_tensor.LabelTensor` ``input`` and ``target`` data. """ class TensorInputGraphTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` `input` and + :class:`~pina.label_tensor.LabelTensor` ``input`` and :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target` data. """ @@ -146,13 +146,13 @@ class TensorInputGraphTargetCondition(InputTargetCondition): class GraphInputTensorTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`~pina.graph.Graph` o - :class:`~torch_geometric.data.Data` `input` and :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` `target` data. + :class:`~torch_geometric.data.Data` ``input`` and :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` ``target`` data. """ class GraphInputGraphTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`~pina.graph.Graph`/ - :class:`~torch_geometric.data.Data` `input` and `target` data. + :class:`~torch_geometric.data.Data` ``input`` and ``target`` data. """ diff --git a/pina/data/data_module.py b/pina/data/data_module.py index f5f7d98..56d84e3 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -17,7 +17,7 @@ from ..collector import Collector class DummyDataloader: """ - Dataloader used when batch size is `None`. It returns the entire dataset + Dataloader used when batch size is ``None``. It returns the entire dataset in a single batch. """ @@ -38,7 +38,7 @@ class DummyDataloader: :param dataset: The dataset object to be processed. :type dataset: PinaDataset - .. note:: This data loader is used when the batch size is `None`. + .. note:: This data loader is used when the batch size is ``None``. """ if ( @@ -273,7 +273,7 @@ class PinaDataModule(LightningDataModule): be in the range [0, 1]. :param float val_size: Fraction of elements in the validation split. It must be in the range [0, 1]. - :param batch_size: The batch size used for training. If `None`, the + :param batch_size: The batch size used for training. If ``None``, the entire dataset is returned in a single batch. :type batch_size: int | None :param bool shuffle: Whether to shuffle the dataset before splitting. diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 0e24d22..452da22 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -33,8 +33,8 @@ class PinaDatasetFactory: :param dict conditions_dict: Dictionary containing all the conditions to be included in the dataset instance. :return: A subclass of :class:`~pina.data.dataset.PinaDataset`. - :rtype: :class:`~pina.data.dataset.PinaTensorDataset` | - :class:`~pina.data.dataset.PinaGraphDataset` + :rtype: pina.data.dataset.PinaTensorDataset | + pina.data.dataset.PinaGraphDataset :raises ValueError: If an empty dictionary is provided. """ @@ -255,7 +255,7 @@ class PinaGraphDataset(PinaDataset): def _create_tensor_batch(self, data): """ - Reshape properly `data` tensor to be processed handle by the graph + Reshape properly ``data`` tensor to be processed handle by the graph based models. :param data: torch.Tensor object of shape (N, ...) where N is the diff --git a/pina/graph.py b/pina/graph.py index 45cd0e6..74dc91f 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -29,7 +29,6 @@ class Graph(Data): :return: A new instance of the :class:`~pina.graph.Graph` class. :rtype: Graph """ - # create class instance instance = Data.__new__(cls) @@ -54,23 +53,20 @@ class Graph(Data): the graph undirected if required. For more details, see the :meth:`torch_geometric.data.Data` - :param x: Optional tensor of node features `(N, F)` where `F` is the + :param x: Optional tensor of node features ``(N, F)`` where ``F`` is the number of features per node. :type x: torch.Tensor, LabelTensor - :param torch.Tensor edge_index: A tensor of shape `(2, E)` representing + :param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing the indices of the graph's edges. - :param pos: A tensor of shape `(N, D)` representing the positions of `N` - points in `D`-dimensional space. + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. :type pos: torch.Tensor | LabelTensor - :param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'` - is the number of edge features + :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where + ``F'`` is the number of edge features :param bool undirected: Whether to make the graph undirected :param kwargs: Additional keyword arguments passed to the - `torch_geometric.data.Data` class constructor. If the argument - is a `torch.Tensor` or `LabelTensor`, it is included in the graph - object. + :class:`~torch_geometric.data.Data` class constructor. """ - # preprocessing self._preprocess_edge_index(edge_index, undirected) @@ -86,7 +82,6 @@ class Graph(Data): :param kwargs: Attributes to be checked for consistency. :type kwargs: dict """ - # default types, specified in cls.__new__, by default they are Nont # if specified in **kwargs they get override x, pos, edge_index, edge_attr = None, None, None, None @@ -110,12 +105,9 @@ class Graph(Data): def _check_pos_consistency(pos): """ Check if the position tensor is consistent. - :param torch.Tensor pos: The position tensor. - :raises ValueError: If the position tensor is not consistent. """ - if pos is not None: check_consistency(pos, (torch.Tensor, LabelTensor)) if pos.ndim != 2: @@ -127,10 +119,8 @@ class Graph(Data): Check if the edge index is consistent. :param torch.Tensor edge_index: The edge index tensor. - :raises ValueError: If the edge index tensor is not consistent. """ - check_consistency(edge_index, (torch.Tensor, LabelTensor)) if edge_index.ndim != 2: raise ValueError("edge_index must be a 2D tensor.") @@ -145,11 +135,8 @@ class Graph(Data): :param torch.Tensor edge_attr: The edge attribute tensor. :param torch.Tensor edge_index: The edge index tensor. - :raises ValueError: If the edge attribute tensor is not consistent. - """ - if edge_attr is not None: check_consistency(edge_attr, (torch.Tensor, LabelTensor)) if edge_attr.ndim != 2: @@ -171,7 +158,6 @@ class Graph(Data): :param torch.Tensor x: The input tensor. :param torch.Tensor pos: The position tensor. """ - if x is not None: check_consistency(x, (torch.Tensor, LabelTensor)) if x.ndim != 2: @@ -193,7 +179,6 @@ class Graph(Data): :return: The preprocessed edge index. :rtype: torch.Tensor """ - if undirected: edge_index = to_undirected(edge_index) return edge_index @@ -246,7 +231,7 @@ class GraphBuilder: :type edge_attr: torch.Tensor, optional :param custom_edge_func: A custom function to compute edge attributes. If provided, overrides `edge_attr`. - :type custom_edge_func: callable, optional + :type custom_edge_func: Callable, optional :param kwargs: Additional keyword arguments passed to the :class:`~pina.graph.Graph` class constructor. :return: A :class:`~pina.graph.Graph` instance constructed using the @@ -266,6 +251,18 @@ class GraphBuilder: @staticmethod def _create_edge_attr(pos, edge_index, edge_attr, func): + """ + Create the edge attributes based on the input parameters. + + :param pos: Positions of the points. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: Edge indices. + :param bool edge_attr: Whether to compute the edge attributes. + :param Callable func: Function to compute the edge attributes. + :raises ValueError: If ``func`` is not a function. + :return: The edge attributes. + :rtype: torch.Tensor | LabelTensor | None + """ check_consistency(edge_attr, bool) if edge_attr: if is_function(func): @@ -275,6 +272,14 @@ class GraphBuilder: @staticmethod def _build_edge_attr(pos, edge_index): + """ + Default function to compute the edge attributes. + :param pos: Positions of the points. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: Edge indices. + :return: The edge attributes. + :rtype: torch.Tensor + """ return ( (pos[edge_index[0]] - pos[edge_index[1]]) .abs() @@ -293,37 +298,34 @@ class RadiusGraph(GraphBuilder): edge_index based on a radius. Each point is connected to all the points within the radius. - :param pos: A tensor of shape `(N, D)` representing the positions of `N` - points in `D`-dimensional space. - :type pos: torch.Tensor or LabelTensor - :param radius: The radius within which points are connected. - :type radius: float + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param float radius: The radius within which points are connected. :param kwargs: Additional keyword arguments to be passed to the :class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph` constructors. :return: A :class:`~pina.graph.Graph` instance containing the input - information and the computed edge_index. + information and the computed ``edge_index``. :rtype: Graph """ - edge_index = cls.compute_radius_graph(pos, radius) return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) @staticmethod def compute_radius_graph(points, radius): """ - Computes `edge_index` for a given set of points base on the radius. + Computes ``edge_index`` for a given set of points base on the radius. Each point is connected to all the points within the radius. - :param points: A tensor of shape `(N, D)` representing the positions of - N points in D-dimensional space. + :param points: A tensor of shape ``(N, D)`` representing the positions + of ``N`` points in ``D``-dimensional space. :type points: torch.Tensor | LabelTensor :param float radius: The number of nearest neighbors to find for each point. - :rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number - of edges, representing the edge indices of the KNN graph. + :rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the + number of edges, representing the edge indices of the KNN graph. """ - dist = torch.cdist(points, points, p=2) return ( torch.nonzero(dist <= radius, as_tuple=False) @@ -342,8 +344,8 @@ class KNNGraph(GraphBuilder): Extends the :class:`~pina.graph.GraphBuilder` class to compute edge_index based on a K-nearest neighbors algorithm. - :param pos: A tensor of shape (N, D) representing the positions of N - points in D-dimensional space. + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. :type pos: torch.Tensor | LabelTensor :param int neighbours: The number of nearest neighbors to consider when building the graph. @@ -352,7 +354,7 @@ class KNNGraph(GraphBuilder): and Graph classes :return: A :class:`~pina.graph.Graph` instance containg the - information passed in input and the computed edge_index + information passed in input and the computed ``edge_index`` :rtype: Graph """ @@ -364,11 +366,11 @@ class KNNGraph(GraphBuilder): """ Computes the edge_index based k-nearest neighbors graph algorithm - :param points: A tensor of shape (N, D) representing the positions of - N points in D-dimensional space. + :param points: A tensor of shape ``(N, D)`` representing the positions + of ``N`` points in ``D``-dimensional space. :type points: torch.Tensor | LabelTensor :param int k: The number of nearest neighbors to find for each point. - :return: A tensor of shape (2, E), where E is the number of + :return: A tensor of shape ``(2, E)``, where ``E`` is the number of edges, representing the edge indices of the KNN graph. :rtype: torch.Tensor """ @@ -391,7 +393,8 @@ class LabelBatch(Batch): Create a Batch object from a list of :class:`~torch_geometric.data.Data` objects. - :param data_list: List of Data/Graph objects + :param data_list: List of :class:`~torch_geometric.data.Data` or + :class:`~pina.graph.Graph` objects. :type data_list: list[Data] | list[Graph] :return: A Batch object containing the data in the list :rtype: Batch