From a0015c3af6ef09e228e4f009accb7a5736d07d22 Mon Sep 17 00:00:00 2001 From: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:47:06 +0200 Subject: [PATCH] add exhaustive doc for condition module (#629) --- pina/condition/condition.py | 172 +++++++++----------- pina/condition/condition_interface.py | 83 +++++----- pina/condition/data_condition.py | 79 ++++++--- pina/condition/domain_equation_condition.py | 40 ++++- pina/condition/input_equation_condition.py | 110 ++++++++----- pina/condition/input_target_condition.py | 128 ++++++++++----- 6 files changed, 366 insertions(+), 246 deletions(-) diff --git a/pina/condition/condition.py b/pina/condition/condition.py index 05a377e..ad8764c 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -1,100 +1,91 @@ """Module for the Condition class.""" -import warnings from .data_condition import DataCondition from .domain_equation_condition import DomainEquationCondition from .input_equation_condition import InputEquationCondition from .input_target_condition import InputTargetCondition -from ..utils import custom_warning_format - -# Set the custom format for warnings -warnings.formatwarning = custom_warning_format -warnings.filterwarnings("always", category=DeprecationWarning) - - -def warning_function(new, old): - """Handle the deprecation warning. - - :param new: Object to use instead of the old one. - :type new: str - :param old: Object to deprecate. - :type old: str - """ - warnings.warn( - f"'{old}' is deprecated and will be removed " - f"in future versions. Please use '{new}' instead.", - DeprecationWarning, - ) class Condition: """ - Represents constraints (such as physical equations, boundary conditions, - etc.) that must be satisfied in a given problem. Condition objects are used - to formulate the PINA - :class:`~pina.problem.abstract_problem.AbstractProblem` object. + The :class:`Condition` class is a core component of the PINA framework that + provides a unified interface to define heterogeneous constraints that must + be satisfied by a :class:`~pina.problem.abstract_problem.AbstractProblem`. - There are different types of conditions: + It encapsulates all types of constraints - physical, boundary, initial, or + data-driven - that the solver must satisfy during training. The specific + behavior is inferred from the arguments passed to the constructor. + + Multiple types of conditions can be used within the same problem, allowing + for a high degree of flexibility in defining complex problems. + + The :class:`Condition` class behavior specializes internally based on the + arguments provided during instantiation. Depending on the specified keyword + arguments, the class automatically selects the appropriate internal + implementation. + + + Available `Condition` types: - :class:`~pina.condition.input_target_condition.InputTargetCondition`: - Defined by specifying both the input and the target of the condition. In - this case, the model is trained to produce the target given the input. The - input and output data must be one of the :class:`torch.Tensor`, - :class:`~pina.label_tensor.LabelTensor`, - :class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`. - Different implementations exist depending on the type of input and target. - For more details, see - :class:`~pina.condition.input_target_condition.InputTargetCondition`. + represents a supervised condition defined by both ``input`` and ``target`` + data. The model is trained to reproduce the ``target`` values given the + ``input``. Supported data types include :class:`torch.Tensor`, + :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or + :class:`~torch_geometric.data.Data`. + The class automatically selects the appropriate implementation based on + the types of ``input`` and ``target``. - :class:`~pina.condition.domain_equation_condition.DomainEquationCondition` - : Defined by specifying both the domain and the equation of the condition. - Here, the model is trained to minimize the equation residual by evaluating - it at sampled points within the domain. + : represents a general physics-informed condition defined by a ``domain`` + and an ``equation``. The model learns to minimize the equation residual + through evaluations performed at points sampled from the specified domain. - :class:`~pina.condition.input_equation_condition.InputEquationCondition`: - Defined by specifying the input and the equation of the condition. In this - case, the model is trained to minimize the equation residual by evaluating - it at the provided input. The input must be either a - :class:`~pina.label_tensor.LabelTensor` or a :class:`~pina.graph.Graph`. - Different implementations exist depending on the type of input. For more - details, see - :class:`~pina.condition.input_equation_condition.InputEquationCondition`. + represents a general physics-informed condition defined by ``input`` + points and an ``equation``. The model learns to minimize the equation + residual through evaluations performed at the provided ``input``. + Supported data types for the ``input`` include + :class:`~pina.label_tensor.LabelTensor` or :class:`~pina.graph.Graph`. + The class automatically selects the appropriate implementation based on + the types of the ``input``. - - :class:`~pina.condition.data_condition.DataCondition`: - Defined by specifying only the input. In this case, the model is trained - with an unsupervised custom loss while using the provided data during - training. The input data must be one of :class:`torch.Tensor`, - :class:`~pina.label_tensor.LabelTensor`, - :class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`. - Additionally, conditional variables can be provided when the model - depends on extra parameters. These conditional variables must be either - :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`. - Different implementations exist depending on the type of input. - For more details, see - :class:`~pina.condition.data_condition.DataCondition`. + - :class:`~pina.condition.data_condition.DataCondition`: represents an + unsupervised, data-driven condition defined by the ``input`` only. + The model is trained using a custom unsupervised loss determined by the + chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging the + provided data during training. Optional ``conditional_variables`` can be + specified when the model depends on additional parameters. + Supported data types include :class:`torch.Tensor`, + :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or + :class:`~torch_geometric.data.Data`. + The class automatically selects the appropriate implementation based on + the type of the ``input``. + + .. note:: + + The user should always instantiate :class:`Condition` directly, without + manually creating subclass instances. Please refer to the specific + :class:`Condition` classes for implementation details. :Example: >>> from pina import Condition - >>> condition = Condition( - ... input=input, - ... target=target - ... ) - >>> condition = Condition( - ... domain=location, - ... equation=equation - ... ) - >>> condition = Condition( - ... input=input, - ... equation=equation - ... ) - >>> condition = Condition( - ... input=data, - ... conditional_variables=conditional_variables - ... ) + >>> # Example of InputTargetCondition signature + >>> condition = Condition(input=input, target=target) + + >>> # Example of DomainEquationCondition signature + >>> condition = Condition(domain=domain, equation=equation) + + >>> # Example of InputEquationCondition signature + >>> condition = Condition(input=input, equation=equation) + + >>> # Example of DataCondition signature + >>> condition = Condition(input=data, conditional_variables=cond_vars) """ + # Combine all possible keyword arguments from the different Condition types __slots__ = list( set( InputTargetCondition.__slots__ @@ -106,46 +97,45 @@ class Condition: def __new__(cls, *args, **kwargs): """ - Instantiate the appropriate Condition object based on the keyword - arguments passed. + Instantiate the appropriate :class:`Condition` object based on the + keyword arguments passed. - :raises ValueError: If no keyword arguments are passed. + :param tuple args: The positional arguments (should be empty). + :param dict kwargs: The keyword arguments corresponding to the + parameters of the specific :class:`Condition` type to instantiate. + :raises ValueError: If unexpected positional arguments are provided. :raises ValueError: If the keyword arguments are invalid. - :return: The appropriate Condition object. + :return: The appropriate :class:`Condition` object. :rtype: ConditionInterface """ - + # Check keyword arguments if len(args) != 0: raise ValueError( "Condition takes only the following keyword " f"arguments: {Condition.__slots__}." ) - # back-compatibility 0.1 - keys = list(kwargs.keys()) - if "location" in keys: - kwargs["domain"] = kwargs.pop("location") - warning_function(new="domain", old="location") - - if "input_points" in keys: - kwargs["input"] = kwargs.pop("input_points") - warning_function(new="input", old="input_points") - - if "output_points" in keys: - kwargs["target"] = kwargs.pop("output_points") - warning_function(new="target", old="output_points") - + # Class specialization based on keyword arguments sorted_keys = sorted(kwargs.keys()) + + # Input - Target Condition if sorted_keys == sorted(InputTargetCondition.__slots__): return InputTargetCondition(**kwargs) + + # Input - Equation Condition if sorted_keys == sorted(InputEquationCondition.__slots__): return InputEquationCondition(**kwargs) + + # Domain - Equation Condition if sorted_keys == sorted(DomainEquationCondition.__slots__): return DomainEquationCondition(**kwargs) + + # Data Condition if ( sorted_keys == sorted(DataCondition.__slots__) or sorted_keys[0] == DataCondition.__slots__[0] ): return DataCondition(**kwargs) + # Invalid keyword arguments raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index ee20845..b026451 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -8,24 +8,25 @@ from ..graph import Graph class ConditionInterface(metaclass=ABCMeta): """ - Abstract class which defines a common interface for all the conditions. - It defined a common interface for all the conditions. + Abstract base class for PINA conditions. All specific conditions must + inherit from this interface. + Refer to :class:`pina.condition.condition.Condition` for a thorough + description of all available conditions and how to instantiate them. """ def __init__(self): """ - Initialize the ConditionInterface object. + Initialization of the :class:`ConditionInterface` class. """ - self._problem = None @property def problem(self): """ - Return the problem to which the condition is associated. + Return the problem associated with this condition. - :return: Problem to which the condition is associated. + :return: Problem associated with this condition. :rtype: ~pina.problem.abstract_problem.AbstractProblem """ return self._problem @@ -33,31 +34,32 @@ class ConditionInterface(metaclass=ABCMeta): @problem.setter def problem(self, value): """ - Set the problem to which the condition is associated. + Set the problem associated with this condition. - :param pina.problem.abstract_problem.AbstractProblem value: Problem to - which the condition is associated + :param pina.problem.abstract_problem.AbstractProblem value: The problem + to associate with this condition """ self._problem = value @staticmethod def _check_graph_list_consistency(data_list): """ - Check the consistency of the list of Data/Graph objects. It performs - the following checks: + Check the consistency of the list of Data | Graph objects. + The following checks are performed: - 1. All elements in the list must be of the same type (either Data or - Graph). - 2. All elements in the list must have the same keys. - 3. The type of each tensor must be consistent across all elements in - the list. - 4. If the tensor is a LabelTensor, the labels must be consistent across - all elements in the list. + - All elements in the list must be of the same type (either + :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - :param data_list: List of Data/Graph objects to check + - All elements in the list must have the same keys. + + - The data type of each tensor must be consistent across all elements. + + - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels + must also be consistent across all elements. + + :param data_list: The list of Data | Graph objects to check. :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - - :raises ValueError: If the input types are invalid. + :raises ValueError: If the input types are invalid. :raises ValueError: If all elements in the list do not have the same keys. :raises ValueError: If the type of each tensor is not consistent across @@ -65,51 +67,45 @@ class ConditionInterface(metaclass=ABCMeta): :raises ValueError: If the labels of the LabelTensors are not consistent across all elements in the list. """ - - # If the data is a Graph or Data object, return (do not need to check - # anything) + # If the data is a Graph or Data object, perform no checks if isinstance(data_list, (Graph, Data)): return - # check all elements in the list are of the same type + # Check all elements in the list are of the same type if not all(isinstance(i, (Graph, Data)) for i in data_list): raise ValueError( - "Invalid input types. " - "Please provide either Data or Graph objects." + "Invalid input. Please, provide either Data or Graph objects." ) + + # Store the keys, data types and labels of the first element data = data_list[0] - # Store the keys of the first element in the list keys = sorted(list(data.keys())) - - # Store the type of each tensor inside first element Data/Graph object data_types = {name: tensor.__class__ for name, tensor in data.items()} - - # Store the labels of each LabelTensor inside first element Data/Graph - # object labels = { name: tensor.labels for name, tensor in data.items() if isinstance(tensor, LabelTensor) } - # Iterate over the list of Data/Graph objects + # Iterate over the list of Data | Graph objects for data in data_list[1:]: - # Check if the keys of the current element are the same as the first - # element + + # Check that all elements in the list have the same keys if sorted(list(data.keys())) != keys: raise ValueError( "All elements in the list must have the same keys." ) + + # Iterate over the tensors in the current element for name, tensor in data.items(): - # Check if the type of each tensor inside the current element - # is the same as the first element + # Check that the type of each tensor is consistent if tensor.__class__ is not data_types[name]: raise ValueError( f"Data {name} must be a {data_types[name]}, got " f"{tensor.__class__}" ) - # If the tensor is a LabelTensor, check if the labels are the - # same as the first element + + # Check that the labels of each LabelTensor are consistent if isinstance(tensor, LabelTensor): if tensor.labels != labels[name]: raise ValueError( @@ -117,6 +113,13 @@ class ConditionInterface(metaclass=ABCMeta): ) def __getattribute__(self, name): + """ + Get an attribute from the object. + + :param str name: The name of the attribute to get. + :return: The requested attribute. + :rtype: Any + """ to_return = super().__getattribute__(name) if isinstance(to_return, (Graph, Data)): to_return = [to_return] diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 4ecd0ae..e948305 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -9,16 +9,35 @@ from ..graph import Graph class DataCondition(ConditionInterface): """ - Condition defined by input data and conditional variables. It can be used - in unsupervised learning problems. Based on the type of the input, - different condition implementations are available: + The class :class:`DataCondition` defines an unsupervised condition based on + ``input`` data. This condition is typically used in data-driven problems, + where the model is trained using a custom unsupervised loss determined by + the chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging + the provided data during training. Optional ``conditional_variables`` can be + specified when the model depends on additional parameters. - - :class:`TensorDataCondition`: For :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` input data. - - :class:`GraphDataCondition`: For :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data` input data. + The class automatically selects the appropriate implementation based on the + type of the ``input`` data. Depending on whether the ``input`` is a tensor + or graph-based data, one of the following specialized subclasses is + instantiated: + + - :class:`TensorDataCondition`: For cases where the ``input`` is either a + :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. + + - :class:`GraphDataCondition`: For cases where the ``input`` is either a + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` object. + + :Example: + + >>> from pina import Condition, LabelTensor + >>> import torch + + >>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + >>> cond_vars = LabelTensor(torch.randn(100, 1), labels=["w"]) + >>> condition = Condition(input=pts, conditional_variables=cond_vars) """ + # Available input data types __slots__ = ["input", "conditional_variables"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) @@ -26,33 +45,36 @@ class DataCondition(ConditionInterface): def __new__(cls, input, conditional_variables=None): """ Instantiate the appropriate subclass of :class:`DataCondition` based on - the type of ``input``. + the type of the ``input``. - :param input: Input data for the condition. + :param input: The input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param conditional_variables: Conditional variables for the condition. - :type conditional_variables: torch.Tensor | LabelTensor, optional - :return: Subclass of DataCondition. + :param conditional_variables: The conditional variables for the + condition. Default is ``None``. + :type conditional_variables: torch.Tensor | LabelTensor + :return: The subclass of DataCondition. :rtype: pina.condition.data_condition.TensorDataCondition | pina.condition.data_condition.GraphDataCondition - - :raises ValueError: If input is not of type :class:`torch.Tensor`, + :raises ValueError: If ``input`` is not of type :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. """ - if cls != DataCondition: return super().__new__(cls) + + # If the input is a tensor if isinstance(input, (torch.Tensor, LabelTensor)): subclass = TensorDataCondition return subclass.__new__(subclass, input, conditional_variables) + # If the input is a graph if isinstance(input, (Graph, Data, list, tuple)): cls._check_graph_list_consistency(input) subclass = GraphDataCondition return subclass.__new__(subclass, input, conditional_variables) + # If the input is not of the correct type raise an error raise ValueError( "Invalid input types. " "Please provide either torch_geometric.data.Data or Graph objects." @@ -60,21 +82,22 @@ class DataCondition(ConditionInterface): def __init__(self, input, conditional_variables=None): """ - Initialize the object by storing the input and conditional - variables (if any). + Initialization of the :class:`DataCondition` class. - :param input: Input data for the condition. + :param input: The input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param conditional_variables: Conditional variables for the condition. + :param conditional_variables: The conditional variables for the + condition. Default is ``None``. :type conditional_variables: torch.Tensor | LabelTensor .. note:: - If ``input`` consists of a list of :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data`, all elements must have the same - structure (keys and data types) - """ + If ``input`` is a list of :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data`, all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ super().__init__() self.input = input self.conditional_variables = conditional_variables @@ -82,13 +105,15 @@ class DataCondition(ConditionInterface): class TensorDataCondition(DataCondition): """ - DataCondition for :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` input data + Specialization of the :class:`DataCondition` class for the case where + ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a + :class:`torch.Tensor` object. """ class GraphDataCondition(DataCondition): """ - DataCondition for :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data` input data + Specialization of the :class:`DataCondition` class for the case where + ``input`` is either a :class:`~pina.graph.Graph` object or a + :class:`~torch_geometric.data.Data` object. """ diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index ee2b507..3565c0b 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -8,31 +8,57 @@ from ..equation.equation_interface import EquationInterface class DomainEquationCondition(ConditionInterface): """ - Condition defined by a domain and an equation. It can be used in Physics - Informed problems. Before using this condition, make sure that input data - are correctly sampled from the domain. + The class :class:`DomainEquationCondition` defines a condition based on a + ``domain`` and an ``equation``. This condition is typically used in + physics-informed problems, where the model is trained to satisfy a given + ``equation`` over a specified ``domain``. The ``domain`` is used to sample + points where the ``equation`` residual is evaluated and minimized during + training. + + :Example: + + >>> from pina.domain import CartesianDomain + >>> from pina.equation import Equation + >>> from pina import Condition + + >>> # Equation to be satisfied over the domain: # x^2 + y^2 - 1 = 0 + >>> def dummy_equation(pts): + ... return pts["x"]**2 + pts["y"]**2 - 1 + + >>> domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) + >>> condition = Condition(domain=domain, equation=Equation(dummy_equation)) """ + # Available slots __slots__ = ["domain", "equation"] def __init__(self, domain, equation): """ - Initialise the object by storing the domain and equation. + Initialization of the :class:`DomainEquationCondition` class. - :param DomainInterface domain: Domain object containing the domain data. - :param EquationInterface equation: Equation object containing the - equation data. + :param DomainInterface domain: The domain over which the equation is + defined. + :param EquationInterface equation: The equation to be satisfied over the + specified domain. """ super().__init__() self.domain = domain self.equation = equation def __setattr__(self, key, value): + """ + Set the attribute value with type checking. + + :param str key: The attribute name. + :param any value: The value to set for the attribute. + """ if key == "domain": check_consistency(value, (DomainInterface, str)) DomainEquationCondition.__dict__[key].__set__(self, value) + elif key == "equation": check_consistency(value, (EquationInterface)) DomainEquationCondition.__dict__[key].__set__(self, value) + elif key in ("_problem"): super().__setattr__(key, value) diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index a803a88..d325978 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,6 +1,5 @@ """Module for the InputEquationCondition class and its subclasses.""" -from torch_geometric.data import Data from .condition_interface import ConditionInterface from ..label_tensor import LabelTensor from ..graph import Graph @@ -10,16 +9,38 @@ from ..equation.equation_interface import EquationInterface class InputEquationCondition(ConditionInterface): """ - Condition defined by input data and an equation. This condition can be - used in a Physics Informed problems. Based on the type of the input, - different condition implementations are available: + The class :class:`InputEquationCondition` defines a condition based on + ``input`` data and an ``equation``. This condition is typically used in + physics-informed problems, where the model is trained to satisfy a given + ``equation`` through the evaluation of the residual performed at the + provided ``input``. - - :class:`InputTensorEquationCondition`: For \ - :class:`~pina.label_tensor.LabelTensor` input data. - - :class:`InputGraphEquationCondition`: For :class:`~pina.graph.Graph` \ - input data. + The class automatically selects the appropriate implementation based on + the type of the ``input`` data. Depending on whether the ``input`` is a + tensor or graph-based data, one of the following specialized subclasses is + instantiated: + + - :class:`InputTensorEquationCondition`: For cases where the ``input`` + data is a :class:`~pina.label_tensor.LabelTensor` object. + + - :class:`InputGraphEquationCondition`: For cases where the ``input`` data + is a :class:`~pina.graph.Graph` object. + + :Example: + + >>> from pina import Condition, LabelTensor + >>> from pina.equation import Equation + >>> import torch + + >>> # Equation to be satisfied over the input points: # x^2 + y^2 - 1 = 0 + >>> def dummy_equation(pts): + ... return pts["x"]**2 + pts["y"]**2 - 1 + + >>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + >>> condition = Condition(input=pts, equation=Equation(dummy_equation)) """ + # Available input data types __slots__ = ["input", "equation"] _avail_input_cls = (LabelTensor, Graph, list, tuple) _avail_equation_cls = EquationInterface @@ -27,31 +48,31 @@ class InputEquationCondition(ConditionInterface): def __new__(cls, input, equation): """ Instantiate the appropriate subclass of :class:`InputEquationCondition` - based on the type of ``input``. + based on the type of ``input`` data. - :param input: Input data for the condition. + :param input: The input data for the condition. :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] - :param EquationInterface equation: Equation object containing the - equation function. - :return: Subclass of InputEquationCondition, based on the input type. + :param EquationInterface equation: The equation to be satisfied over the + specified ``input`` data. + :return: The subclass of InputEquationCondition. :rtype: pina.condition.input_equation_condition. InputTensorEquationCondition | pina.condition.input_equation_condition.InputGraphEquationCondition - :raises ValueError: If input is not of type - :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`. + :raises ValueError: If input is not of type :class:`~pina.graph.Graph` + or :class:`~pina.label_tensor.LabelTensor`. """ - - # If the class is already a subclass, return the instance if cls != InputEquationCondition: return super().__new__(cls) - # Instanciate the correct subclass - if isinstance(input, (Graph, Data, list, tuple)): + # If the input is a Graph object + if isinstance(input, (Graph, list, tuple)): subclass = InputGraphEquationCondition cls._check_graph_list_consistency(input) subclass._check_label_tensor(input) return subclass.__new__(subclass, input, equation) + + # If the input is a LabelTensor if isinstance(input, LabelTensor): subclass = InputTensorEquationCondition return subclass.__new__(subclass, input, equation) @@ -63,69 +84,74 @@ class InputEquationCondition(ConditionInterface): def __init__(self, input, equation): """ - Initialize the object by storing the input data and equation object. + Initialization of the :class:`InputEquationCondition` class. - :param input: Input data for the condition. - :type input: LabelTensor | Graph | - list[Graph] | tuple[Graph] - :param EquationInterface equation: Equation object containing the - equation function. + :param input: The input data for the condition. + :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] + :param EquationInterface equation: The equation to be satisfied over the + specified input points. .. note:: - If ``input`` consists of a list of :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data`, all elements must have the same - structure (keys and data types) - """ + If ``input`` is a list of :class:`~pina.graph.Graph` all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ super().__init__() self.input = input self.equation = equation def __setattr__(self, key, value): + """ + Set the attribute value with type checking. + + :param str key: The attribute name. + :param any value: The value to set for the attribute. + """ if key == "input": check_consistency(value, self._avail_input_cls) InputEquationCondition.__dict__[key].__set__(self, value) + elif key == "equation": check_consistency(value, self._avail_equation_cls) InputEquationCondition.__dict__[key].__set__(self, value) + elif key in ("_problem"): super().__setattr__(key, value) class InputTensorEquationCondition(InputEquationCondition): """ - InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor` - input data. + Specialization of the :class:`InputEquationCondition` class for the case + where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. """ class InputGraphEquationCondition(InputEquationCondition): """ - InputEquationCondition subclass for :class:`~pina.graph.Graph` input data. + Specialization of the :class:`InputEquationCondition` class for the case + where ``input`` is a :class:`~pina.graph.Graph` object. """ @staticmethod def _check_label_tensor(input): """ Check if at least one :class:`~pina.label_tensor.LabelTensor` is present - in the :class:`~pina.graph.Graph` object. - - :param input: Input data. - :type input: torch.Tensor | Graph | Data + in the ``input`` object. + :param input: The input data. + :type input: torch.Tensor | Graph | list[Graph] | tuple[Graph] :raises ValueError: If the input data object does not contain at least one LabelTensor. """ - # Store the fist element of the list/tuple if input is a list/tuple - # it is anougth to check the first element because all elements must - # have the same type and structure (already checked) + # Store the first element: it is sufficient to check this since all + # elements must have the same type and structure (already checked). data = input[0] if isinstance(input, (list, tuple)) else input # Check if the input data contains at least one LabelTensor for v in data.values(): if isinstance(v, LabelTensor): return - raise ValueError( - "The input data object must contain at least one LabelTensor." - ) + + raise ValueError("The input must contain at least one LabelTensor.") diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index d39fb28..07b07bb 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -11,39 +11,66 @@ from .condition_interface import ConditionInterface class InputTargetCondition(ConditionInterface): """ - Condition defined by input and target data. This condition can be used in - both supervised learning and Physics-informed problems. Based on the type of - the input and target, different condition implementations are available: + The :class:`InputTargetCondition` class represents a supervised condition + defined by both ``input`` and ``target`` data. The model is trained to + reproduce the ``target`` values given the ``input``. Supported data types + include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, + :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - - :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or \ - :class:`~pina.label_tensor.LabelTensor` input and target data. - - :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or \ - :class:`~pina.label_tensor.LabelTensor` input and \ - :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` \ - target data. - - :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph` \ - or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` \ - or :class:`~pina.label_tensor.LabelTensor` target data. - - :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` \ - or :class:`~torch_geometric.data.Data` input and target data. + The class automatically selects the appropriate implementation based on + the types of ``input`` and ``target``. Depending on whether the ``input`` + and ``target`` are tensors or graph-based data, one of the following + specialized subclasses is instantiated: + + - :class:`TensorInputTensorTargetCondition`: For cases where both ``input`` + and ``target`` data are either :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor`. + + - :class:`TensorInputGraphTargetCondition`: For cases where ``input`` is + either a :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor` + and ``target`` is either a :class:`~pina.graph.Graph` or a + :class:`torch_geometric.data.Data`. + + - :class:`GraphInputTensorTargetCondition`: For cases where ``input`` is + either a :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` + and ``target`` is either a :class:`torch.Tensor` or a + :class:`~pina.label_tensor.LabelTensor`. + + - :class:`GraphInputGraphTargetCondition`: For cases where both ``input`` + and ``target`` are either :class:`~pina.graph.Graph` or + :class:`torch_geometric.data.Data`. + + :Example: + + >>> from pina import Condition, LabelTensor + >>> from pina.graph import Graph + >>> import torch + + >>> pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + >>> edge_index = torch.randint(0, 100, (2, 300)) + >>> graph = Graph(pos=pos, edge_index=edge_index) + + >>> input = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + >>> condition = Condition(input=input, target=graph) """ + # Available input and target data types __slots__ = ["input", "target"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) def __new__(cls, input, target): """ - Instantiate the appropriate subclass of InputTargetCondition based on - the types of input and target data. + Instantiate the appropriate subclass of :class:`InputTargetCondition` + based on the types of both ``input`` and ``target`` data. - :param input: Input data for the condition. + :param input: The input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param target: Target data for the condition. + :param target: The target data for the condition. :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :return: Subclass of InputTargetCondition + :return: The subclass of InputTargetCondition. :rtype: pina.condition.input_target_condition. TensorInputTensorTargetCondition | pina.condition.input_target_condition. @@ -59,11 +86,14 @@ class InputTargetCondition(ConditionInterface): if cls != InputTargetCondition: return super().__new__(cls) + # Tensor - Tensor if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( target, (torch.Tensor, LabelTensor) ): subclass = TensorInputTensorTargetCondition return subclass.__new__(subclass, input, target) + + # Tensor - Graph if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( target, (Graph, Data, list, tuple) ): @@ -71,6 +101,7 @@ class InputTargetCondition(ConditionInterface): subclass = TensorInputGraphTargetCondition return subclass.__new__(subclass, input, target) + # Graph - Tensor if isinstance(input, (Graph, Data, list, tuple)) and isinstance( target, (torch.Tensor, LabelTensor) ): @@ -78,6 +109,7 @@ class InputTargetCondition(ConditionInterface): subclass = GraphInputTensorTargetCondition return subclass.__new__(subclass, input, target) + # Graph - Graph if isinstance(input, (Graph, Data, list, tuple)) and isinstance( target, (Graph, Data, list, tuple) ): @@ -86,30 +118,31 @@ class InputTargetCondition(ConditionInterface): subclass = GraphInputGraphTargetCondition return subclass.__new__(subclass, input, target) + # If the input and/or target are not of the correct type raise an error raise ValueError( - "Invalid input/target types. " + "Invalid input | target types." "Please provide either torch_geometric.data.Data, Graph, " "LabelTensor or torch.Tensor objects." ) def __init__(self, input, target): """ - Initialize the object by storing the ``input`` and ``target`` data. + Initialization of the :class:`InputTargetCondition` class. - :param input: Input data for the condition. + :param input: The input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param target: Target data for the condition. + :param target: The target data for the condition. :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] .. note:: - If either input or target consists of a list of - :class:~pina.graph.Graph or :class:~torch_geometric.data.Data - objects, all elements must have the same structure (matching - keys and data types). - """ + If either ``input`` or ``target`` is a list of + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` + objects, all elements in the list must share the same structure, + with matching keys and consistent data types. + """ super().__init__() self._check_input_target_len(input, target) self.input = input @@ -117,10 +150,24 @@ class InputTargetCondition(ConditionInterface): @staticmethod def _check_input_target_len(input, target): + """ + Check that the length of the input and target lists are the same. + + :param input: The input data. + :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] + :param target: The target data. + :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] + :raises ValueError: If the lengths of the input and target lists do not + match. + """ if isinstance(input, (Graph, Data)) or isinstance( target, (Graph, Data) ): return + + # Raise an error if the lengths of the input and target do not match if len(input) != len(target): raise ValueError( "The input and target lists must have the same length." @@ -129,30 +176,33 @@ class InputTargetCondition(ConditionInterface): class TensorInputTensorTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` ``input`` and ``target`` data. + Specialization of the :class:`InputTargetCondition` class for the case where + both ``input`` and ``target`` are :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` objects. """ class TensorInputGraphTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` ``input`` and - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target` - data. + Specialization of the :class:`InputTargetCondition` class for the case where + ``input`` is either a :class:`torch.Tensor` or a + :class:`~pina.label_tensor.LabelTensor` object and ``target`` is either a + :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. """ class GraphInputTensorTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`~pina.graph.Graph` o - :class:`~torch_geometric.data.Data` ``input`` and :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` ``target`` data. + Specialization of the :class:`InputTargetCondition` class for the case where + ``input`` is either a :class:`~pina.graph.Graph` or + :class:`torch_geometric.data.Data` object and ``target`` is either a + :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. """ class GraphInputGraphTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`~pina.graph.Graph`/ - :class:`~torch_geometric.data.Data` ``input`` and ``target`` data. + Specialization of the :class:`InputTargetCondition` class for the case where + both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or + :class:`torch_geometric.data.Data` objects. """