Improve doc condition
This commit is contained in:
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Condition module."""
|
||||||
Condition module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from .data_condition import DataCondition
|
from .data_condition import DataCondition
|
||||||
@@ -15,11 +13,12 @@ warnings.filterwarnings("always", category=DeprecationWarning)
|
|||||||
|
|
||||||
|
|
||||||
def warning_function(new, old):
|
def warning_function(new, old):
|
||||||
"""
|
"""Handle the deprecation warning.
|
||||||
Handle the deprecation warning.
|
|
||||||
|
|
||||||
:param str new: Object to use instead of the old one.
|
:param new: Object to use instead of the old one.
|
||||||
:param str old: Object to deprecate.
|
:type new: str
|
||||||
|
:param old: Object to deprecate.
|
||||||
|
:type old: str
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"'{old}' is deprecated and will be removed "
|
f"'{old}' is deprecated and will be removed "
|
||||||
@@ -30,29 +29,37 @@ def warning_function(new, old):
|
|||||||
|
|
||||||
class Condition:
|
class Condition:
|
||||||
"""
|
"""
|
||||||
The class `Condition` is used to represent the constraints (physical
|
The class ``Condition`` is used to represent the constraints (physical
|
||||||
equations, boundary conditions, etc.) that should be satisfied in the
|
equations, boundary conditions, etc.) that should be satisfied in the
|
||||||
problem at hand. Condition objects are used to formulate the
|
problem at hand. Condition objects are used to formulate the
|
||||||
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
||||||
Conditions can be specified in four ways:
|
Conditions can be specified in four ways:
|
||||||
|
|
||||||
1. By specifying the input and output points of the condition; in such a
|
1. By specifying the input and target of the condition; in such a
|
||||||
case, the model is trained to produce the output points given the input
|
case, the model is trained to produce the output points given the input
|
||||||
points. Those points can either be torch.Tensor, LabelTensors, Graph
|
points. Those points can either be torch.Tensor, LabelTensors, Graph.
|
||||||
|
Based on the type of the input and target, there are different
|
||||||
|
implementations of the condition. For more details, see
|
||||||
|
:class:`~pina.condition.input_target_condition.InputTargetCondition`.
|
||||||
|
|
||||||
2. By specifying the location and the equation of the condition; in such
|
2. By specifying the domain and the equation of the condition; in such
|
||||||
a case, the model is trained to minimize the equation residual by
|
a case, the model is trained to minimize the equation residual by
|
||||||
evaluating it at some samples of the location.
|
evaluating it at some samples of the domain.
|
||||||
|
|
||||||
3. By specifying the input points and the equation of the condition; in
|
3. By specifying the input and the equation of the condition; in
|
||||||
such a case, the model is trained to minimize the equation residual by
|
such a case, the model is trained to minimize the equation residual by
|
||||||
evaluating it at the passed input points. The input points must be
|
evaluating it at the passed input points. The input points must be
|
||||||
a LabelTensor.
|
a LabelTensor. Based on the type of the input, there are different
|
||||||
|
implementations of the condition. For more details, see
|
||||||
|
:class:`~pina.condition.input_equation_condition.InputEquationCondition`
|
||||||
|
.
|
||||||
|
|
||||||
4. By specifying only the data matrix; in such a case the model is
|
4. By specifying only the input data; in such a case the model is
|
||||||
trained with an unsupervised costum loss and uses the data in training.
|
trained with an unsupervised costum loss and uses the data in training.
|
||||||
Additionaly conditioning variables can be passed, whenever the model
|
Additionaly conditioning variables can be passed, whenever the model
|
||||||
has extra conditioning variable it depends on.
|
has extra conditioning variable it depends on. Based on the type of the
|
||||||
|
input, there are different implementations of the condition. For more
|
||||||
|
details, see :class:`~pina.condition.data_condition.DataCondition`.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -73,6 +80,7 @@ class Condition:
|
|||||||
... input=data,
|
... input=data,
|
||||||
... conditional_variables=conditional_variables
|
... conditional_variables=conditional_variables
|
||||||
... )
|
... )
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = list(
|
__slots__ = list(
|
||||||
@@ -86,24 +94,14 @@ class Condition:
|
|||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a new condition object based on the keyword arguments passed.
|
Check the input arguments and return the appropriate Condition object.
|
||||||
|
|
||||||
- `input` and `target`:
|
:raises ValueError: If no keyword arguments are passed.
|
||||||
:class:`~pina.condition.input_target_condition.InputTargetCondition`
|
:raises ValueError: If the keyword arguments are invalid.
|
||||||
- `domain` and `equation`:
|
:return: The appropriate Condition object.
|
||||||
:class:`~pina.condition.domain_equation_condition.
|
:rtype: ConditionInterface
|
||||||
DomainEquationCondition`
|
|
||||||
- `input` and `equation`: :class:`~pina.condition.
|
|
||||||
input_equation_condition.InputEquationCondition`
|
|
||||||
- `input`: :class:`~pina.condition.data_condition.DataCondition`
|
|
||||||
- `input` and `conditional_variables`:
|
|
||||||
:class:`~pina.condition.data_condition.DataCondition`
|
|
||||||
:return: A new condition instance belonging to the proper class.
|
|
||||||
:rtype: InputTargetCondition | DomainEquationCondition |
|
|
||||||
InputEquationCondition | DataCondition
|
|
||||||
|
|
||||||
:raises ValueError: No valid condition has been found.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(args) != 0:
|
if len(args) != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Condition takes only the following keyword "
|
"Condition takes only the following keyword "
|
||||||
|
|||||||
@@ -11,9 +11,15 @@ from ..graph import Graph
|
|||||||
class ConditionInterface(metaclass=ABCMeta):
|
class ConditionInterface(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Abstract class which defines a common interface for all the conditions.
|
Abstract class which defines a common interface for all the conditions.
|
||||||
|
It defined a common interface for all the conditions.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the ConditionInterface object.
|
||||||
|
"""
|
||||||
|
|
||||||
self._problem = None
|
self._problem = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -21,10 +27,9 @@ class ConditionInterface(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Return the problem to which the condition is associated.
|
Return the problem to which the condition is associated.
|
||||||
|
|
||||||
:return: Problem to which the condition is associated.
|
:return: Problem to which the condition is associated
|
||||||
:rtype: pina.problem.AbstractProblem
|
:rtype: pina.problem.AbstractProblem
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._problem
|
return self._problem
|
||||||
|
|
||||||
@problem.setter
|
@problem.setter
|
||||||
@@ -32,26 +37,35 @@ class ConditionInterface(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Set the problem to which the condition is associated.
|
Set the problem to which the condition is associated.
|
||||||
|
|
||||||
:param pina.problem.AbstractProblem value: Problem to which the
|
:param pina.problem.abstract_problem.AbstractProblem value: Problem to
|
||||||
condition is associated.
|
which the condition is associated
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._problem = value
|
self._problem = value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_graph_list_consistency(data_list):
|
def _check_graph_list_consistency(data_list):
|
||||||
"""
|
"""
|
||||||
Check if the list of :class:`~torch_geometric.data.Data` or
|
Check the consistency of the list of Data/Graph objects. It performs
|
||||||
class:`pina.graphGraph` objects is consistent.
|
the following checks:
|
||||||
|
|
||||||
:param data_list: List of graph type objects.
|
1. All elements in the list must be of the same type (either Data or
|
||||||
:type data_list: Data | Graph | list[Data] | list[Graph]
|
Graph).
|
||||||
|
2. All elements in the list must have the same keys.
|
||||||
|
3. The type of each tensor must be consistent across all elements in
|
||||||
|
the list.
|
||||||
|
4. If the tensor is a LabelTensor, the labels must be consistent across
|
||||||
|
all elements in the list.
|
||||||
|
|
||||||
:raises ValueError: Input data must be either Data
|
:param data_list: List of Data/Graph objects to check
|
||||||
or Graph objects.
|
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
|
||||||
:raises ValueError: All elements in the list must have the same keys.
|
|
||||||
:raises ValueError: Type mismatch in data tensors.
|
:raises ValueError: If the input types are invalid.
|
||||||
:raises ValueError: Label mismatch in LabelTensors.
|
:raises ValueError: If all elements in the list do not have the same
|
||||||
|
keys.
|
||||||
|
:raises ValueError: If the type of each tensor is not consistent across
|
||||||
|
all elements in the list.
|
||||||
|
:raises ValueError: If the labels of the LabelTensors are not consistent
|
||||||
|
across all elements in the list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If the data is a Graph or Data object, return (do not need to check
|
# If the data is a Graph or Data object, return (do not need to check
|
||||||
|
|||||||
@@ -12,7 +12,13 @@ from ..graph import Graph
|
|||||||
class DataCondition(ConditionInterface):
|
class DataCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition defined by input data and conditional variables. It can be used
|
Condition defined by input data and conditional variables. It can be used
|
||||||
in unsupervised learning problems.
|
in unsupervised learning problems. Based on the type of the input,
|
||||||
|
different condition implementations are available:
|
||||||
|
|
||||||
|
- :class:`TensorDataCondition`: For :class:`torch.Tensor` or
|
||||||
|
:class:`~pina.label_tensor.LabelTensor` input data.
|
||||||
|
- :class:`GraphDataCondition`: For :class:`~pina.graph.Graph` or
|
||||||
|
:class:`~torch_geometric.data.Data` input data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input", "conditional_variables"]
|
__slots__ = ["input", "conditional_variables"]
|
||||||
|
|||||||
@@ -13,7 +13,13 @@ from ..equation.equation_interface import EquationInterface
|
|||||||
class InputEquationCondition(ConditionInterface):
|
class InputEquationCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition defined by input data and an equation. This condition can be
|
Condition defined by input data and an equation. This condition can be
|
||||||
used in a Physics Informed problems.
|
used in a Physics Informed problems. Based on the type of the input,
|
||||||
|
different condition implementations are available:
|
||||||
|
|
||||||
|
- :class:`InputTensorEquationCondition`: For
|
||||||
|
:class:`~pina.label_tensor.LabelTensor` input data.
|
||||||
|
- :class:`InputGraphEquationCondition`: For :class:`~pina.graph.Graph`
|
||||||
|
input data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input", "equation"]
|
__slots__ = ["input", "equation"]
|
||||||
|
|||||||
@@ -12,7 +12,20 @@ from .condition_interface import ConditionInterface
|
|||||||
class InputTargetCondition(ConditionInterface):
|
class InputTargetCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition defined by input and target data. This condition can be used in
|
Condition defined by input and target data. This condition can be used in
|
||||||
both supervised learning and Physics-informed problems.
|
both supervised learning and Physics-informed problems. Based on the type of
|
||||||
|
the input and target, different condition implementations are available:
|
||||||
|
|
||||||
|
- :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or
|
||||||
|
:class:`~pina.label_tensor.LabelTensor` input and target data.
|
||||||
|
- :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or
|
||||||
|
:class:`~pina.label_tensor.LabelTensor` input and
|
||||||
|
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
|
||||||
|
target data.
|
||||||
|
- :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph`
|
||||||
|
or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor`
|
||||||
|
or :class:`~pina.label_tensor.LabelTensor` target data.
|
||||||
|
- :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` or
|
||||||
|
:class:`~torch_geometric.data.Data` input and target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input", "target"]
|
__slots__ = ["input", "target"]
|
||||||
|
|||||||
Reference in New Issue
Block a user