From 08de548e346dcafe90d6b53fb278b7b3596f82a7 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 13 Mar 2025 15:49:50 +0100 Subject: [PATCH] Improve doc condition --- pina/condition/condition.py | 96 +++++++++++----------- pina/condition/condition_interface.py | 42 ++++++---- pina/condition/data_condition.py | 8 +- pina/condition/input_equation_condition.py | 8 +- pina/condition/input_target_condition.py | 15 +++- 5 files changed, 103 insertions(+), 66 deletions(-) diff --git a/pina/condition/condition.py b/pina/condition/condition.py index 0e8bd34..4194259 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -1,6 +1,4 @@ -""" -Condition module. -""" +"""Condition module.""" import warnings from .data_condition import DataCondition @@ -15,11 +13,12 @@ warnings.filterwarnings("always", category=DeprecationWarning) def warning_function(new, old): - """ - Handle the deprecation warning. + """Handle the deprecation warning. - :param str new: Object to use instead of the old one. - :param str old: Object to deprecate. + :param new: Object to use instead of the old one. + :type new: str + :param old: Object to deprecate. + :type old: str """ warnings.warn( f"'{old}' is deprecated and will be removed " @@ -30,49 +29,58 @@ def warning_function(new, old): class Condition: """ - The class `Condition` is used to represent the constraints (physical + The class ``Condition`` is used to represent the constraints (physical equations, boundary conditions, etc.) that should be satisfied in the problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object. Conditions can be specified in four ways: - 1. By specifying the input and output points of the condition; in such a + 1. By specifying the input and target of the condition; in such a case, the model is trained to produce the output points given the input - points. Those points can either be torch.Tensor, LabelTensors, Graph + points. Those points can either be torch.Tensor, LabelTensors, Graph. + Based on the type of the input and target, there are different + implementations of the condition. For more details, see + :class:`~pina.condition.input_target_condition.InputTargetCondition`. - 2. By specifying the location and the equation of the condition; in such + 2. By specifying the domain and the equation of the condition; in such a case, the model is trained to minimize the equation residual by - evaluating it at some samples of the location. + evaluating it at some samples of the domain. - 3. By specifying the input points and the equation of the condition; in + 3. By specifying the input and the equation of the condition; in such a case, the model is trained to minimize the equation residual by evaluating it at the passed input points. The input points must be - a LabelTensor. + a LabelTensor. Based on the type of the input, there are different + implementations of the condition. For more details, see + :class:`~pina.condition.input_equation_condition.InputEquationCondition` + . - 4. By specifying only the data matrix; in such a case the model is + 4. By specifying only the input data; in such a case the model is trained with an unsupervised costum loss and uses the data in training. Additionaly conditioning variables can be passed, whenever the model - has extra conditioning variable it depends on. + has extra conditioning variable it depends on. Based on the type of the + input, there are different implementations of the condition. For more + details, see :class:`~pina.condition.data_condition.DataCondition`. Example:: - >>> from pina import Condition - >>> condition = Condition( - ... input=input, - ... target=target - ... ) - >>> condition = Condition( - ... domain=location, - ... equation=equation - ... ) - >>> condition = Condition( - ... input=input, - ... equation=equation - ... ) - >>> condition = Condition( - ... input=data, - ... conditional_variables=conditional_variables - ... ) + >>> from pina import Condition + >>> condition = Condition( + ... input=input, + ... target=target + ... ) + >>> condition = Condition( + ... domain=location, + ... equation=equation + ... ) + >>> condition = Condition( + ... input=input, + ... equation=equation + ... ) + >>> condition = Condition( + ... input=data, + ... conditional_variables=conditional_variables + ... ) + """ __slots__ = list( @@ -86,24 +94,14 @@ class Condition: def __new__(cls, *args, **kwargs): """ - Create a new condition object based on the keyword arguments passed. + Check the input arguments and return the appropriate Condition object. - - `input` and `target`: - :class:`~pina.condition.input_target_condition.InputTargetCondition` - - `domain` and `equation`: - :class:`~pina.condition.domain_equation_condition. - DomainEquationCondition` - - `input` and `equation`: :class:`~pina.condition. - input_equation_condition.InputEquationCondition` - - `input`: :class:`~pina.condition.data_condition.DataCondition` - - `input` and `conditional_variables`: - :class:`~pina.condition.data_condition.DataCondition` - :return: A new condition instance belonging to the proper class. - :rtype: InputTargetCondition | DomainEquationCondition | - InputEquationCondition | DataCondition - - :raises ValueError: No valid condition has been found. + :raises ValueError: If no keyword arguments are passed. + :raises ValueError: If the keyword arguments are invalid. + :return: The appropriate Condition object. + :rtype: ConditionInterface """ + if len(args) != 0: raise ValueError( "Condition takes only the following keyword " diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index 9e5b4df..41dfa7b 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -11,9 +11,15 @@ from ..graph import Graph class ConditionInterface(metaclass=ABCMeta): """ Abstract class which defines a common interface for all the conditions. + It defined a common interface for all the conditions. + """ def __init__(self): + """ + Initialize the ConditionInterface object. + """ + self._problem = None @property @@ -21,10 +27,9 @@ class ConditionInterface(metaclass=ABCMeta): """ Return the problem to which the condition is associated. - :return: Problem to which the condition is associated. + :return: Problem to which the condition is associated :rtype: pina.problem.AbstractProblem """ - return self._problem @problem.setter @@ -32,26 +37,35 @@ class ConditionInterface(metaclass=ABCMeta): """ Set the problem to which the condition is associated. - :param pina.problem.AbstractProblem value: Problem to which the - condition is associated. + :param pina.problem.abstract_problem.AbstractProblem value: Problem to + which the condition is associated """ - self._problem = value @staticmethod def _check_graph_list_consistency(data_list): """ - Check if the list of :class:`~torch_geometric.data.Data` or - class:`pina.graphGraph` objects is consistent. + Check the consistency of the list of Data/Graph objects. It performs + the following checks: - :param data_list: List of graph type objects. - :type data_list: Data | Graph | list[Data] | list[Graph] + 1. All elements in the list must be of the same type (either Data or + Graph). + 2. All elements in the list must have the same keys. + 3. The type of each tensor must be consistent across all elements in + the list. + 4. If the tensor is a LabelTensor, the labels must be consistent across + all elements in the list. - :raises ValueError: Input data must be either Data - or Graph objects. - :raises ValueError: All elements in the list must have the same keys. - :raises ValueError: Type mismatch in data tensors. - :raises ValueError: Label mismatch in LabelTensors. + :param data_list: List of Data/Graph objects to check + :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] + + :raises ValueError: If the input types are invalid. + :raises ValueError: If all elements in the list do not have the same + keys. + :raises ValueError: If the type of each tensor is not consistent across + all elements in the list. + :raises ValueError: If the labels of the LabelTensors are not consistent + across all elements in the list. """ # If the data is a Graph or Data object, return (do not need to check diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 4643b2a..4dd7eb1 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -12,7 +12,13 @@ from ..graph import Graph class DataCondition(ConditionInterface): """ Condition defined by input data and conditional variables. It can be used - in unsupervised learning problems. + in unsupervised learning problems. Based on the type of the input, + different condition implementations are available: + + - :class:`TensorDataCondition`: For :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` input data. + - :class:`GraphDataCondition`: For :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data` input data. """ __slots__ = ["input", "conditional_variables"] diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index db78a80..3494e1c 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -13,7 +13,13 @@ from ..equation.equation_interface import EquationInterface class InputEquationCondition(ConditionInterface): """ Condition defined by input data and an equation. This condition can be - used in a Physics Informed problems. + used in a Physics Informed problems. Based on the type of the input, + different condition implementations are available: + + - :class:`InputTensorEquationCondition`: For + :class:`~pina.label_tensor.LabelTensor` input data. + - :class:`InputGraphEquationCondition`: For :class:`~pina.graph.Graph` + input data. """ __slots__ = ["input", "equation"] diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 6d4c524..2465038 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -12,7 +12,20 @@ from .condition_interface import ConditionInterface class InputTargetCondition(ConditionInterface): """ Condition defined by input and target data. This condition can be used in - both supervised learning and Physics-informed problems. + both supervised learning and Physics-informed problems. Based on the type of + the input and target, different condition implementations are available: + + - :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` input and target data. + - :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` input and + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` + target data. + - :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph` + or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` + or :class:`~pina.label_tensor.LabelTensor` target data. + - :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data` input and target data. """ __slots__ = ["input", "target"]