new conditions

This commit is contained in:
Dario Coscia
2024-10-03 21:33:37 +02:00
committed by Nicola Demo
parent a888141707
commit fd16fcf9b4
8 changed files with 210 additions and 171 deletions

View File

@@ -1,27 +1,21 @@
""" Condition module. """
from ..label_tensor import LabelTensor
from ..domain import DomainInterface
from ..equation.equation import Equation
from . import DomainOutputCondition, DomainEquationCondition
def dummy(a):
"""Dummy function for testing purposes."""
return None
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface
class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in three ways:
problem at hand. Condition objects are used to formulate the
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in four ways:
1. By specifying the input and output points of the condition; in such a
case, the model is trained to produce the output points given the input
points.
points. Those points can either be torch.Tensor, LabelTensors, Graph
2. By specifying the location and the equation of the condition; in such
a case, the model is trained to minimize the equation residual by
@@ -29,79 +23,48 @@ class Condition:
3. By specifying the input points and the equation of the condition; in
such a case, the model is trained to minimize the equation residual by
evaluating it at the passed input points.
evaluating it at the passed input points. The input points must be
a LabelTensor.
4. By specifying only the data matrix; in such a case the model is
trained with an unsupervised costum loss and uses the data in training.
Additionaly conditioning variables can be passed, whenever the model
has extra conditioning variable it depends on.
Example::
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
>>> def example_dirichlet(input_, output_):
>>> value = 0.0
>>> return output_.extract(['u']) - value
>>> example_input_pts = LabelTensor(
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
>>>
>>> Condition(
>>> input_points=example_input_pts,
>>> output_points=example_output_pts)
>>> Condition(
>>> location=example_domain,
>>> equation=example_dirichlet)
>>> Condition(
>>> input_points=example_input_pts,
>>> equation=example_dirichlet)
>>> TODO
"""
# def _dictvalue_isinstance(self, dict_, key_, class_):
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
# if key_ not in dict_.keys():
# return True
__slots__ = list(
set(
InputOutputPointsCondition.__slots__,
InputPointsEquationCondition.__slots__,
DomainEquationCondition.__slots__,
DataConditionInterface.__slots__
# return isinstance(dict_[key_], class_)
# def __init__(self, *args, **kwargs):
# """
# Constructor for the `Condition` class.
# """
# self.data_weight = kwargs.pop("data_weight", 1.0)
# if len(args) != 0:
# raise ValueError(
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
# )
)
)
def __new__(cls, *args, **kwargs):
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
return DomainOutputCondition(
domain=kwargs["input_points"],
output_points=kwargs["output_points"]
if len(args) != 0:
raise ValueError(
f"Condition takes only the following keyword '
'arguments: {Condition.__slots__}."
)
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
return DomainOutputCondition(**kwargs)
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
sorted_keys = sorted(kwargs.keys())
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
return InputOutputPointsCondition(**kwargs)
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
return InputPointsEquationCondition(**kwargs)
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
return DomainEquationCondition(**kwargs)
elif sorted_keys == sorted(DataConditionInterface.__slots__):
return DataConditionInterface(**kwargs)
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
'''
if (
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
and sorted(kwargs.keys()) != sorted(["location", "equation"])
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
):
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
raise TypeError("`input_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
raise TypeError("`output_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "location", Location):
raise TypeError("`location` must be a Location.")
if not self._dictvalue_isinstance(kwargs, "equation", Equation):
raise TypeError("`equation` must be a Equation.")
for key, value in kwargs.items():
setattr(self, key, value)
'''
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")