add exhaustive doc for condition module (#629)

This commit is contained in:
Giovanni Canali
2025-09-11 15:47:06 +02:00
committed by GitHub
parent f3ccfd4598
commit a0015c3af6
6 changed files with 366 additions and 246 deletions

View File

@@ -1,100 +1,91 @@
"""Module for the Condition class.""" """Module for the Condition class."""
import warnings
from .data_condition import DataCondition from .data_condition import DataCondition
from .domain_equation_condition import DomainEquationCondition from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputEquationCondition from .input_equation_condition import InputEquationCondition
from .input_target_condition import InputTargetCondition from .input_target_condition import InputTargetCondition
from ..utils import custom_warning_format
# Set the custom format for warnings
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)
def warning_function(new, old):
"""Handle the deprecation warning.
:param new: Object to use instead of the old one.
:type new: str
:param old: Object to deprecate.
:type old: str
"""
warnings.warn(
f"'{old}' is deprecated and will be removed "
f"in future versions. Please use '{new}' instead.",
DeprecationWarning,
)
class Condition: class Condition:
""" """
Represents constraints (such as physical equations, boundary conditions, The :class:`Condition` class is a core component of the PINA framework that
etc.) that must be satisfied in a given problem. Condition objects are used provides a unified interface to define heterogeneous constraints that must
to formulate the PINA be satisfied by a :class:`~pina.problem.abstract_problem.AbstractProblem`.
:class:`~pina.problem.abstract_problem.AbstractProblem` object.
There are different types of conditions: It encapsulates all types of constraints - physical, boundary, initial, or
data-driven - that the solver must satisfy during training. The specific
behavior is inferred from the arguments passed to the constructor.
Multiple types of conditions can be used within the same problem, allowing
for a high degree of flexibility in defining complex problems.
The :class:`Condition` class behavior specializes internally based on the
arguments provided during instantiation. Depending on the specified keyword
arguments, the class automatically selects the appropriate internal
implementation.
Available `Condition` types:
- :class:`~pina.condition.input_target_condition.InputTargetCondition`: - :class:`~pina.condition.input_target_condition.InputTargetCondition`:
Defined by specifying both the input and the target of the condition. In represents a supervised condition defined by both ``input`` and ``target``
this case, the model is trained to produce the target given the input. The data. The model is trained to reproduce the ``target`` values given the
input and output data must be one of the :class:`torch.Tensor`, ``input``. Supported data types include :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
:class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`. :class:`~torch_geometric.data.Data`.
Different implementations exist depending on the type of input and target. The class automatically selects the appropriate implementation based on
For more details, see the types of ``input`` and ``target``.
:class:`~pina.condition.input_target_condition.InputTargetCondition`.
- :class:`~pina.condition.domain_equation_condition.DomainEquationCondition` - :class:`~pina.condition.domain_equation_condition.DomainEquationCondition`
: Defined by specifying both the domain and the equation of the condition. : represents a general physics-informed condition defined by a ``domain``
Here, the model is trained to minimize the equation residual by evaluating and an ``equation``. The model learns to minimize the equation residual
it at sampled points within the domain. through evaluations performed at points sampled from the specified domain.
- :class:`~pina.condition.input_equation_condition.InputEquationCondition`: - :class:`~pina.condition.input_equation_condition.InputEquationCondition`:
Defined by specifying the input and the equation of the condition. In this represents a general physics-informed condition defined by ``input``
case, the model is trained to minimize the equation residual by evaluating points and an ``equation``. The model learns to minimize the equation
it at the provided input. The input must be either a residual through evaluations performed at the provided ``input``.
:class:`~pina.label_tensor.LabelTensor` or a :class:`~pina.graph.Graph`. Supported data types for the ``input`` include
Different implementations exist depending on the type of input. For more :class:`~pina.label_tensor.LabelTensor` or :class:`~pina.graph.Graph`.
details, see The class automatically selects the appropriate implementation based on
:class:`~pina.condition.input_equation_condition.InputEquationCondition`. the types of the ``input``.
- :class:`~pina.condition.data_condition.DataCondition`: - :class:`~pina.condition.data_condition.DataCondition`: represents an
Defined by specifying only the input. In this case, the model is trained unsupervised, data-driven condition defined by the ``input`` only.
with an unsupervised custom loss while using the provided data during The model is trained using a custom unsupervised loss determined by the
training. The input data must be one of :class:`torch.Tensor`, chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging the
:class:`~pina.label_tensor.LabelTensor`, provided data during training. Optional ``conditional_variables`` can be
:class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`. specified when the model depends on additional parameters.
Additionally, conditional variables can be provided when the model Supported data types include :class:`torch.Tensor`,
depends on extra parameters. These conditional variables must be either :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
:class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`. :class:`~torch_geometric.data.Data`.
Different implementations exist depending on the type of input. The class automatically selects the appropriate implementation based on
For more details, see the type of the ``input``.
:class:`~pina.condition.data_condition.DataCondition`.
.. note::
The user should always instantiate :class:`Condition` directly, without
manually creating subclass instances. Please refer to the specific
:class:`Condition` classes for implementation details.
:Example: :Example:
>>> from pina import Condition >>> from pina import Condition
>>> condition = Condition(
... input=input,
... target=target
... )
>>> condition = Condition(
... domain=location,
... equation=equation
... )
>>> condition = Condition(
... input=input,
... equation=equation
... )
>>> condition = Condition(
... input=data,
... conditional_variables=conditional_variables
... )
>>> # Example of InputTargetCondition signature
>>> condition = Condition(input=input, target=target)
>>> # Example of DomainEquationCondition signature
>>> condition = Condition(domain=domain, equation=equation)
>>> # Example of InputEquationCondition signature
>>> condition = Condition(input=input, equation=equation)
>>> # Example of DataCondition signature
>>> condition = Condition(input=data, conditional_variables=cond_vars)
""" """
# Combine all possible keyword arguments from the different Condition types
__slots__ = list( __slots__ = list(
set( set(
InputTargetCondition.__slots__ InputTargetCondition.__slots__
@@ -106,46 +97,45 @@ class Condition:
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
Instantiate the appropriate Condition object based on the keyword Instantiate the appropriate :class:`Condition` object based on the
arguments passed. keyword arguments passed.
:raises ValueError: If no keyword arguments are passed. :param tuple args: The positional arguments (should be empty).
:param dict kwargs: The keyword arguments corresponding to the
parameters of the specific :class:`Condition` type to instantiate.
:raises ValueError: If unexpected positional arguments are provided.
:raises ValueError: If the keyword arguments are invalid. :raises ValueError: If the keyword arguments are invalid.
:return: The appropriate Condition object. :return: The appropriate :class:`Condition` object.
:rtype: ConditionInterface :rtype: ConditionInterface
""" """
# Check keyword arguments
if len(args) != 0: if len(args) != 0:
raise ValueError( raise ValueError(
"Condition takes only the following keyword " "Condition takes only the following keyword "
f"arguments: {Condition.__slots__}." f"arguments: {Condition.__slots__}."
) )
# back-compatibility 0.1 # Class specialization based on keyword arguments
keys = list(kwargs.keys())
if "location" in keys:
kwargs["domain"] = kwargs.pop("location")
warning_function(new="domain", old="location")
if "input_points" in keys:
kwargs["input"] = kwargs.pop("input_points")
warning_function(new="input", old="input_points")
if "output_points" in keys:
kwargs["target"] = kwargs.pop("output_points")
warning_function(new="target", old="output_points")
sorted_keys = sorted(kwargs.keys()) sorted_keys = sorted(kwargs.keys())
# Input - Target Condition
if sorted_keys == sorted(InputTargetCondition.__slots__): if sorted_keys == sorted(InputTargetCondition.__slots__):
return InputTargetCondition(**kwargs) return InputTargetCondition(**kwargs)
# Input - Equation Condition
if sorted_keys == sorted(InputEquationCondition.__slots__): if sorted_keys == sorted(InputEquationCondition.__slots__):
return InputEquationCondition(**kwargs) return InputEquationCondition(**kwargs)
# Domain - Equation Condition
if sorted_keys == sorted(DomainEquationCondition.__slots__): if sorted_keys == sorted(DomainEquationCondition.__slots__):
return DomainEquationCondition(**kwargs) return DomainEquationCondition(**kwargs)
# Data Condition
if ( if (
sorted_keys == sorted(DataCondition.__slots__) sorted_keys == sorted(DataCondition.__slots__)
or sorted_keys[0] == DataCondition.__slots__[0] or sorted_keys[0] == DataCondition.__slots__[0]
): ):
return DataCondition(**kwargs) return DataCondition(**kwargs)
# Invalid keyword arguments
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

View File

@@ -8,24 +8,25 @@ from ..graph import Graph
class ConditionInterface(metaclass=ABCMeta): class ConditionInterface(metaclass=ABCMeta):
""" """
Abstract class which defines a common interface for all the conditions. Abstract base class for PINA conditions. All specific conditions must
It defined a common interface for all the conditions. inherit from this interface.
Refer to :class:`pina.condition.condition.Condition` for a thorough
description of all available conditions and how to instantiate them.
""" """
def __init__(self): def __init__(self):
""" """
Initialize the ConditionInterface object. Initialization of the :class:`ConditionInterface` class.
""" """
self._problem = None self._problem = None
@property @property
def problem(self): def problem(self):
""" """
Return the problem to which the condition is associated. Return the problem associated with this condition.
:return: Problem to which the condition is associated. :return: Problem associated with this condition.
:rtype: ~pina.problem.abstract_problem.AbstractProblem :rtype: ~pina.problem.abstract_problem.AbstractProblem
""" """
return self._problem return self._problem
@@ -33,30 +34,31 @@ class ConditionInterface(metaclass=ABCMeta):
@problem.setter @problem.setter
def problem(self, value): def problem(self, value):
""" """
Set the problem to which the condition is associated. Set the problem associated with this condition.
:param pina.problem.abstract_problem.AbstractProblem value: Problem to :param pina.problem.abstract_problem.AbstractProblem value: The problem
which the condition is associated to associate with this condition
""" """
self._problem = value self._problem = value
@staticmethod @staticmethod
def _check_graph_list_consistency(data_list): def _check_graph_list_consistency(data_list):
""" """
Check the consistency of the list of Data/Graph objects. It performs Check the consistency of the list of Data | Graph objects.
the following checks: The following checks are performed:
1. All elements in the list must be of the same type (either Data or - All elements in the list must be of the same type (either
Graph). :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
2. All elements in the list must have the same keys.
3. The type of each tensor must be consistent across all elements in
the list.
4. If the tensor is a LabelTensor, the labels must be consistent across
all elements in the list.
:param data_list: List of Data/Graph objects to check - All elements in the list must have the same keys.
- The data type of each tensor must be consistent across all elements.
- If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
must also be consistent across all elements.
:param data_list: The list of Data | Graph objects to check.
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
:raises ValueError: If the input types are invalid. :raises ValueError: If the input types are invalid.
:raises ValueError: If all elements in the list do not have the same :raises ValueError: If all elements in the list do not have the same
keys. keys.
@@ -65,51 +67,45 @@ class ConditionInterface(metaclass=ABCMeta):
:raises ValueError: If the labels of the LabelTensors are not consistent :raises ValueError: If the labels of the LabelTensors are not consistent
across all elements in the list. across all elements in the list.
""" """
# If the data is a Graph or Data object, perform no checks
# If the data is a Graph or Data object, return (do not need to check
# anything)
if isinstance(data_list, (Graph, Data)): if isinstance(data_list, (Graph, Data)):
return return
# check all elements in the list are of the same type # Check all elements in the list are of the same type
if not all(isinstance(i, (Graph, Data)) for i in data_list): if not all(isinstance(i, (Graph, Data)) for i in data_list):
raise ValueError( raise ValueError(
"Invalid input types. " "Invalid input. Please, provide either Data or Graph objects."
"Please provide either Data or Graph objects."
) )
# Store the keys, data types and labels of the first element
data = data_list[0] data = data_list[0]
# Store the keys of the first element in the list
keys = sorted(list(data.keys())) keys = sorted(list(data.keys()))
# Store the type of each tensor inside first element Data/Graph object
data_types = {name: tensor.__class__ for name, tensor in data.items()} data_types = {name: tensor.__class__ for name, tensor in data.items()}
# Store the labels of each LabelTensor inside first element Data/Graph
# object
labels = { labels = {
name: tensor.labels name: tensor.labels
for name, tensor in data.items() for name, tensor in data.items()
if isinstance(tensor, LabelTensor) if isinstance(tensor, LabelTensor)
} }
# Iterate over the list of Data/Graph objects # Iterate over the list of Data | Graph objects
for data in data_list[1:]: for data in data_list[1:]:
# Check if the keys of the current element are the same as the first
# element # Check that all elements in the list have the same keys
if sorted(list(data.keys())) != keys: if sorted(list(data.keys())) != keys:
raise ValueError( raise ValueError(
"All elements in the list must have the same keys." "All elements in the list must have the same keys."
) )
# Iterate over the tensors in the current element
for name, tensor in data.items(): for name, tensor in data.items():
# Check if the type of each tensor inside the current element # Check that the type of each tensor is consistent
# is the same as the first element
if tensor.__class__ is not data_types[name]: if tensor.__class__ is not data_types[name]:
raise ValueError( raise ValueError(
f"Data {name} must be a {data_types[name]}, got " f"Data {name} must be a {data_types[name]}, got "
f"{tensor.__class__}" f"{tensor.__class__}"
) )
# If the tensor is a LabelTensor, check if the labels are the
# same as the first element # Check that the labels of each LabelTensor are consistent
if isinstance(tensor, LabelTensor): if isinstance(tensor, LabelTensor):
if tensor.labels != labels[name]: if tensor.labels != labels[name]:
raise ValueError( raise ValueError(
@@ -117,6 +113,13 @@ class ConditionInterface(metaclass=ABCMeta):
) )
def __getattribute__(self, name): def __getattribute__(self, name):
"""
Get an attribute from the object.
:param str name: The name of the attribute to get.
:return: The requested attribute.
:rtype: Any
"""
to_return = super().__getattribute__(name) to_return = super().__getattribute__(name)
if isinstance(to_return, (Graph, Data)): if isinstance(to_return, (Graph, Data)):
to_return = [to_return] to_return = [to_return]

View File

@@ -9,16 +9,35 @@ from ..graph import Graph
class DataCondition(ConditionInterface): class DataCondition(ConditionInterface):
""" """
Condition defined by input data and conditional variables. It can be used The class :class:`DataCondition` defines an unsupervised condition based on
in unsupervised learning problems. Based on the type of the input, ``input`` data. This condition is typically used in data-driven problems,
different condition implementations are available: where the model is trained using a custom unsupervised loss determined by
the chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging
the provided data during training. Optional ``conditional_variables`` can be
specified when the model depends on additional parameters.
- :class:`TensorDataCondition`: For :class:`torch.Tensor` or The class automatically selects the appropriate implementation based on the
:class:`~pina.label_tensor.LabelTensor` input data. type of the ``input`` data. Depending on whether the ``input`` is a tensor
- :class:`GraphDataCondition`: For :class:`~pina.graph.Graph` or or graph-based data, one of the following specialized subclasses is
:class:`~torch_geometric.data.Data` input data. instantiated:
- :class:`TensorDataCondition`: For cases where the ``input`` is either a
:class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object.
- :class:`GraphDataCondition`: For cases where the ``input`` is either a
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` object.
:Example:
>>> from pina import Condition, LabelTensor
>>> import torch
>>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> cond_vars = LabelTensor(torch.randn(100, 1), labels=["w"])
>>> condition = Condition(input=pts, conditional_variables=cond_vars)
""" """
# Available input data types
__slots__ = ["input", "conditional_variables"] __slots__ = ["input", "conditional_variables"]
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
_avail_conditional_variables_cls = (torch.Tensor, LabelTensor) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor)
@@ -26,33 +45,36 @@ class DataCondition(ConditionInterface):
def __new__(cls, input, conditional_variables=None): def __new__(cls, input, conditional_variables=None):
""" """
Instantiate the appropriate subclass of :class:`DataCondition` based on Instantiate the appropriate subclass of :class:`DataCondition` based on
the type of ``input``. the type of the ``input``.
:param input: Input data for the condition. :param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | :type input: torch.Tensor | LabelTensor | Graph |
Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
:param conditional_variables: Conditional variables for the condition. :param conditional_variables: The conditional variables for the
:type conditional_variables: torch.Tensor | LabelTensor, optional condition. Default is ``None``.
:return: Subclass of DataCondition. :type conditional_variables: torch.Tensor | LabelTensor
:return: The subclass of DataCondition.
:rtype: pina.condition.data_condition.TensorDataCondition | :rtype: pina.condition.data_condition.TensorDataCondition |
pina.condition.data_condition.GraphDataCondition pina.condition.data_condition.GraphDataCondition
:raises ValueError: If ``input`` is not of type :class:`torch.Tensor`,
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
or :class:`~torch_geometric.data.Data`. or :class:`~torch_geometric.data.Data`.
""" """
if cls != DataCondition: if cls != DataCondition:
return super().__new__(cls) return super().__new__(cls)
# If the input is a tensor
if isinstance(input, (torch.Tensor, LabelTensor)): if isinstance(input, (torch.Tensor, LabelTensor)):
subclass = TensorDataCondition subclass = TensorDataCondition
return subclass.__new__(subclass, input, conditional_variables) return subclass.__new__(subclass, input, conditional_variables)
# If the input is a graph
if isinstance(input, (Graph, Data, list, tuple)): if isinstance(input, (Graph, Data, list, tuple)):
cls._check_graph_list_consistency(input) cls._check_graph_list_consistency(input)
subclass = GraphDataCondition subclass = GraphDataCondition
return subclass.__new__(subclass, input, conditional_variables) return subclass.__new__(subclass, input, conditional_variables)
# If the input is not of the correct type raise an error
raise ValueError( raise ValueError(
"Invalid input types. " "Invalid input types. "
"Please provide either torch_geometric.data.Data or Graph objects." "Please provide either torch_geometric.data.Data or Graph objects."
@@ -60,21 +82,22 @@ class DataCondition(ConditionInterface):
def __init__(self, input, conditional_variables=None): def __init__(self, input, conditional_variables=None):
""" """
Initialize the object by storing the input and conditional Initialization of the :class:`DataCondition` class.
variables (if any).
:param input: Input data for the condition. :param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data] list[Data] | tuple[Graph] | tuple[Data]
:param conditional_variables: Conditional variables for the condition. :param conditional_variables: The conditional variables for the
condition. Default is ``None``.
:type conditional_variables: torch.Tensor | LabelTensor :type conditional_variables: torch.Tensor | LabelTensor
.. note:: .. note::
If ``input`` consists of a list of :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data`, all elements must have the same
structure (keys and data types)
"""
If ``input`` is a list of :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data`, all elements in
the list must share the same structure, with matching keys and
consistent data types.
"""
super().__init__() super().__init__()
self.input = input self.input = input
self.conditional_variables = conditional_variables self.conditional_variables = conditional_variables
@@ -82,13 +105,15 @@ class DataCondition(ConditionInterface):
class TensorDataCondition(DataCondition): class TensorDataCondition(DataCondition):
""" """
DataCondition for :class:`torch.Tensor` or Specialization of the :class:`DataCondition` class for the case where
:class:`~pina.label_tensor.LabelTensor` input data ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a
:class:`torch.Tensor` object.
""" """
class GraphDataCondition(DataCondition): class GraphDataCondition(DataCondition):
""" """
DataCondition for :class:`~pina.graph.Graph` or Specialization of the :class:`DataCondition` class for the case where
:class:`~torch_geometric.data.Data` input data ``input`` is either a :class:`~pina.graph.Graph` object or a
:class:`~torch_geometric.data.Data` object.
""" """

View File

@@ -8,31 +8,57 @@ from ..equation.equation_interface import EquationInterface
class DomainEquationCondition(ConditionInterface): class DomainEquationCondition(ConditionInterface):
""" """
Condition defined by a domain and an equation. It can be used in Physics The class :class:`DomainEquationCondition` defines a condition based on a
Informed problems. Before using this condition, make sure that input data ``domain`` and an ``equation``. This condition is typically used in
are correctly sampled from the domain. physics-informed problems, where the model is trained to satisfy a given
``equation`` over a specified ``domain``. The ``domain`` is used to sample
points where the ``equation`` residual is evaluated and minimized during
training.
:Example:
>>> from pina.domain import CartesianDomain
>>> from pina.equation import Equation
>>> from pina import Condition
>>> # Equation to be satisfied over the domain: # x^2 + y^2 - 1 = 0
>>> def dummy_equation(pts):
... return pts["x"]**2 + pts["y"]**2 - 1
>>> domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
>>> condition = Condition(domain=domain, equation=Equation(dummy_equation))
""" """
# Available slots
__slots__ = ["domain", "equation"] __slots__ = ["domain", "equation"]
def __init__(self, domain, equation): def __init__(self, domain, equation):
""" """
Initialise the object by storing the domain and equation. Initialization of the :class:`DomainEquationCondition` class.
:param DomainInterface domain: Domain object containing the domain data. :param DomainInterface domain: The domain over which the equation is
:param EquationInterface equation: Equation object containing the defined.
equation data. :param EquationInterface equation: The equation to be satisfied over the
specified domain.
""" """
super().__init__() super().__init__()
self.domain = domain self.domain = domain
self.equation = equation self.equation = equation
def __setattr__(self, key, value): def __setattr__(self, key, value):
"""
Set the attribute value with type checking.
:param str key: The attribute name.
:param any value: The value to set for the attribute.
"""
if key == "domain": if key == "domain":
check_consistency(value, (DomainInterface, str)) check_consistency(value, (DomainInterface, str))
DomainEquationCondition.__dict__[key].__set__(self, value) DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == "equation": elif key == "equation":
check_consistency(value, (EquationInterface)) check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value) DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ("_problem"): elif key in ("_problem"):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -1,6 +1,5 @@
"""Module for the InputEquationCondition class and its subclasses.""" """Module for the InputEquationCondition class and its subclasses."""
from torch_geometric.data import Data
from .condition_interface import ConditionInterface from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..graph import Graph from ..graph import Graph
@@ -10,16 +9,38 @@ from ..equation.equation_interface import EquationInterface
class InputEquationCondition(ConditionInterface): class InputEquationCondition(ConditionInterface):
""" """
Condition defined by input data and an equation. This condition can be The class :class:`InputEquationCondition` defines a condition based on
used in a Physics Informed problems. Based on the type of the input, ``input`` data and an ``equation``. This condition is typically used in
different condition implementations are available: physics-informed problems, where the model is trained to satisfy a given
``equation`` through the evaluation of the residual performed at the
provided ``input``.
- :class:`InputTensorEquationCondition`: For \ The class automatically selects the appropriate implementation based on
:class:`~pina.label_tensor.LabelTensor` input data. the type of the ``input`` data. Depending on whether the ``input`` is a
- :class:`InputGraphEquationCondition`: For :class:`~pina.graph.Graph` \ tensor or graph-based data, one of the following specialized subclasses is
input data. instantiated:
- :class:`InputTensorEquationCondition`: For cases where the ``input``
data is a :class:`~pina.label_tensor.LabelTensor` object.
- :class:`InputGraphEquationCondition`: For cases where the ``input`` data
is a :class:`~pina.graph.Graph` object.
:Example:
>>> from pina import Condition, LabelTensor
>>> from pina.equation import Equation
>>> import torch
>>> # Equation to be satisfied over the input points: # x^2 + y^2 - 1 = 0
>>> def dummy_equation(pts):
... return pts["x"]**2 + pts["y"]**2 - 1
>>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> condition = Condition(input=pts, equation=Equation(dummy_equation))
""" """
# Available input data types
__slots__ = ["input", "equation"] __slots__ = ["input", "equation"]
_avail_input_cls = (LabelTensor, Graph, list, tuple) _avail_input_cls = (LabelTensor, Graph, list, tuple)
_avail_equation_cls = EquationInterface _avail_equation_cls = EquationInterface
@@ -27,31 +48,31 @@ class InputEquationCondition(ConditionInterface):
def __new__(cls, input, equation): def __new__(cls, input, equation):
""" """
Instantiate the appropriate subclass of :class:`InputEquationCondition` Instantiate the appropriate subclass of :class:`InputEquationCondition`
based on the type of ``input``. based on the type of ``input`` data.
:param input: Input data for the condition. :param input: The input data for the condition.
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph] :type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
:param EquationInterface equation: Equation object containing the :param EquationInterface equation: The equation to be satisfied over the
equation function. specified ``input`` data.
:return: Subclass of InputEquationCondition, based on the input type. :return: The subclass of InputEquationCondition.
:rtype: pina.condition.input_equation_condition. :rtype: pina.condition.input_equation_condition.
InputTensorEquationCondition | InputTensorEquationCondition |
pina.condition.input_equation_condition.InputGraphEquationCondition pina.condition.input_equation_condition.InputGraphEquationCondition
:raises ValueError: If input is not of type :raises ValueError: If input is not of type :class:`~pina.graph.Graph`
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`. or :class:`~pina.label_tensor.LabelTensor`.
""" """
# If the class is already a subclass, return the instance
if cls != InputEquationCondition: if cls != InputEquationCondition:
return super().__new__(cls) return super().__new__(cls)
# Instanciate the correct subclass # If the input is a Graph object
if isinstance(input, (Graph, Data, list, tuple)): if isinstance(input, (Graph, list, tuple)):
subclass = InputGraphEquationCondition subclass = InputGraphEquationCondition
cls._check_graph_list_consistency(input) cls._check_graph_list_consistency(input)
subclass._check_label_tensor(input) subclass._check_label_tensor(input)
return subclass.__new__(subclass, input, equation) return subclass.__new__(subclass, input, equation)
# If the input is a LabelTensor
if isinstance(input, LabelTensor): if isinstance(input, LabelTensor):
subclass = InputTensorEquationCondition subclass = InputTensorEquationCondition
return subclass.__new__(subclass, input, equation) return subclass.__new__(subclass, input, equation)
@@ -63,69 +84,74 @@ class InputEquationCondition(ConditionInterface):
def __init__(self, input, equation): def __init__(self, input, equation):
""" """
Initialize the object by storing the input data and equation object. Initialization of the :class:`InputEquationCondition` class.
:param input: Input data for the condition. :param input: The input data for the condition.
:type input: LabelTensor | Graph | :type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
list[Graph] | tuple[Graph] :param EquationInterface equation: The equation to be satisfied over the
:param EquationInterface equation: Equation object containing the specified input points.
equation function.
.. note:: .. note::
If ``input`` consists of a list of :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data`, all elements must have the same
structure (keys and data types)
"""
If ``input`` is a list of :class:`~pina.graph.Graph` all elements in
the list must share the same structure, with matching keys and
consistent data types.
"""
super().__init__() super().__init__()
self.input = input self.input = input
self.equation = equation self.equation = equation
def __setattr__(self, key, value): def __setattr__(self, key, value):
"""
Set the attribute value with type checking.
:param str key: The attribute name.
:param any value: The value to set for the attribute.
"""
if key == "input": if key == "input":
check_consistency(value, self._avail_input_cls) check_consistency(value, self._avail_input_cls)
InputEquationCondition.__dict__[key].__set__(self, value) InputEquationCondition.__dict__[key].__set__(self, value)
elif key == "equation": elif key == "equation":
check_consistency(value, self._avail_equation_cls) check_consistency(value, self._avail_equation_cls)
InputEquationCondition.__dict__[key].__set__(self, value) InputEquationCondition.__dict__[key].__set__(self, value)
elif key in ("_problem"): elif key in ("_problem"):
super().__setattr__(key, value) super().__setattr__(key, value)
class InputTensorEquationCondition(InputEquationCondition): class InputTensorEquationCondition(InputEquationCondition):
""" """
InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor` Specialization of the :class:`InputEquationCondition` class for the case
input data. where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object.
""" """
class InputGraphEquationCondition(InputEquationCondition): class InputGraphEquationCondition(InputEquationCondition):
""" """
InputEquationCondition subclass for :class:`~pina.graph.Graph` input data. Specialization of the :class:`InputEquationCondition` class for the case
where ``input`` is a :class:`~pina.graph.Graph` object.
""" """
@staticmethod @staticmethod
def _check_label_tensor(input): def _check_label_tensor(input):
""" """
Check if at least one :class:`~pina.label_tensor.LabelTensor` is present Check if at least one :class:`~pina.label_tensor.LabelTensor` is present
in the :class:`~pina.graph.Graph` object. in the ``input`` object.
:param input: Input data.
:type input: torch.Tensor | Graph | Data
:param input: The input data.
:type input: torch.Tensor | Graph | list[Graph] | tuple[Graph]
:raises ValueError: If the input data object does not contain at least :raises ValueError: If the input data object does not contain at least
one LabelTensor. one LabelTensor.
""" """
# Store the fist element of the list/tuple if input is a list/tuple # Store the first element: it is sufficient to check this since all
# it is anougth to check the first element because all elements must # elements must have the same type and structure (already checked).
# have the same type and structure (already checked)
data = input[0] if isinstance(input, (list, tuple)) else input data = input[0] if isinstance(input, (list, tuple)) else input
# Check if the input data contains at least one LabelTensor # Check if the input data contains at least one LabelTensor
for v in data.values(): for v in data.values():
if isinstance(v, LabelTensor): if isinstance(v, LabelTensor):
return return
raise ValueError(
"The input data object must contain at least one LabelTensor." raise ValueError("The input must contain at least one LabelTensor.")
)

View File

@@ -11,39 +11,66 @@ from .condition_interface import ConditionInterface
class InputTargetCondition(ConditionInterface): class InputTargetCondition(ConditionInterface):
""" """
Condition defined by input and target data. This condition can be used in The :class:`InputTargetCondition` class represents a supervised condition
both supervised learning and Physics-informed problems. Based on the type of defined by both ``input`` and ``target`` data. The model is trained to
the input and target, different condition implementations are available: reproduce the ``target`` values given the ``input``. Supported data types
include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
- :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or \ The class automatically selects the appropriate implementation based on
:class:`~pina.label_tensor.LabelTensor` input and target data. the types of ``input`` and ``target``. Depending on whether the ``input``
- :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or \ and ``target`` are tensors or graph-based data, one of the following
:class:`~pina.label_tensor.LabelTensor` input and \ specialized subclasses is instantiated:
:class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` \
target data. - :class:`TensorInputTensorTargetCondition`: For cases where both ``input``
- :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph` \ and ``target`` data are either :class:`torch.Tensor` or
or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` \ :class:`~pina.label_tensor.LabelTensor`.
or :class:`~pina.label_tensor.LabelTensor` target data.
- :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` \ - :class:`TensorInputGraphTargetCondition`: For cases where ``input`` is
or :class:`~torch_geometric.data.Data` input and target data. either a :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`
and ``target`` is either a :class:`~pina.graph.Graph` or a
:class:`torch_geometric.data.Data`.
- :class:`GraphInputTensorTargetCondition`: For cases where ``input`` is
either a :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data`
and ``target`` is either a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor`.
- :class:`GraphInputGraphTargetCondition`: For cases where both ``input``
and ``target`` are either :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data`.
:Example:
>>> from pina import Condition, LabelTensor
>>> from pina.graph import Graph
>>> import torch
>>> pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> edge_index = torch.randint(0, 100, (2, 300))
>>> graph = Graph(pos=pos, edge_index=edge_index)
>>> input = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> condition = Condition(input=input, target=graph)
""" """
# Available input and target data types
__slots__ = ["input", "target"] __slots__ = ["input", "target"]
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
_avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
def __new__(cls, input, target): def __new__(cls, input, target):
""" """
Instantiate the appropriate subclass of InputTargetCondition based on Instantiate the appropriate subclass of :class:`InputTargetCondition`
the types of input and target data. based on the types of both ``input`` and ``target`` data.
:param input: Input data for the condition. :param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data] list[Data] | tuple[Graph] | tuple[Data]
:param target: Target data for the condition. :param target: The target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data] list[Data] | tuple[Graph] | tuple[Data]
:return: Subclass of InputTargetCondition :return: The subclass of InputTargetCondition.
:rtype: pina.condition.input_target_condition. :rtype: pina.condition.input_target_condition.
TensorInputTensorTargetCondition | TensorInputTensorTargetCondition |
pina.condition.input_target_condition. pina.condition.input_target_condition.
@@ -59,11 +86,14 @@ class InputTargetCondition(ConditionInterface):
if cls != InputTargetCondition: if cls != InputTargetCondition:
return super().__new__(cls) return super().__new__(cls)
# Tensor - Tensor
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (torch.Tensor, LabelTensor) target, (torch.Tensor, LabelTensor)
): ):
subclass = TensorInputTensorTargetCondition subclass = TensorInputTensorTargetCondition
return subclass.__new__(subclass, input, target) return subclass.__new__(subclass, input, target)
# Tensor - Graph
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (Graph, Data, list, tuple) target, (Graph, Data, list, tuple)
): ):
@@ -71,6 +101,7 @@ class InputTargetCondition(ConditionInterface):
subclass = TensorInputGraphTargetCondition subclass = TensorInputGraphTargetCondition
return subclass.__new__(subclass, input, target) return subclass.__new__(subclass, input, target)
# Graph - Tensor
if isinstance(input, (Graph, Data, list, tuple)) and isinstance( if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (torch.Tensor, LabelTensor) target, (torch.Tensor, LabelTensor)
): ):
@@ -78,6 +109,7 @@ class InputTargetCondition(ConditionInterface):
subclass = GraphInputTensorTargetCondition subclass = GraphInputTensorTargetCondition
return subclass.__new__(subclass, input, target) return subclass.__new__(subclass, input, target)
# Graph - Graph
if isinstance(input, (Graph, Data, list, tuple)) and isinstance( if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (Graph, Data, list, tuple) target, (Graph, Data, list, tuple)
): ):
@@ -86,30 +118,31 @@ class InputTargetCondition(ConditionInterface):
subclass = GraphInputGraphTargetCondition subclass = GraphInputGraphTargetCondition
return subclass.__new__(subclass, input, target) return subclass.__new__(subclass, input, target)
# If the input and/or target are not of the correct type raise an error
raise ValueError( raise ValueError(
"Invalid input/target types. " "Invalid input | target types."
"Please provide either torch_geometric.data.Data, Graph, " "Please provide either torch_geometric.data.Data, Graph, "
"LabelTensor or torch.Tensor objects." "LabelTensor or torch.Tensor objects."
) )
def __init__(self, input, target): def __init__(self, input, target):
""" """
Initialize the object by storing the ``input`` and ``target`` data. Initialization of the :class:`InputTargetCondition` class.
:param input: Input data for the condition. :param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data] list[Data] | tuple[Graph] | tuple[Data]
:param target: Target data for the condition. :param target: The target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data] list[Data] | tuple[Graph] | tuple[Data]
.. note:: .. note::
If either input or target consists of a list of
:class:~pina.graph.Graph or :class:~torch_geometric.data.Data
objects, all elements must have the same structure (matching
keys and data types).
"""
If either ``input`` or ``target`` is a list of
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
objects, all elements in the list must share the same structure,
with matching keys and consistent data types.
"""
super().__init__() super().__init__()
self._check_input_target_len(input, target) self._check_input_target_len(input, target)
self.input = input self.input = input
@@ -117,10 +150,24 @@ class InputTargetCondition(ConditionInterface):
@staticmethod @staticmethod
def _check_input_target_len(input, target): def _check_input_target_len(input, target):
"""
Check that the length of the input and target lists are the same.
:param input: The input data.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: The target data.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:raises ValueError: If the lengths of the input and target lists do not
match.
"""
if isinstance(input, (Graph, Data)) or isinstance( if isinstance(input, (Graph, Data)) or isinstance(
target, (Graph, Data) target, (Graph, Data)
): ):
return return
# Raise an error if the lengths of the input and target do not match
if len(input) != len(target): if len(input) != len(target):
raise ValueError( raise ValueError(
"The input and target lists must have the same length." "The input and target lists must have the same length."
@@ -129,30 +176,33 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition): class TensorInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`torch.Tensor` or Specialization of the :class:`InputTargetCondition` class for the case where
:class:`~pina.label_tensor.LabelTensor` ``input`` and ``target`` data. both ``input`` and ``target`` are :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` objects.
""" """
class TensorInputGraphTargetCondition(InputTargetCondition): class TensorInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`torch.Tensor` or Specialization of the :class:`InputTargetCondition` class for the case where
:class:`~pina.label_tensor.LabelTensor` ``input`` and ``input`` is either a :class:`torch.Tensor` or a
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target` :class:`~pina.label_tensor.LabelTensor` object and ``target`` is either a
data. :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object.
""" """
class GraphInputTensorTargetCondition(InputTargetCondition): class GraphInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`~pina.graph.Graph` o Specialization of the :class:`InputTargetCondition` class for the case where
:class:`~torch_geometric.data.Data` ``input`` and :class:`torch.Tensor` or ``input`` is either a :class:`~pina.graph.Graph` or
:class:`~pina.label_tensor.LabelTensor` ``target`` data. :class:`torch_geometric.data.Data` object and ``target`` is either a
:class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object.
""" """
class GraphInputGraphTargetCondition(InputTargetCondition): class GraphInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`~pina.graph.Graph`/ Specialization of the :class:`InputTargetCondition` class for the case where
:class:`~torch_geometric.data.Data` ``input`` and ``target`` data. both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data` objects.
""" """