From fd16fcf9b408e335a1bcdb5254fe5d75c8bedda1 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Thu, 3 Oct 2024 21:33:37 +0200 Subject: [PATCH] new conditions --- pina/condition/__init__.py | 10 +- pina/condition/condition.py | 119 +++++++------------- pina/condition/condition_interface.py | 34 +++--- pina/condition/data_condition.py | 44 ++++++++ pina/condition/domain_equation_condition.py | 45 +++++--- pina/condition/domain_output_condition.py | 44 -------- pina/condition/input_equation_condition.py | 43 +++++-- pina/condition/input_output_condition.py | 42 +++++++ 8 files changed, 210 insertions(+), 171 deletions(-) create mode 100644 pina/condition/data_condition.py delete mode 100644 pina/condition/domain_output_condition.py create mode 100644 pina/condition/input_output_condition.py diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index ff329c1..4c89b75 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -1,10 +1,12 @@ __all__ = [ 'Condition', 'ConditionInterface', - 'DomainOutputCondition', - 'DomainEquationCondition' + 'DomainEquationCondition', + 'InputPointsEquationCondition', + 'InputOutputPointsCondition', ] from .condition_interface import ConditionInterface -from .domain_output_condition import DomainOutputCondition -from .domain_equation_condition import DomainEquationCondition \ No newline at end of file +from .domain_equation_condition import DomainEquationCondition +from .input_equation_condition import InputPointsEquationCondition +from .input_output_condition import InputOutputPointsCondition \ No newline at end of file diff --git a/pina/condition/condition.py b/pina/condition/condition.py index d815838..ddc722f 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -1,27 +1,21 @@ """ Condition module. """ -from ..label_tensor import LabelTensor -from ..domain import DomainInterface -from ..equation.equation import Equation - -from . import DomainOutputCondition, DomainEquationCondition - - -def dummy(a): - """Dummy function for testing purposes.""" - return None - +from .domain_equation_condition import DomainEquationCondition +from .input_equation_condition import InputPointsEquationCondition +from .input_output_condition import InputOutputPointsCondition +from .data_condition import DataConditionInterface class Condition: """ The class ``Condition`` is used to represent the constraints (physical equations, boundary conditions, etc.) that should be satisfied in the - problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object. - Conditions can be specified in three ways: + problem at hand. Condition objects are used to formulate the + PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object. + Conditions can be specified in four ways: 1. By specifying the input and output points of the condition; in such a case, the model is trained to produce the output points given the input - points. + points. Those points can either be torch.Tensor, LabelTensors, Graph 2. By specifying the location and the equation of the condition; in such a case, the model is trained to minimize the equation residual by @@ -29,79 +23,48 @@ class Condition: 3. By specifying the input points and the equation of the condition; in such a case, the model is trained to minimize the equation residual by - evaluating it at the passed input points. + evaluating it at the passed input points. The input points must be + a LabelTensor. + + 4. By specifying only the data matrix; in such a case the model is + trained with an unsupervised costum loss and uses the data in training. + Additionaly conditioning variables can be passed, whenever the model + has extra conditioning variable it depends on. Example:: - >>> example_domain = Span({'x': [0, 1], 'y': [0, 1]}) - >>> def example_dirichlet(input_, output_): - >>> value = 0.0 - >>> return output_.extract(['u']) - value - >>> example_input_pts = LabelTensor( - >>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z']) - >>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b']) - >>> - >>> Condition( - >>> input_points=example_input_pts, - >>> output_points=example_output_pts) - >>> Condition( - >>> location=example_domain, - >>> equation=example_dirichlet) - >>> Condition( - >>> input_points=example_input_pts, - >>> equation=example_dirichlet) + >>> TODO """ - # def _dictvalue_isinstance(self, dict_, key_, class_): - # """Check if the value of a dictionary corresponding to `key` is an instance of `class_`.""" - # if key_ not in dict_.keys(): - # return True + __slots__ = list( + set( + InputOutputPointsCondition.__slots__, + InputPointsEquationCondition.__slots__, + DomainEquationCondition.__slots__, + DataConditionInterface.__slots__ - # return isinstance(dict_[key_], class_) - - # def __init__(self, *args, **kwargs): - # """ - # Constructor for the `Condition` class. - # """ - # self.data_weight = kwargs.pop("data_weight", 1.0) - - # if len(args) != 0: - # raise ValueError( - # f"Condition takes only the following keyword arguments: {Condition.__slots__}." - # ) + ) + ) def __new__(cls, *args, **kwargs): - - if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]): - return DomainOutputCondition( - domain=kwargs["input_points"], - output_points=kwargs["output_points"] + + if len(args) != 0: + raise ValueError( + f"Condition takes only the following keyword ' + 'arguments: {Condition.__slots__}." ) - elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]): - return DomainOutputCondition(**kwargs) - elif sorted(kwargs.keys()) == sorted(["domain", "equation"]): + + sorted_keys = sorted(kwargs.keys()) + if sorted_keys == sorted(InputOutputPointsCondition.__slots__): + return InputOutputPointsCondition(**kwargs) + elif sorted_keys == sorted(InputPointsEquationCondition.__slots__): + return InputPointsEquationCondition(**kwargs) + elif sorted_keys == sorted(DomainEquationCondition.__slots__): return DomainEquationCondition(**kwargs) + elif sorted_keys == sorted(DataConditionInterface.__slots__): + return DataConditionInterface(**kwargs) + elif sorted_keys == DataConditionInterface.__slots__[0]: + return DataConditionInterface(**kwargs) else: - raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") - # TODO: remove, not used anymore - ''' - if ( - sorted(kwargs.keys()) != sorted(["input_points", "output_points"]) - and sorted(kwargs.keys()) != sorted(["location", "equation"]) - and sorted(kwargs.keys()) != sorted(["input_points", "equation"]) - ): - raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") - # TODO: remove, not used anymore - if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor): - raise TypeError("`input_points` must be a torch.Tensor.") - if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor): - raise TypeError("`output_points` must be a torch.Tensor.") - if not self._dictvalue_isinstance(kwargs, "location", Location): - raise TypeError("`location` must be a Location.") - if not self._dictvalue_isinstance(kwargs, "equation", Equation): - raise TypeError("`equation` must be a Equation.") - - for key, value in kwargs.items(): - setattr(self, key, value) - ''' \ No newline at end of file + raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") \ No newline at end of file diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index 0626a6d..f380dcf 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,21 +1,25 @@ -from abc import ABCMeta, abstractmethod +from abc import ABCMeta class ConditionInterface(metaclass=ABCMeta): - def __init__(self) -> None: - self._problem = None + condition_types = ['physical', 'supervised', 'unsupervised'] + def __init__(self): + self._condition_type = None - @abstractmethod - def residual(self, model): - """ - Compute the residual of the condition. - - :param model: The model to evaluate the condition. - :return: The residual of the condition. - """ - pass - - def set_problem(self, problem): - self._problem = problem + @property + def condition_type(self): + return self._condition_type + + @condition_type.setattr + def condition_type(self, values): + if not isinstance(values, (list, tuple)): + values = [values] + for value in values: + if value not in ConditionInterface.condition_types: + raise ValueError( + 'Unavailable type of condition, expected one of' + f' {ConditionInterface.condition_types}.' + ) + self._condition_type = values \ No newline at end of file diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py new file mode 100644 index 0000000..259eb56 --- /dev/null +++ b/pina/condition/data_condition.py @@ -0,0 +1,44 @@ +import torch + +from . import ConditionInterface +from ..label_tensor import LabelTensor +from ..graph import Graph +from ..utils import check_consistency + +class DataConditionInterface(ConditionInterface): + """ + Condition for data. This condition must be used every + time a Unsupervised Loss is needed in the Solver. The conditionalvariable + can be passed as extra-input when the model learns a conditional + distribution + """ + + __slots__ = ["data", "conditionalvariable"] + + def __init__(self, data, conditionalvariable=None): + """ + TODO + """ + super().__init__() + self.data = data + self.conditionalvariable = conditionalvariable + self.condition_type = 'unsupervised' + + @property + def data(self): + return self._data + + @data.setter + def data(self, value): + check_consistency(value, (LabelTensor, Graph, torch.Tensor)) + self._data = value + + @property + def conditionalvariable(self): + return self._conditionalvariable + + @data.setter + def conditionalvariable(self, value): + if value is not None: + check_consistency(value, (LabelTensor, Graph, torch.Tensor)) + self._data = value \ No newline at end of file diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 15df3f8..9838ad7 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -1,34 +1,43 @@ +import torch + from .condition_interface import ConditionInterface +from ..label_tensor import LabelTensor +from ..graph import Graph +from ..utils import check_consistency +from ..domain import DomainInterface +from ..equation.equation_interface import EquationInterface class DomainEquationCondition(ConditionInterface): """ - Condition for input/output data. + Condition for domain/equation data. This condition must be used every + time a Physics Informed Loss is needed in the Solver. """ __slots__ = ["domain", "equation"] def __init__(self, domain, equation): """ - Constructor for the `InputOutputCondition` class. + TODO """ super().__init__() self.domain = domain self.equation = equation + self.condition_type = 'physics' - def residual(self, model): - """ - Compute the residual of the condition. - """ - self.batch_residual(model, self.domain, self.equation) + @property + def domain(self): + return self._domain + + @domain.setter + def domain(self, value): + check_consistency(value, (DomainInterface)) + self._domain = value - @staticmethod - def batch_residual(model, input_pts, equation): - """ - Compute the residual of the condition for a single batch. Input and - output points are provided as arguments. - - :param torch.nn.Module model: The model to evaluate the condition. - :param torch.Tensor input_pts: The input points. - :param torch.Tensor equation: The output points. - """ - return equation.residual(input_pts, model(input_pts)) \ No newline at end of file + @property + def equation(self): + return self._equation + + @equation.setter + def equation(self, value): + check_consistency(value, (EquationInterface)) + self._equation = value \ No newline at end of file diff --git a/pina/condition/domain_output_condition.py b/pina/condition/domain_output_condition.py deleted file mode 100644 index f847720..0000000 --- a/pina/condition/domain_output_condition.py +++ /dev/null @@ -1,44 +0,0 @@ - -from . import ConditionInterface - -class DomainOutputCondition(ConditionInterface): - """ - Condition for input/output data. - """ - - __slots__ = ["domain", "output_points"] - - def __init__(self, domain, output_points): - """ - Constructor for the `InputOutputCondition` class. - """ - super().__init__() - print(self) - self.domain = domain - self.output_points = output_points - - @property - def input_points(self): - """ - Get the input points of the condition. - """ - return self._problem.domains[self.domain] - - def residual(self, model): - """ - Compute the residual of the condition. - """ - return self.batch_residual(model, self.domain, self.output_points) - - @staticmethod - def batch_residual(model, input_points, output_points): - """ - Compute the residual of the condition for a single batch. Input and - output points are provided as arguments. - - :param torch.nn.Module model: The model to evaluate the condition. - :param torch.Tensor input_points: The input points. - :param torch.Tensor output_points: The output points. - """ - - return output_points - model(input_points) \ No newline at end of file diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 288022c..c4b9f8d 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,23 +1,42 @@ +import torch -from . import ConditionInterface +from .condition_interface import ConditionInterface +from ..label_tensor import LabelTensor +from ..graph import Graph +from ..utils import check_consistency +from ..equation.equation_interface import EquationInterface -class InputEquationCondition(ConditionInterface): +class InputPointsEquationCondition(ConditionInterface): """ - Condition for input/output data. + Condition for input_points/equation data. This condition must be used every + time a Physics Informed Loss is needed in the Solver. """ - __slots__ = ["input_points", "output_points"] + __slots__ = ["input_points", "equation"] - def __init__(self, input_points, output_points): + def __init__(self, input_points, equation): """ - Constructor for the `InputOutputCondition` class. + TODO """ super().__init__() self.input_points = input_points - self.output_points = output_points + self.equation = equation + self.condition_type = 'physics' - def residual(self, model): - """ - Compute the residual of the condition. - """ - return self.output_points - model(self.input_points) \ No newline at end of file + @property + def input_points(self): + return self._input_points + + @input_points.setter + def input_points(self, value): + check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators! + self._input_points = value + + @property + def equation(self): + return self._equation + + @equation.setter + def equation(self, value): + check_consistency(value, (EquationInterface)) + self._equation = value \ No newline at end of file diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py new file mode 100644 index 0000000..fd6b7a0 --- /dev/null +++ b/pina/condition/input_output_condition.py @@ -0,0 +1,42 @@ + +import torch + +from .condition_interface import ConditionInterface +from ..label_tensor import LabelTensor +from ..graph import Graph +from ..utils import check_consistency + +class InputOutputPointsCondition(ConditionInterface): + """ + Condition for domain/equation data. This condition must be used every + time a Physics Informed or a Supervised Loss is needed in the Solver. + """ + + __slots__ = ["input_points", "output_points"] + + def __init__(self, input_points, output_points): + """ + TODO + """ + super().__init__() + self.input_points = input_points + self.output_points = output_points + self.condition_type = ['supervised', 'physics'] + + @property + def input_points(self): + return self._input_points + + @input_points.setter + def input_points(self, value): + check_consistency(value, (LabelTensor, Graph, torch.Tensor)) + self._input_points = value + + @property + def output_points(self): + return self._output_points + + @output_points.setter + def output_points(self, value): + check_consistency(value, (LabelTensor, Graph, torch.Tensor)) + self._output_points = value \ No newline at end of file