new conditions

This commit is contained in:
Dario Coscia
2024-10-03 21:33:37 +02:00
committed by Nicola Demo
parent a888141707
commit fd16fcf9b4
8 changed files with 210 additions and 171 deletions

View File

@@ -1,10 +1,12 @@
__all__ = [ __all__ = [
'Condition', 'Condition',
'ConditionInterface', 'ConditionInterface',
'DomainOutputCondition', 'DomainEquationCondition',
'DomainEquationCondition' 'InputPointsEquationCondition',
'InputOutputPointsCondition',
] ]
from .condition_interface import ConditionInterface from .condition_interface import ConditionInterface
from .domain_output_condition import DomainOutputCondition from .domain_equation_condition import DomainEquationCondition
from .domain_equation_condition import DomainEquationCondition from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition

View File

@@ -1,27 +1,21 @@
""" Condition module. """ """ Condition module. """
from ..label_tensor import LabelTensor from .domain_equation_condition import DomainEquationCondition
from ..domain import DomainInterface from .input_equation_condition import InputPointsEquationCondition
from ..equation.equation import Equation from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface
from . import DomainOutputCondition, DomainEquationCondition
def dummy(a):
"""Dummy function for testing purposes."""
return None
class Condition: class Condition:
""" """
The class ``Condition`` is used to represent the constraints (physical The class ``Condition`` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object. problem at hand. Condition objects are used to formulate the
Conditions can be specified in three ways: PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in four ways:
1. By specifying the input and output points of the condition; in such a 1. By specifying the input and output points of the condition; in such a
case, the model is trained to produce the output points given the input case, the model is trained to produce the output points given the input
points. points. Those points can either be torch.Tensor, LabelTensors, Graph
2. By specifying the location and the equation of the condition; in such 2. By specifying the location and the equation of the condition; in such
a case, the model is trained to minimize the equation residual by a case, the model is trained to minimize the equation residual by
@@ -29,79 +23,48 @@ class Condition:
3. By specifying the input points and the equation of the condition; in 3. By specifying the input points and the equation of the condition; in
such a case, the model is trained to minimize the equation residual by such a case, the model is trained to minimize the equation residual by
evaluating it at the passed input points. evaluating it at the passed input points. The input points must be
a LabelTensor.
4. By specifying only the data matrix; in such a case the model is
trained with an unsupervised costum loss and uses the data in training.
Additionaly conditioning variables can be passed, whenever the model
has extra conditioning variable it depends on.
Example:: Example::
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]}) >>> TODO
>>> def example_dirichlet(input_, output_):
>>> value = 0.0
>>> return output_.extract(['u']) - value
>>> example_input_pts = LabelTensor(
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
>>>
>>> Condition(
>>> input_points=example_input_pts,
>>> output_points=example_output_pts)
>>> Condition(
>>> location=example_domain,
>>> equation=example_dirichlet)
>>> Condition(
>>> input_points=example_input_pts,
>>> equation=example_dirichlet)
""" """
# def _dictvalue_isinstance(self, dict_, key_, class_): __slots__ = list(
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`.""" set(
# if key_ not in dict_.keys(): InputOutputPointsCondition.__slots__,
# return True InputPointsEquationCondition.__slots__,
DomainEquationCondition.__slots__,
DataConditionInterface.__slots__
# return isinstance(dict_[key_], class_) )
)
# def __init__(self, *args, **kwargs):
# """
# Constructor for the `Condition` class.
# """
# self.data_weight = kwargs.pop("data_weight", 1.0)
# if len(args) != 0:
# raise ValueError(
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
# )
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]): if len(args) != 0:
return DomainOutputCondition( raise ValueError(
domain=kwargs["input_points"], f"Condition takes only the following keyword '
output_points=kwargs["output_points"] 'arguments: {Condition.__slots__}."
) )
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
return DomainOutputCondition(**kwargs) sorted_keys = sorted(kwargs.keys())
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]): if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
return InputOutputPointsCondition(**kwargs)
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
return InputPointsEquationCondition(**kwargs)
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
return DomainEquationCondition(**kwargs) return DomainEquationCondition(**kwargs)
elif sorted_keys == sorted(DataConditionInterface.__slots__):
return DataConditionInterface(**kwargs)
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
else: else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
'''
if (
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
and sorted(kwargs.keys()) != sorted(["location", "equation"])
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
):
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
raise TypeError("`input_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
raise TypeError("`output_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "location", Location):
raise TypeError("`location` must be a Location.")
if not self._dictvalue_isinstance(kwargs, "equation", Equation):
raise TypeError("`equation` must be a Equation.")
for key, value in kwargs.items():
setattr(self, key, value)
'''

View File

@@ -1,21 +1,25 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta
class ConditionInterface(metaclass=ABCMeta): class ConditionInterface(metaclass=ABCMeta):
def __init__(self) -> None: condition_types = ['physical', 'supervised', 'unsupervised']
self._problem = None def __init__(self):
self._condition_type = None
@abstractmethod @property
def residual(self, model): def condition_type(self):
""" return self._condition_type
Compute the residual of the condition.
@condition_type.setattr
:param model: The model to evaluate the condition. def condition_type(self, values):
:return: The residual of the condition. if not isinstance(values, (list, tuple)):
""" values = [values]
pass for value in values:
if value not in ConditionInterface.condition_types:
def set_problem(self, problem): raise ValueError(
self._problem = problem 'Unavailable type of condition, expected one of'
f' {ConditionInterface.condition_types}.'
)
self._condition_type = values

View File

@@ -0,0 +1,44 @@
import torch
from . import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
class DataConditionInterface(ConditionInterface):
"""
Condition for data. This condition must be used every
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
can be passed as extra-input when the model learns a conditional
distribution
"""
__slots__ = ["data", "conditionalvariable"]
def __init__(self, data, conditionalvariable=None):
"""
TODO
"""
super().__init__()
self.data = data
self.conditionalvariable = conditionalvariable
self.condition_type = 'unsupervised'
@property
def data(self):
return self._data
@data.setter
def data(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._data = value
@property
def conditionalvariable(self):
return self._conditionalvariable
@data.setter
def conditionalvariable(self, value):
if value is not None:
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._data = value

View File

@@ -1,34 +1,43 @@
import torch
from .condition_interface import ConditionInterface from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
from ..domain import DomainInterface
from ..equation.equation_interface import EquationInterface
class DomainEquationCondition(ConditionInterface): class DomainEquationCondition(ConditionInterface):
""" """
Condition for input/output data. Condition for domain/equation data. This condition must be used every
time a Physics Informed Loss is needed in the Solver.
""" """
__slots__ = ["domain", "equation"] __slots__ = ["domain", "equation"]
def __init__(self, domain, equation): def __init__(self, domain, equation):
""" """
Constructor for the `InputOutputCondition` class. TODO
""" """
super().__init__() super().__init__()
self.domain = domain self.domain = domain
self.equation = equation self.equation = equation
self.condition_type = 'physics'
def residual(self, model): @property
""" def domain(self):
Compute the residual of the condition. return self._domain
"""
self.batch_residual(model, self.domain, self.equation) @domain.setter
def domain(self, value):
check_consistency(value, (DomainInterface))
self._domain = value
@staticmethod @property
def batch_residual(model, input_pts, equation): def equation(self):
""" return self._equation
Compute the residual of the condition for a single batch. Input and
output points are provided as arguments. @equation.setter
def equation(self, value):
:param torch.nn.Module model: The model to evaluate the condition. check_consistency(value, (EquationInterface))
:param torch.Tensor input_pts: The input points. self._equation = value
:param torch.Tensor equation: The output points.
"""
return equation.residual(input_pts, model(input_pts))

View File

@@ -1,44 +0,0 @@
from . import ConditionInterface
class DomainOutputCondition(ConditionInterface):
"""
Condition for input/output data.
"""
__slots__ = ["domain", "output_points"]
def __init__(self, domain, output_points):
"""
Constructor for the `InputOutputCondition` class.
"""
super().__init__()
print(self)
self.domain = domain
self.output_points = output_points
@property
def input_points(self):
"""
Get the input points of the condition.
"""
return self._problem.domains[self.domain]
def residual(self, model):
"""
Compute the residual of the condition.
"""
return self.batch_residual(model, self.domain, self.output_points)
@staticmethod
def batch_residual(model, input_points, output_points):
"""
Compute the residual of the condition for a single batch. Input and
output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points.
"""
return output_points - model(input_points)

View File

@@ -1,23 +1,42 @@
import torch
from . import ConditionInterface from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
from ..equation.equation_interface import EquationInterface
class InputEquationCondition(ConditionInterface): class InputPointsEquationCondition(ConditionInterface):
""" """
Condition for input/output data. Condition for input_points/equation data. This condition must be used every
time a Physics Informed Loss is needed in the Solver.
""" """
__slots__ = ["input_points", "output_points"] __slots__ = ["input_points", "equation"]
def __init__(self, input_points, output_points): def __init__(self, input_points, equation):
""" """
Constructor for the `InputOutputCondition` class. TODO
""" """
super().__init__() super().__init__()
self.input_points = input_points self.input_points = input_points
self.output_points = output_points self.equation = equation
self.condition_type = 'physics'
def residual(self, model): @property
""" def input_points(self):
Compute the residual of the condition. return self._input_points
"""
return self.output_points - model(self.input_points) @input_points.setter
def input_points(self, value):
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
self._input_points = value
@property
def equation(self):
return self._equation
@equation.setter
def equation(self, value):
check_consistency(value, (EquationInterface))
self._equation = value

View File

@@ -0,0 +1,42 @@
import torch
from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
class InputOutputPointsCondition(ConditionInterface):
"""
Condition for domain/equation data. This condition must be used every
time a Physics Informed or a Supervised Loss is needed in the Solver.
"""
__slots__ = ["input_points", "output_points"]
def __init__(self, input_points, output_points):
"""
TODO
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
self.condition_type = ['supervised', 'physics']
@property
def input_points(self):
return self._input_points
@input_points.setter
def input_points(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._input_points = value
@property
def output_points(self):
return self._output_points
@output_points.setter
def output_points(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._output_points = value