add exhaustive doc for condition module (#629)

This commit is contained in:
Giovanni Canali
2025-09-11 15:47:06 +02:00
committed by GitHub
parent f3ccfd4598
commit a0015c3af6
6 changed files with 366 additions and 246 deletions

View File

@@ -1,100 +1,91 @@
"""Module for the Condition class."""
import warnings
from .data_condition import DataCondition
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputEquationCondition
from .input_target_condition import InputTargetCondition
from ..utils import custom_warning_format
# Set the custom format for warnings
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)
def warning_function(new, old):
"""Handle the deprecation warning.
:param new: Object to use instead of the old one.
:type new: str
:param old: Object to deprecate.
:type old: str
"""
warnings.warn(
f"'{old}' is deprecated and will be removed "
f"in future versions. Please use '{new}' instead.",
DeprecationWarning,
)
class Condition:
"""
Represents constraints (such as physical equations, boundary conditions,
etc.) that must be satisfied in a given problem. Condition objects are used
to formulate the PINA
:class:`~pina.problem.abstract_problem.AbstractProblem` object.
The :class:`Condition` class is a core component of the PINA framework that
provides a unified interface to define heterogeneous constraints that must
be satisfied by a :class:`~pina.problem.abstract_problem.AbstractProblem`.
There are different types of conditions:
It encapsulates all types of constraints - physical, boundary, initial, or
data-driven - that the solver must satisfy during training. The specific
behavior is inferred from the arguments passed to the constructor.
Multiple types of conditions can be used within the same problem, allowing
for a high degree of flexibility in defining complex problems.
The :class:`Condition` class behavior specializes internally based on the
arguments provided during instantiation. Depending on the specified keyword
arguments, the class automatically selects the appropriate internal
implementation.
Available `Condition` types:
- :class:`~pina.condition.input_target_condition.InputTargetCondition`:
Defined by specifying both the input and the target of the condition. In
this case, the model is trained to produce the target given the input. The
input and output data must be one of the :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`,
:class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`.
Different implementations exist depending on the type of input and target.
For more details, see
:class:`~pina.condition.input_target_condition.InputTargetCondition`.
represents a supervised condition defined by both ``input`` and ``target``
data. The model is trained to reproduce the ``target`` values given the
``input``. Supported data types include :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
:class:`~torch_geometric.data.Data`.
The class automatically selects the appropriate implementation based on
the types of ``input`` and ``target``.
- :class:`~pina.condition.domain_equation_condition.DomainEquationCondition`
: Defined by specifying both the domain and the equation of the condition.
Here, the model is trained to minimize the equation residual by evaluating
it at sampled points within the domain.
: represents a general physics-informed condition defined by a ``domain``
and an ``equation``. The model learns to minimize the equation residual
through evaluations performed at points sampled from the specified domain.
- :class:`~pina.condition.input_equation_condition.InputEquationCondition`:
Defined by specifying the input and the equation of the condition. In this
case, the model is trained to minimize the equation residual by evaluating
it at the provided input. The input must be either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`~pina.graph.Graph`.
Different implementations exist depending on the type of input. For more
details, see
:class:`~pina.condition.input_equation_condition.InputEquationCondition`.
represents a general physics-informed condition defined by ``input``
points and an ``equation``. The model learns to minimize the equation
residual through evaluations performed at the provided ``input``.
Supported data types for the ``input`` include
:class:`~pina.label_tensor.LabelTensor` or :class:`~pina.graph.Graph`.
The class automatically selects the appropriate implementation based on
the types of the ``input``.
- :class:`~pina.condition.data_condition.DataCondition`:
Defined by specifying only the input. In this case, the model is trained
with an unsupervised custom loss while using the provided data during
training. The input data must be one of :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`,
:class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`.
Additionally, conditional variables can be provided when the model
depends on extra parameters. These conditional variables must be either
:class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`.
Different implementations exist depending on the type of input.
For more details, see
:class:`~pina.condition.data_condition.DataCondition`.
- :class:`~pina.condition.data_condition.DataCondition`: represents an
unsupervised, data-driven condition defined by the ``input`` only.
The model is trained using a custom unsupervised loss determined by the
chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging the
provided data during training. Optional ``conditional_variables`` can be
specified when the model depends on additional parameters.
Supported data types include :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
:class:`~torch_geometric.data.Data`.
The class automatically selects the appropriate implementation based on
the type of the ``input``.
.. note::
The user should always instantiate :class:`Condition` directly, without
manually creating subclass instances. Please refer to the specific
:class:`Condition` classes for implementation details.
:Example:
>>> from pina import Condition
>>> condition = Condition(
... input=input,
... target=target
... )
>>> condition = Condition(
... domain=location,
... equation=equation
... )
>>> condition = Condition(
... input=input,
... equation=equation
... )
>>> condition = Condition(
... input=data,
... conditional_variables=conditional_variables
... )
>>> # Example of InputTargetCondition signature
>>> condition = Condition(input=input, target=target)
>>> # Example of DomainEquationCondition signature
>>> condition = Condition(domain=domain, equation=equation)
>>> # Example of InputEquationCondition signature
>>> condition = Condition(input=input, equation=equation)
>>> # Example of DataCondition signature
>>> condition = Condition(input=data, conditional_variables=cond_vars)
"""
# Combine all possible keyword arguments from the different Condition types
__slots__ = list(
set(
InputTargetCondition.__slots__
@@ -106,46 +97,45 @@ class Condition:
def __new__(cls, *args, **kwargs):
"""
Instantiate the appropriate Condition object based on the keyword
arguments passed.
Instantiate the appropriate :class:`Condition` object based on the
keyword arguments passed.
:raises ValueError: If no keyword arguments are passed.
:param tuple args: The positional arguments (should be empty).
:param dict kwargs: The keyword arguments corresponding to the
parameters of the specific :class:`Condition` type to instantiate.
:raises ValueError: If unexpected positional arguments are provided.
:raises ValueError: If the keyword arguments are invalid.
:return: The appropriate Condition object.
:return: The appropriate :class:`Condition` object.
:rtype: ConditionInterface
"""
# Check keyword arguments
if len(args) != 0:
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
)
# back-compatibility 0.1
keys = list(kwargs.keys())
if "location" in keys:
kwargs["domain"] = kwargs.pop("location")
warning_function(new="domain", old="location")
if "input_points" in keys:
kwargs["input"] = kwargs.pop("input_points")
warning_function(new="input", old="input_points")
if "output_points" in keys:
kwargs["target"] = kwargs.pop("output_points")
warning_function(new="target", old="output_points")
# Class specialization based on keyword arguments
sorted_keys = sorted(kwargs.keys())
# Input - Target Condition
if sorted_keys == sorted(InputTargetCondition.__slots__):
return InputTargetCondition(**kwargs)
# Input - Equation Condition
if sorted_keys == sorted(InputEquationCondition.__slots__):
return InputEquationCondition(**kwargs)
# Domain - Equation Condition
if sorted_keys == sorted(DomainEquationCondition.__slots__):
return DomainEquationCondition(**kwargs)
# Data Condition
if (
sorted_keys == sorted(DataCondition.__slots__)
or sorted_keys[0] == DataCondition.__slots__[0]
):
return DataCondition(**kwargs)
# Invalid keyword arguments
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

View File

@@ -8,24 +8,25 @@ from ..graph import Graph
class ConditionInterface(metaclass=ABCMeta):
"""
Abstract class which defines a common interface for all the conditions.
It defined a common interface for all the conditions.
Abstract base class for PINA conditions. All specific conditions must
inherit from this interface.
Refer to :class:`pina.condition.condition.Condition` for a thorough
description of all available conditions and how to instantiate them.
"""
def __init__(self):
"""
Initialize the ConditionInterface object.
Initialization of the :class:`ConditionInterface` class.
"""
self._problem = None
@property
def problem(self):
"""
Return the problem to which the condition is associated.
Return the problem associated with this condition.
:return: Problem to which the condition is associated.
:return: Problem associated with this condition.
:rtype: ~pina.problem.abstract_problem.AbstractProblem
"""
return self._problem
@@ -33,30 +34,31 @@ class ConditionInterface(metaclass=ABCMeta):
@problem.setter
def problem(self, value):
"""
Set the problem to which the condition is associated.
Set the problem associated with this condition.
:param pina.problem.abstract_problem.AbstractProblem value: Problem to
which the condition is associated
:param pina.problem.abstract_problem.AbstractProblem value: The problem
to associate with this condition
"""
self._problem = value
@staticmethod
def _check_graph_list_consistency(data_list):
"""
Check the consistency of the list of Data/Graph objects. It performs
the following checks:
Check the consistency of the list of Data | Graph objects.
The following checks are performed:
1. All elements in the list must be of the same type (either Data or
Graph).
2. All elements in the list must have the same keys.
3. The type of each tensor must be consistent across all elements in
the list.
4. If the tensor is a LabelTensor, the labels must be consistent across
all elements in the list.
- All elements in the list must be of the same type (either
:class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
:param data_list: List of Data/Graph objects to check
- All elements in the list must have the same keys.
- The data type of each tensor must be consistent across all elements.
- If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
must also be consistent across all elements.
:param data_list: The list of Data | Graph objects to check.
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
:raises ValueError: If the input types are invalid.
:raises ValueError: If all elements in the list do not have the same
keys.
@@ -65,51 +67,45 @@ class ConditionInterface(metaclass=ABCMeta):
:raises ValueError: If the labels of the LabelTensors are not consistent
across all elements in the list.
"""
# If the data is a Graph or Data object, return (do not need to check
# anything)
# If the data is a Graph or Data object, perform no checks
if isinstance(data_list, (Graph, Data)):
return
# check all elements in the list are of the same type
# Check all elements in the list are of the same type
if not all(isinstance(i, (Graph, Data)) for i in data_list):
raise ValueError(
"Invalid input types. "
"Please provide either Data or Graph objects."
"Invalid input. Please, provide either Data or Graph objects."
)
# Store the keys, data types and labels of the first element
data = data_list[0]
# Store the keys of the first element in the list
keys = sorted(list(data.keys()))
# Store the type of each tensor inside first element Data/Graph object
data_types = {name: tensor.__class__ for name, tensor in data.items()}
# Store the labels of each LabelTensor inside first element Data/Graph
# object
labels = {
name: tensor.labels
for name, tensor in data.items()
if isinstance(tensor, LabelTensor)
}
# Iterate over the list of Data/Graph objects
# Iterate over the list of Data | Graph objects
for data in data_list[1:]:
# Check if the keys of the current element are the same as the first
# element
# Check that all elements in the list have the same keys
if sorted(list(data.keys())) != keys:
raise ValueError(
"All elements in the list must have the same keys."
)
# Iterate over the tensors in the current element
for name, tensor in data.items():
# Check if the type of each tensor inside the current element
# is the same as the first element
# Check that the type of each tensor is consistent
if tensor.__class__ is not data_types[name]:
raise ValueError(
f"Data {name} must be a {data_types[name]}, got "
f"{tensor.__class__}"
)
# If the tensor is a LabelTensor, check if the labels are the
# same as the first element
# Check that the labels of each LabelTensor are consistent
if isinstance(tensor, LabelTensor):
if tensor.labels != labels[name]:
raise ValueError(
@@ -117,6 +113,13 @@ class ConditionInterface(metaclass=ABCMeta):
)
def __getattribute__(self, name):
"""
Get an attribute from the object.
:param str name: The name of the attribute to get.
:return: The requested attribute.
:rtype: Any
"""
to_return = super().__getattribute__(name)
if isinstance(to_return, (Graph, Data)):
to_return = [to_return]

View File

@@ -9,16 +9,35 @@ from ..graph import Graph
class DataCondition(ConditionInterface):
"""
Condition defined by input data and conditional variables. It can be used
in unsupervised learning problems. Based on the type of the input,
different condition implementations are available:
The class :class:`DataCondition` defines an unsupervised condition based on
``input`` data. This condition is typically used in data-driven problems,
where the model is trained using a custom unsupervised loss determined by
the chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging
the provided data during training. Optional ``conditional_variables`` can be
specified when the model depends on additional parameters.
- :class:`TensorDataCondition`: For :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` input data.
- :class:`GraphDataCondition`: For :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data` input data.
The class automatically selects the appropriate implementation based on the
type of the ``input`` data. Depending on whether the ``input`` is a tensor
or graph-based data, one of the following specialized subclasses is
instantiated:
- :class:`TensorDataCondition`: For cases where the ``input`` is either a
:class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object.
- :class:`GraphDataCondition`: For cases where the ``input`` is either a
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` object.
:Example:
>>> from pina import Condition, LabelTensor
>>> import torch
>>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> cond_vars = LabelTensor(torch.randn(100, 1), labels=["w"])
>>> condition = Condition(input=pts, conditional_variables=cond_vars)
"""
# Available input data types
__slots__ = ["input", "conditional_variables"]
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
_avail_conditional_variables_cls = (torch.Tensor, LabelTensor)
@@ -26,33 +45,36 @@ class DataCondition(ConditionInterface):
def __new__(cls, input, conditional_variables=None):
"""
Instantiate the appropriate subclass of :class:`DataCondition` based on
the type of ``input``.
the type of the ``input``.
:param input: Input data for the condition.
:param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph |
Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor, optional
:return: Subclass of DataCondition.
:param conditional_variables: The conditional variables for the
condition. Default is ``None``.
:type conditional_variables: torch.Tensor | LabelTensor
:return: The subclass of DataCondition.
:rtype: pina.condition.data_condition.TensorDataCondition |
pina.condition.data_condition.GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:raises ValueError: If ``input`` is not of type :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
or :class:`~torch_geometric.data.Data`.
"""
if cls != DataCondition:
return super().__new__(cls)
# If the input is a tensor
if isinstance(input, (torch.Tensor, LabelTensor)):
subclass = TensorDataCondition
return subclass.__new__(subclass, input, conditional_variables)
# If the input is a graph
if isinstance(input, (Graph, Data, list, tuple)):
cls._check_graph_list_consistency(input)
subclass = GraphDataCondition
return subclass.__new__(subclass, input, conditional_variables)
# If the input is not of the correct type raise an error
raise ValueError(
"Invalid input types. "
"Please provide either torch_geometric.data.Data or Graph objects."
@@ -60,21 +82,22 @@ class DataCondition(ConditionInterface):
def __init__(self, input, conditional_variables=None):
"""
Initialize the object by storing the input and conditional
variables (if any).
Initialization of the :class:`DataCondition` class.
:param input: Input data for the condition.
:param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param conditional_variables: Conditional variables for the condition.
:param conditional_variables: The conditional variables for the
condition. Default is ``None``.
:type conditional_variables: torch.Tensor | LabelTensor
.. note::
If ``input`` consists of a list of :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data`, all elements must have the same
structure (keys and data types)
"""
If ``input`` is a list of :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data`, all elements in
the list must share the same structure, with matching keys and
consistent data types.
"""
super().__init__()
self.input = input
self.conditional_variables = conditional_variables
@@ -82,13 +105,15 @@ class DataCondition(ConditionInterface):
class TensorDataCondition(DataCondition):
"""
DataCondition for :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` input data
Specialization of the :class:`DataCondition` class for the case where
``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a
:class:`torch.Tensor` object.
"""
class GraphDataCondition(DataCondition):
"""
DataCondition for :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data` input data
Specialization of the :class:`DataCondition` class for the case where
``input`` is either a :class:`~pina.graph.Graph` object or a
:class:`~torch_geometric.data.Data` object.
"""

View File

@@ -8,31 +8,57 @@ from ..equation.equation_interface import EquationInterface
class DomainEquationCondition(ConditionInterface):
"""
Condition defined by a domain and an equation. It can be used in Physics
Informed problems. Before using this condition, make sure that input data
are correctly sampled from the domain.
The class :class:`DomainEquationCondition` defines a condition based on a
``domain`` and an ``equation``. This condition is typically used in
physics-informed problems, where the model is trained to satisfy a given
``equation`` over a specified ``domain``. The ``domain`` is used to sample
points where the ``equation`` residual is evaluated and minimized during
training.
:Example:
>>> from pina.domain import CartesianDomain
>>> from pina.equation import Equation
>>> from pina import Condition
>>> # Equation to be satisfied over the domain: # x^2 + y^2 - 1 = 0
>>> def dummy_equation(pts):
... return pts["x"]**2 + pts["y"]**2 - 1
>>> domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
>>> condition = Condition(domain=domain, equation=Equation(dummy_equation))
"""
# Available slots
__slots__ = ["domain", "equation"]
def __init__(self, domain, equation):
"""
Initialise the object by storing the domain and equation.
Initialization of the :class:`DomainEquationCondition` class.
:param DomainInterface domain: Domain object containing the domain data.
:param EquationInterface equation: Equation object containing the
equation data.
:param DomainInterface domain: The domain over which the equation is
defined.
:param EquationInterface equation: The equation to be satisfied over the
specified domain.
"""
super().__init__()
self.domain = domain
self.equation = equation
def __setattr__(self, key, value):
"""
Set the attribute value with type checking.
:param str key: The attribute name.
:param any value: The value to set for the attribute.
"""
if key == "domain":
check_consistency(value, (DomainInterface, str))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == "equation":
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ("_problem"):
super().__setattr__(key, value)

View File

@@ -1,6 +1,5 @@
"""Module for the InputEquationCondition class and its subclasses."""
from torch_geometric.data import Data
from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
@@ -10,16 +9,38 @@ from ..equation.equation_interface import EquationInterface
class InputEquationCondition(ConditionInterface):
"""
Condition defined by input data and an equation. This condition can be
used in a Physics Informed problems. Based on the type of the input,
different condition implementations are available:
The class :class:`InputEquationCondition` defines a condition based on
``input`` data and an ``equation``. This condition is typically used in
physics-informed problems, where the model is trained to satisfy a given
``equation`` through the evaluation of the residual performed at the
provided ``input``.
- :class:`InputTensorEquationCondition`: For \
:class:`~pina.label_tensor.LabelTensor` input data.
- :class:`InputGraphEquationCondition`: For :class:`~pina.graph.Graph` \
input data.
The class automatically selects the appropriate implementation based on
the type of the ``input`` data. Depending on whether the ``input`` is a
tensor or graph-based data, one of the following specialized subclasses is
instantiated:
- :class:`InputTensorEquationCondition`: For cases where the ``input``
data is a :class:`~pina.label_tensor.LabelTensor` object.
- :class:`InputGraphEquationCondition`: For cases where the ``input`` data
is a :class:`~pina.graph.Graph` object.
:Example:
>>> from pina import Condition, LabelTensor
>>> from pina.equation import Equation
>>> import torch
>>> # Equation to be satisfied over the input points: # x^2 + y^2 - 1 = 0
>>> def dummy_equation(pts):
... return pts["x"]**2 + pts["y"]**2 - 1
>>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> condition = Condition(input=pts, equation=Equation(dummy_equation))
"""
# Available input data types
__slots__ = ["input", "equation"]
_avail_input_cls = (LabelTensor, Graph, list, tuple)
_avail_equation_cls = EquationInterface
@@ -27,31 +48,31 @@ class InputEquationCondition(ConditionInterface):
def __new__(cls, input, equation):
"""
Instantiate the appropriate subclass of :class:`InputEquationCondition`
based on the type of ``input``.
based on the type of ``input`` data.
:param input: Input data for the condition.
:param input: The input data for the condition.
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
:param EquationInterface equation: Equation object containing the
equation function.
:return: Subclass of InputEquationCondition, based on the input type.
:param EquationInterface equation: The equation to be satisfied over the
specified ``input`` data.
:return: The subclass of InputEquationCondition.
:rtype: pina.condition.input_equation_condition.
InputTensorEquationCondition |
pina.condition.input_equation_condition.InputGraphEquationCondition
:raises ValueError: If input is not of type
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`.
:raises ValueError: If input is not of type :class:`~pina.graph.Graph`
or :class:`~pina.label_tensor.LabelTensor`.
"""
# If the class is already a subclass, return the instance
if cls != InputEquationCondition:
return super().__new__(cls)
# Instanciate the correct subclass
if isinstance(input, (Graph, Data, list, tuple)):
# If the input is a Graph object
if isinstance(input, (Graph, list, tuple)):
subclass = InputGraphEquationCondition
cls._check_graph_list_consistency(input)
subclass._check_label_tensor(input)
return subclass.__new__(subclass, input, equation)
# If the input is a LabelTensor
if isinstance(input, LabelTensor):
subclass = InputTensorEquationCondition
return subclass.__new__(subclass, input, equation)
@@ -63,69 +84,74 @@ class InputEquationCondition(ConditionInterface):
def __init__(self, input, equation):
"""
Initialize the object by storing the input data and equation object.
Initialization of the :class:`InputEquationCondition` class.
:param input: Input data for the condition.
:type input: LabelTensor | Graph |
list[Graph] | tuple[Graph]
:param EquationInterface equation: Equation object containing the
equation function.
:param input: The input data for the condition.
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
:param EquationInterface equation: The equation to be satisfied over the
specified input points.
.. note::
If ``input`` consists of a list of :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data`, all elements must have the same
structure (keys and data types)
"""
If ``input`` is a list of :class:`~pina.graph.Graph` all elements in
the list must share the same structure, with matching keys and
consistent data types.
"""
super().__init__()
self.input = input
self.equation = equation
def __setattr__(self, key, value):
"""
Set the attribute value with type checking.
:param str key: The attribute name.
:param any value: The value to set for the attribute.
"""
if key == "input":
check_consistency(value, self._avail_input_cls)
InputEquationCondition.__dict__[key].__set__(self, value)
elif key == "equation":
check_consistency(value, self._avail_equation_cls)
InputEquationCondition.__dict__[key].__set__(self, value)
elif key in ("_problem"):
super().__setattr__(key, value)
class InputTensorEquationCondition(InputEquationCondition):
"""
InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor`
input data.
Specialization of the :class:`InputEquationCondition` class for the case
where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object.
"""
class InputGraphEquationCondition(InputEquationCondition):
"""
InputEquationCondition subclass for :class:`~pina.graph.Graph` input data.
Specialization of the :class:`InputEquationCondition` class for the case
where ``input`` is a :class:`~pina.graph.Graph` object.
"""
@staticmethod
def _check_label_tensor(input):
"""
Check if at least one :class:`~pina.label_tensor.LabelTensor` is present
in the :class:`~pina.graph.Graph` object.
:param input: Input data.
:type input: torch.Tensor | Graph | Data
in the ``input`` object.
:param input: The input data.
:type input: torch.Tensor | Graph | list[Graph] | tuple[Graph]
:raises ValueError: If the input data object does not contain at least
one LabelTensor.
"""
# Store the fist element of the list/tuple if input is a list/tuple
# it is anougth to check the first element because all elements must
# have the same type and structure (already checked)
# Store the first element: it is sufficient to check this since all
# elements must have the same type and structure (already checked).
data = input[0] if isinstance(input, (list, tuple)) else input
# Check if the input data contains at least one LabelTensor
for v in data.values():
if isinstance(v, LabelTensor):
return
raise ValueError(
"The input data object must contain at least one LabelTensor."
)
raise ValueError("The input must contain at least one LabelTensor.")

View File

@@ -11,39 +11,66 @@ from .condition_interface import ConditionInterface
class InputTargetCondition(ConditionInterface):
"""
Condition defined by input and target data. This condition can be used in
both supervised learning and Physics-informed problems. Based on the type of
the input and target, different condition implementations are available:
The :class:`InputTargetCondition` class represents a supervised condition
defined by both ``input`` and ``target`` data. The model is trained to
reproduce the ``target`` values given the ``input``. Supported data types
include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
- :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or \
:class:`~pina.label_tensor.LabelTensor` input and target data.
- :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or \
:class:`~pina.label_tensor.LabelTensor` input and \
:class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` \
target data.
- :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph` \
or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` \
or :class:`~pina.label_tensor.LabelTensor` target data.
- :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` \
or :class:`~torch_geometric.data.Data` input and target data.
The class automatically selects the appropriate implementation based on
the types of ``input`` and ``target``. Depending on whether the ``input``
and ``target`` are tensors or graph-based data, one of the following
specialized subclasses is instantiated:
- :class:`TensorInputTensorTargetCondition`: For cases where both ``input``
and ``target`` data are either :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor`.
- :class:`TensorInputGraphTargetCondition`: For cases where ``input`` is
either a :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`
and ``target`` is either a :class:`~pina.graph.Graph` or a
:class:`torch_geometric.data.Data`.
- :class:`GraphInputTensorTargetCondition`: For cases where ``input`` is
either a :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data`
and ``target`` is either a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor`.
- :class:`GraphInputGraphTargetCondition`: For cases where both ``input``
and ``target`` are either :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data`.
:Example:
>>> from pina import Condition, LabelTensor
>>> from pina.graph import Graph
>>> import torch
>>> pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> edge_index = torch.randint(0, 100, (2, 300))
>>> graph = Graph(pos=pos, edge_index=edge_index)
>>> input = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> condition = Condition(input=input, target=graph)
"""
# Available input and target data types
__slots__ = ["input", "target"]
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
_avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
def __new__(cls, input, target):
"""
Instantiate the appropriate subclass of InputTargetCondition based on
the types of input and target data.
Instantiate the appropriate subclass of :class:`InputTargetCondition`
based on the types of both ``input`` and ``target`` data.
:param input: Input data for the condition.
:param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: Target data for the condition.
:param target: The target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:return: Subclass of InputTargetCondition
:return: The subclass of InputTargetCondition.
:rtype: pina.condition.input_target_condition.
TensorInputTensorTargetCondition |
pina.condition.input_target_condition.
@@ -59,11 +86,14 @@ class InputTargetCondition(ConditionInterface):
if cls != InputTargetCondition:
return super().__new__(cls)
# Tensor - Tensor
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (torch.Tensor, LabelTensor)
):
subclass = TensorInputTensorTargetCondition
return subclass.__new__(subclass, input, target)
# Tensor - Graph
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (Graph, Data, list, tuple)
):
@@ -71,6 +101,7 @@ class InputTargetCondition(ConditionInterface):
subclass = TensorInputGraphTargetCondition
return subclass.__new__(subclass, input, target)
# Graph - Tensor
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (torch.Tensor, LabelTensor)
):
@@ -78,6 +109,7 @@ class InputTargetCondition(ConditionInterface):
subclass = GraphInputTensorTargetCondition
return subclass.__new__(subclass, input, target)
# Graph - Graph
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (Graph, Data, list, tuple)
):
@@ -86,30 +118,31 @@ class InputTargetCondition(ConditionInterface):
subclass = GraphInputGraphTargetCondition
return subclass.__new__(subclass, input, target)
# If the input and/or target are not of the correct type raise an error
raise ValueError(
"Invalid input/target types. "
"Invalid input | target types."
"Please provide either torch_geometric.data.Data, Graph, "
"LabelTensor or torch.Tensor objects."
)
def __init__(self, input, target):
"""
Initialize the object by storing the ``input`` and ``target`` data.
Initialization of the :class:`InputTargetCondition` class.
:param input: Input data for the condition.
:param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: Target data for the condition.
:param target: The target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
.. note::
If either input or target consists of a list of
:class:~pina.graph.Graph or :class:~torch_geometric.data.Data
objects, all elements must have the same structure (matching
keys and data types).
"""
If either ``input`` or ``target`` is a list of
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
objects, all elements in the list must share the same structure,
with matching keys and consistent data types.
"""
super().__init__()
self._check_input_target_len(input, target)
self.input = input
@@ -117,10 +150,24 @@ class InputTargetCondition(ConditionInterface):
@staticmethod
def _check_input_target_len(input, target):
"""
Check that the length of the input and target lists are the same.
:param input: The input data.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: The target data.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:raises ValueError: If the lengths of the input and target lists do not
match.
"""
if isinstance(input, (Graph, Data)) or isinstance(
target, (Graph, Data)
):
return
# Raise an error if the lengths of the input and target do not match
if len(input) != len(target):
raise ValueError(
"The input and target lists must have the same length."
@@ -129,30 +176,33 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` ``input`` and ``target`` data.
Specialization of the :class:`InputTargetCondition` class for the case where
both ``input`` and ``target`` are :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` objects.
"""
class TensorInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` ``input`` and
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target`
data.
Specialization of the :class:`InputTargetCondition` class for the case where
``input`` is either a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor` object and ``target`` is either a
:class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object.
"""
class GraphInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
:class:`~torch_geometric.data.Data` ``input`` and :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` ``target`` data.
Specialization of the :class:`InputTargetCondition` class for the case where
``input`` is either a :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data` object and ``target`` is either a
:class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object.
"""
class GraphInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
:class:`~torch_geometric.data.Data` ``input`` and ``target`` data.
Specialization of the :class:`InputTargetCondition` class for the case where
both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data` objects.
"""