new conditions
This commit is contained in:
committed by
Nicola Demo
parent
a888141707
commit
fd16fcf9b4
@@ -1,10 +1,12 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'Condition',
|
'Condition',
|
||||||
'ConditionInterface',
|
'ConditionInterface',
|
||||||
'DomainOutputCondition',
|
'DomainEquationCondition',
|
||||||
'DomainEquationCondition'
|
'InputPointsEquationCondition',
|
||||||
|
'InputOutputPointsCondition',
|
||||||
]
|
]
|
||||||
|
|
||||||
from .condition_interface import ConditionInterface
|
from .condition_interface import ConditionInterface
|
||||||
from .domain_output_condition import DomainOutputCondition
|
from .domain_equation_condition import DomainEquationCondition
|
||||||
from .domain_equation_condition import DomainEquationCondition
|
from .input_equation_condition import InputPointsEquationCondition
|
||||||
|
from .input_output_condition import InputOutputPointsCondition
|
||||||
@@ -1,27 +1,21 @@
|
|||||||
""" Condition module. """
|
""" Condition module. """
|
||||||
|
|
||||||
from ..label_tensor import LabelTensor
|
from .domain_equation_condition import DomainEquationCondition
|
||||||
from ..domain import DomainInterface
|
from .input_equation_condition import InputPointsEquationCondition
|
||||||
from ..equation.equation import Equation
|
from .input_output_condition import InputOutputPointsCondition
|
||||||
|
from .data_condition import DataConditionInterface
|
||||||
from . import DomainOutputCondition, DomainEquationCondition
|
|
||||||
|
|
||||||
|
|
||||||
def dummy(a):
|
|
||||||
"""Dummy function for testing purposes."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Condition:
|
class Condition:
|
||||||
"""
|
"""
|
||||||
The class ``Condition`` is used to represent the constraints (physical
|
The class ``Condition`` is used to represent the constraints (physical
|
||||||
equations, boundary conditions, etc.) that should be satisfied in the
|
equations, boundary conditions, etc.) that should be satisfied in the
|
||||||
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
problem at hand. Condition objects are used to formulate the
|
||||||
Conditions can be specified in three ways:
|
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
||||||
|
Conditions can be specified in four ways:
|
||||||
|
|
||||||
1. By specifying the input and output points of the condition; in such a
|
1. By specifying the input and output points of the condition; in such a
|
||||||
case, the model is trained to produce the output points given the input
|
case, the model is trained to produce the output points given the input
|
||||||
points.
|
points. Those points can either be torch.Tensor, LabelTensors, Graph
|
||||||
|
|
||||||
2. By specifying the location and the equation of the condition; in such
|
2. By specifying the location and the equation of the condition; in such
|
||||||
a case, the model is trained to minimize the equation residual by
|
a case, the model is trained to minimize the equation residual by
|
||||||
@@ -29,79 +23,48 @@ class Condition:
|
|||||||
|
|
||||||
3. By specifying the input points and the equation of the condition; in
|
3. By specifying the input points and the equation of the condition; in
|
||||||
such a case, the model is trained to minimize the equation residual by
|
such a case, the model is trained to minimize the equation residual by
|
||||||
evaluating it at the passed input points.
|
evaluating it at the passed input points. The input points must be
|
||||||
|
a LabelTensor.
|
||||||
|
|
||||||
|
4. By specifying only the data matrix; in such a case the model is
|
||||||
|
trained with an unsupervised costum loss and uses the data in training.
|
||||||
|
Additionaly conditioning variables can be passed, whenever the model
|
||||||
|
has extra conditioning variable it depends on.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
>>> TODO
|
||||||
>>> def example_dirichlet(input_, output_):
|
|
||||||
>>> value = 0.0
|
|
||||||
>>> return output_.extract(['u']) - value
|
|
||||||
>>> example_input_pts = LabelTensor(
|
|
||||||
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
|
||||||
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
|
||||||
>>>
|
|
||||||
>>> Condition(
|
|
||||||
>>> input_points=example_input_pts,
|
|
||||||
>>> output_points=example_output_pts)
|
|
||||||
>>> Condition(
|
|
||||||
>>> location=example_domain,
|
|
||||||
>>> equation=example_dirichlet)
|
|
||||||
>>> Condition(
|
|
||||||
>>> input_points=example_input_pts,
|
|
||||||
>>> equation=example_dirichlet)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# def _dictvalue_isinstance(self, dict_, key_, class_):
|
__slots__ = list(
|
||||||
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
set(
|
||||||
# if key_ not in dict_.keys():
|
InputOutputPointsCondition.__slots__,
|
||||||
# return True
|
InputPointsEquationCondition.__slots__,
|
||||||
|
DomainEquationCondition.__slots__,
|
||||||
|
DataConditionInterface.__slots__
|
||||||
|
|
||||||
# return isinstance(dict_[key_], class_)
|
)
|
||||||
|
)
|
||||||
# def __init__(self, *args, **kwargs):
|
|
||||||
# """
|
|
||||||
# Constructor for the `Condition` class.
|
|
||||||
# """
|
|
||||||
# self.data_weight = kwargs.pop("data_weight", 1.0)
|
|
||||||
|
|
||||||
# if len(args) != 0:
|
|
||||||
# raise ValueError(
|
|
||||||
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
|
|
||||||
# )
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
|
||||||
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
|
if len(args) != 0:
|
||||||
return DomainOutputCondition(
|
raise ValueError(
|
||||||
domain=kwargs["input_points"],
|
f"Condition takes only the following keyword '
|
||||||
output_points=kwargs["output_points"]
|
'arguments: {Condition.__slots__}."
|
||||||
)
|
)
|
||||||
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
|
|
||||||
return DomainOutputCondition(**kwargs)
|
sorted_keys = sorted(kwargs.keys())
|
||||||
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
|
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
|
||||||
|
return InputOutputPointsCondition(**kwargs)
|
||||||
|
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
|
||||||
|
return InputPointsEquationCondition(**kwargs)
|
||||||
|
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
|
||||||
return DomainEquationCondition(**kwargs)
|
return DomainEquationCondition(**kwargs)
|
||||||
|
elif sorted_keys == sorted(DataConditionInterface.__slots__):
|
||||||
|
return DataConditionInterface(**kwargs)
|
||||||
|
elif sorted_keys == DataConditionInterface.__slots__[0]:
|
||||||
|
return DataConditionInterface(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||||
# TODO: remove, not used anymore
|
|
||||||
'''
|
|
||||||
if (
|
|
||||||
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
|
||||||
and sorted(kwargs.keys()) != sorted(["location", "equation"])
|
|
||||||
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
|
|
||||||
):
|
|
||||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
|
||||||
# TODO: remove, not used anymore
|
|
||||||
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
|
|
||||||
raise TypeError("`input_points` must be a torch.Tensor.")
|
|
||||||
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
|
|
||||||
raise TypeError("`output_points` must be a torch.Tensor.")
|
|
||||||
if not self._dictvalue_isinstance(kwargs, "location", Location):
|
|
||||||
raise TypeError("`location` must be a Location.")
|
|
||||||
if not self._dictvalue_isinstance(kwargs, "equation", Equation):
|
|
||||||
raise TypeError("`equation` must be a Equation.")
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
setattr(self, key, value)
|
|
||||||
'''
|
|
||||||
@@ -1,21 +1,25 @@
|
|||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta
|
||||||
|
|
||||||
|
|
||||||
class ConditionInterface(metaclass=ABCMeta):
|
class ConditionInterface(metaclass=ABCMeta):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
condition_types = ['physical', 'supervised', 'unsupervised']
|
||||||
self._problem = None
|
def __init__(self):
|
||||||
|
self._condition_type = None
|
||||||
|
|
||||||
@abstractmethod
|
@property
|
||||||
def residual(self, model):
|
def condition_type(self):
|
||||||
"""
|
return self._condition_type
|
||||||
Compute the residual of the condition.
|
|
||||||
|
@condition_type.setattr
|
||||||
:param model: The model to evaluate the condition.
|
def condition_type(self, values):
|
||||||
:return: The residual of the condition.
|
if not isinstance(values, (list, tuple)):
|
||||||
"""
|
values = [values]
|
||||||
pass
|
for value in values:
|
||||||
|
if value not in ConditionInterface.condition_types:
|
||||||
def set_problem(self, problem):
|
raise ValueError(
|
||||||
self._problem = problem
|
'Unavailable type of condition, expected one of'
|
||||||
|
f' {ConditionInterface.condition_types}.'
|
||||||
|
)
|
||||||
|
self._condition_type = values
|
||||||
44
pina/condition/data_condition.py
Normal file
44
pina/condition/data_condition.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from . import ConditionInterface
|
||||||
|
from ..label_tensor import LabelTensor
|
||||||
|
from ..graph import Graph
|
||||||
|
from ..utils import check_consistency
|
||||||
|
|
||||||
|
class DataConditionInterface(ConditionInterface):
|
||||||
|
"""
|
||||||
|
Condition for data. This condition must be used every
|
||||||
|
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
|
||||||
|
can be passed as extra-input when the model learns a conditional
|
||||||
|
distribution
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["data", "conditionalvariable"]
|
||||||
|
|
||||||
|
def __init__(self, data, conditionalvariable=None):
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.data = data
|
||||||
|
self.conditionalvariable = conditionalvariable
|
||||||
|
self.condition_type = 'unsupervised'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
@data.setter
|
||||||
|
def data(self, value):
|
||||||
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
|
self._data = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conditionalvariable(self):
|
||||||
|
return self._conditionalvariable
|
||||||
|
|
||||||
|
@data.setter
|
||||||
|
def conditionalvariable(self, value):
|
||||||
|
if value is not None:
|
||||||
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
|
self._data = value
|
||||||
@@ -1,34 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from .condition_interface import ConditionInterface
|
from .condition_interface import ConditionInterface
|
||||||
|
from ..label_tensor import LabelTensor
|
||||||
|
from ..graph import Graph
|
||||||
|
from ..utils import check_consistency
|
||||||
|
from ..domain import DomainInterface
|
||||||
|
from ..equation.equation_interface import EquationInterface
|
||||||
|
|
||||||
class DomainEquationCondition(ConditionInterface):
|
class DomainEquationCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition for input/output data.
|
Condition for domain/equation data. This condition must be used every
|
||||||
|
time a Physics Informed Loss is needed in the Solver.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["domain", "equation"]
|
__slots__ = ["domain", "equation"]
|
||||||
|
|
||||||
def __init__(self, domain, equation):
|
def __init__(self, domain, equation):
|
||||||
"""
|
"""
|
||||||
Constructor for the `InputOutputCondition` class.
|
TODO
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.domain = domain
|
self.domain = domain
|
||||||
self.equation = equation
|
self.equation = equation
|
||||||
|
self.condition_type = 'physics'
|
||||||
|
|
||||||
def residual(self, model):
|
@property
|
||||||
"""
|
def domain(self):
|
||||||
Compute the residual of the condition.
|
return self._domain
|
||||||
"""
|
|
||||||
self.batch_residual(model, self.domain, self.equation)
|
@domain.setter
|
||||||
|
def domain(self, value):
|
||||||
|
check_consistency(value, (DomainInterface))
|
||||||
|
self._domain = value
|
||||||
|
|
||||||
@staticmethod
|
@property
|
||||||
def batch_residual(model, input_pts, equation):
|
def equation(self):
|
||||||
"""
|
return self._equation
|
||||||
Compute the residual of the condition for a single batch. Input and
|
|
||||||
output points are provided as arguments.
|
@equation.setter
|
||||||
|
def equation(self, value):
|
||||||
:param torch.nn.Module model: The model to evaluate the condition.
|
check_consistency(value, (EquationInterface))
|
||||||
:param torch.Tensor input_pts: The input points.
|
self._equation = value
|
||||||
:param torch.Tensor equation: The output points.
|
|
||||||
"""
|
|
||||||
return equation.residual(input_pts, model(input_pts))
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
|
|
||||||
from . import ConditionInterface
|
|
||||||
|
|
||||||
class DomainOutputCondition(ConditionInterface):
|
|
||||||
"""
|
|
||||||
Condition for input/output data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ["domain", "output_points"]
|
|
||||||
|
|
||||||
def __init__(self, domain, output_points):
|
|
||||||
"""
|
|
||||||
Constructor for the `InputOutputCondition` class.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
print(self)
|
|
||||||
self.domain = domain
|
|
||||||
self.output_points = output_points
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_points(self):
|
|
||||||
"""
|
|
||||||
Get the input points of the condition.
|
|
||||||
"""
|
|
||||||
return self._problem.domains[self.domain]
|
|
||||||
|
|
||||||
def residual(self, model):
|
|
||||||
"""
|
|
||||||
Compute the residual of the condition.
|
|
||||||
"""
|
|
||||||
return self.batch_residual(model, self.domain, self.output_points)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def batch_residual(model, input_points, output_points):
|
|
||||||
"""
|
|
||||||
Compute the residual of the condition for a single batch. Input and
|
|
||||||
output points are provided as arguments.
|
|
||||||
|
|
||||||
:param torch.nn.Module model: The model to evaluate the condition.
|
|
||||||
:param torch.Tensor input_points: The input points.
|
|
||||||
:param torch.Tensor output_points: The output points.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return output_points - model(input_points)
|
|
||||||
@@ -1,23 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from . import ConditionInterface
|
from .condition_interface import ConditionInterface
|
||||||
|
from ..label_tensor import LabelTensor
|
||||||
|
from ..graph import Graph
|
||||||
|
from ..utils import check_consistency
|
||||||
|
from ..equation.equation_interface import EquationInterface
|
||||||
|
|
||||||
class InputEquationCondition(ConditionInterface):
|
class InputPointsEquationCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition for input/output data.
|
Condition for input_points/equation data. This condition must be used every
|
||||||
|
time a Physics Informed Loss is needed in the Solver.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input_points", "output_points"]
|
__slots__ = ["input_points", "equation"]
|
||||||
|
|
||||||
def __init__(self, input_points, output_points):
|
def __init__(self, input_points, equation):
|
||||||
"""
|
"""
|
||||||
Constructor for the `InputOutputCondition` class.
|
TODO
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_points = input_points
|
self.input_points = input_points
|
||||||
self.output_points = output_points
|
self.equation = equation
|
||||||
|
self.condition_type = 'physics'
|
||||||
|
|
||||||
def residual(self, model):
|
@property
|
||||||
"""
|
def input_points(self):
|
||||||
Compute the residual of the condition.
|
return self._input_points
|
||||||
"""
|
|
||||||
return self.output_points - model(self.input_points)
|
@input_points.setter
|
||||||
|
def input_points(self, value):
|
||||||
|
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
||||||
|
self._input_points = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def equation(self):
|
||||||
|
return self._equation
|
||||||
|
|
||||||
|
@equation.setter
|
||||||
|
def equation(self, value):
|
||||||
|
check_consistency(value, (EquationInterface))
|
||||||
|
self._equation = value
|
||||||
42
pina/condition/input_output_condition.py
Normal file
42
pina/condition/input_output_condition.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .condition_interface import ConditionInterface
|
||||||
|
from ..label_tensor import LabelTensor
|
||||||
|
from ..graph import Graph
|
||||||
|
from ..utils import check_consistency
|
||||||
|
|
||||||
|
class InputOutputPointsCondition(ConditionInterface):
|
||||||
|
"""
|
||||||
|
Condition for domain/equation data. This condition must be used every
|
||||||
|
time a Physics Informed or a Supervised Loss is needed in the Solver.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["input_points", "output_points"]
|
||||||
|
|
||||||
|
def __init__(self, input_points, output_points):
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.input_points = input_points
|
||||||
|
self.output_points = output_points
|
||||||
|
self.condition_type = ['supervised', 'physics']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_points(self):
|
||||||
|
return self._input_points
|
||||||
|
|
||||||
|
@input_points.setter
|
||||||
|
def input_points(self, value):
|
||||||
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
|
self._input_points = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_points(self):
|
||||||
|
return self._output_points
|
||||||
|
|
||||||
|
@output_points.setter
|
||||||
|
def output_points(self, value):
|
||||||
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
|
self._output_points = value
|
||||||
Reference in New Issue
Block a user