Improve doc condition

This commit is contained in:
FilippoOlivo
2025-03-13 15:49:50 +01:00
parent b9b25e7b4a
commit 08de548e34
5 changed files with 103 additions and 66 deletions

View File

@@ -1,6 +1,4 @@
""" """Condition module."""
Condition module.
"""
import warnings import warnings
from .data_condition import DataCondition from .data_condition import DataCondition
@@ -15,11 +13,12 @@ warnings.filterwarnings("always", category=DeprecationWarning)
def warning_function(new, old): def warning_function(new, old):
""" """Handle the deprecation warning.
Handle the deprecation warning.
:param str new: Object to use instead of the old one. :param new: Object to use instead of the old one.
:param str old: Object to deprecate. :type new: str
:param old: Object to deprecate.
:type old: str
""" """
warnings.warn( warnings.warn(
f"'{old}' is deprecated and will be removed " f"'{old}' is deprecated and will be removed "
@@ -30,49 +29,58 @@ def warning_function(new, old):
class Condition: class Condition:
""" """
The class `Condition` is used to represent the constraints (physical The class ``Condition`` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the problem at hand. Condition objects are used to formulate the
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object. PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in four ways: Conditions can be specified in four ways:
1. By specifying the input and output points of the condition; in such a 1. By specifying the input and target of the condition; in such a
case, the model is trained to produce the output points given the input case, the model is trained to produce the output points given the input
points. Those points can either be torch.Tensor, LabelTensors, Graph points. Those points can either be torch.Tensor, LabelTensors, Graph.
Based on the type of the input and target, there are different
implementations of the condition. For more details, see
:class:`~pina.condition.input_target_condition.InputTargetCondition`.
2. By specifying the location and the equation of the condition; in such 2. By specifying the domain and the equation of the condition; in such
a case, the model is trained to minimize the equation residual by a case, the model is trained to minimize the equation residual by
evaluating it at some samples of the location. evaluating it at some samples of the domain.
3. By specifying the input points and the equation of the condition; in 3. By specifying the input and the equation of the condition; in
such a case, the model is trained to minimize the equation residual by such a case, the model is trained to minimize the equation residual by
evaluating it at the passed input points. The input points must be evaluating it at the passed input points. The input points must be
a LabelTensor. a LabelTensor. Based on the type of the input, there are different
implementations of the condition. For more details, see
:class:`~pina.condition.input_equation_condition.InputEquationCondition`
.
4. By specifying only the data matrix; in such a case the model is 4. By specifying only the input data; in such a case the model is
trained with an unsupervised costum loss and uses the data in training. trained with an unsupervised costum loss and uses the data in training.
Additionaly conditioning variables can be passed, whenever the model Additionaly conditioning variables can be passed, whenever the model
has extra conditioning variable it depends on. has extra conditioning variable it depends on. Based on the type of the
input, there are different implementations of the condition. For more
details, see :class:`~pina.condition.data_condition.DataCondition`.
Example:: Example::
>>> from pina import Condition >>> from pina import Condition
>>> condition = Condition( >>> condition = Condition(
... input=input, ... input=input,
... target=target ... target=target
... ) ... )
>>> condition = Condition( >>> condition = Condition(
... domain=location, ... domain=location,
... equation=equation ... equation=equation
... ) ... )
>>> condition = Condition( >>> condition = Condition(
... input=input, ... input=input,
... equation=equation ... equation=equation
... ) ... )
>>> condition = Condition( >>> condition = Condition(
... input=data, ... input=data,
... conditional_variables=conditional_variables ... conditional_variables=conditional_variables
... ) ... )
""" """
__slots__ = list( __slots__ = list(
@@ -86,24 +94,14 @@ class Condition:
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
Create a new condition object based on the keyword arguments passed. Check the input arguments and return the appropriate Condition object.
- `input` and `target`: :raises ValueError: If no keyword arguments are passed.
:class:`~pina.condition.input_target_condition.InputTargetCondition` :raises ValueError: If the keyword arguments are invalid.
- `domain` and `equation`: :return: The appropriate Condition object.
:class:`~pina.condition.domain_equation_condition. :rtype: ConditionInterface
DomainEquationCondition`
- `input` and `equation`: :class:`~pina.condition.
input_equation_condition.InputEquationCondition`
- `input`: :class:`~pina.condition.data_condition.DataCondition`
- `input` and `conditional_variables`:
:class:`~pina.condition.data_condition.DataCondition`
:return: A new condition instance belonging to the proper class.
:rtype: InputTargetCondition | DomainEquationCondition |
InputEquationCondition | DataCondition
:raises ValueError: No valid condition has been found.
""" """
if len(args) != 0: if len(args) != 0:
raise ValueError( raise ValueError(
"Condition takes only the following keyword " "Condition takes only the following keyword "

View File

@@ -11,9 +11,15 @@ from ..graph import Graph
class ConditionInterface(metaclass=ABCMeta): class ConditionInterface(metaclass=ABCMeta):
""" """
Abstract class which defines a common interface for all the conditions. Abstract class which defines a common interface for all the conditions.
It defined a common interface for all the conditions.
""" """
def __init__(self): def __init__(self):
"""
Initialize the ConditionInterface object.
"""
self._problem = None self._problem = None
@property @property
@@ -21,10 +27,9 @@ class ConditionInterface(metaclass=ABCMeta):
""" """
Return the problem to which the condition is associated. Return the problem to which the condition is associated.
:return: Problem to which the condition is associated. :return: Problem to which the condition is associated
:rtype: pina.problem.AbstractProblem :rtype: pina.problem.AbstractProblem
""" """
return self._problem return self._problem
@problem.setter @problem.setter
@@ -32,26 +37,35 @@ class ConditionInterface(metaclass=ABCMeta):
""" """
Set the problem to which the condition is associated. Set the problem to which the condition is associated.
:param pina.problem.AbstractProblem value: Problem to which the :param pina.problem.abstract_problem.AbstractProblem value: Problem to
condition is associated. which the condition is associated
""" """
self._problem = value self._problem = value
@staticmethod @staticmethod
def _check_graph_list_consistency(data_list): def _check_graph_list_consistency(data_list):
""" """
Check if the list of :class:`~torch_geometric.data.Data` or Check the consistency of the list of Data/Graph objects. It performs
class:`pina.graphGraph` objects is consistent. the following checks:
:param data_list: List of graph type objects. 1. All elements in the list must be of the same type (either Data or
:type data_list: Data | Graph | list[Data] | list[Graph] Graph).
2. All elements in the list must have the same keys.
3. The type of each tensor must be consistent across all elements in
the list.
4. If the tensor is a LabelTensor, the labels must be consistent across
all elements in the list.
:raises ValueError: Input data must be either Data :param data_list: List of Data/Graph objects to check
or Graph objects. :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
:raises ValueError: All elements in the list must have the same keys.
:raises ValueError: Type mismatch in data tensors. :raises ValueError: If the input types are invalid.
:raises ValueError: Label mismatch in LabelTensors. :raises ValueError: If all elements in the list do not have the same
keys.
:raises ValueError: If the type of each tensor is not consistent across
all elements in the list.
:raises ValueError: If the labels of the LabelTensors are not consistent
across all elements in the list.
""" """
# If the data is a Graph or Data object, return (do not need to check # If the data is a Graph or Data object, return (do not need to check

View File

@@ -12,7 +12,13 @@ from ..graph import Graph
class DataCondition(ConditionInterface): class DataCondition(ConditionInterface):
""" """
Condition defined by input data and conditional variables. It can be used Condition defined by input data and conditional variables. It can be used
in unsupervised learning problems. in unsupervised learning problems. Based on the type of the input,
different condition implementations are available:
- :class:`TensorDataCondition`: For :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` input data.
- :class:`GraphDataCondition`: For :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data` input data.
""" """
__slots__ = ["input", "conditional_variables"] __slots__ = ["input", "conditional_variables"]

View File

@@ -13,7 +13,13 @@ from ..equation.equation_interface import EquationInterface
class InputEquationCondition(ConditionInterface): class InputEquationCondition(ConditionInterface):
""" """
Condition defined by input data and an equation. This condition can be Condition defined by input data and an equation. This condition can be
used in a Physics Informed problems. used in a Physics Informed problems. Based on the type of the input,
different condition implementations are available:
- :class:`InputTensorEquationCondition`: For
:class:`~pina.label_tensor.LabelTensor` input data.
- :class:`InputGraphEquationCondition`: For :class:`~pina.graph.Graph`
input data.
""" """
__slots__ = ["input", "equation"] __slots__ = ["input", "equation"]

View File

@@ -12,7 +12,20 @@ from .condition_interface import ConditionInterface
class InputTargetCondition(ConditionInterface): class InputTargetCondition(ConditionInterface):
""" """
Condition defined by input and target data. This condition can be used in Condition defined by input and target data. This condition can be used in
both supervised learning and Physics-informed problems. both supervised learning and Physics-informed problems. Based on the type of
the input and target, different condition implementations are available:
- :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` input and target data.
- :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` input and
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
target data.
- :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph`
or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor`
or :class:`~pina.label_tensor.LabelTensor` target data.
- :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data` input and target data.
""" """
__slots__ = ["input", "target"] __slots__ = ["input", "target"]