new conditions
This commit is contained in:
committed by
Nicola Demo
parent
a888141707
commit
fd16fcf9b4
@@ -1,27 +1,21 @@
|
||||
""" Condition module. """
|
||||
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..domain import DomainInterface
|
||||
from ..equation.equation import Equation
|
||||
|
||||
from . import DomainOutputCondition, DomainEquationCondition
|
||||
|
||||
|
||||
def dummy(a):
|
||||
"""Dummy function for testing purposes."""
|
||||
return None
|
||||
|
||||
from .domain_equation_condition import DomainEquationCondition
|
||||
from .input_equation_condition import InputPointsEquationCondition
|
||||
from .input_output_condition import InputOutputPointsCondition
|
||||
from .data_condition import DataConditionInterface
|
||||
|
||||
class Condition:
|
||||
"""
|
||||
The class ``Condition`` is used to represent the constraints (physical
|
||||
equations, boundary conditions, etc.) that should be satisfied in the
|
||||
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
||||
Conditions can be specified in three ways:
|
||||
problem at hand. Condition objects are used to formulate the
|
||||
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
||||
Conditions can be specified in four ways:
|
||||
|
||||
1. By specifying the input and output points of the condition; in such a
|
||||
case, the model is trained to produce the output points given the input
|
||||
points.
|
||||
points. Those points can either be torch.Tensor, LabelTensors, Graph
|
||||
|
||||
2. By specifying the location and the equation of the condition; in such
|
||||
a case, the model is trained to minimize the equation residual by
|
||||
@@ -29,79 +23,48 @@ class Condition:
|
||||
|
||||
3. By specifying the input points and the equation of the condition; in
|
||||
such a case, the model is trained to minimize the equation residual by
|
||||
evaluating it at the passed input points.
|
||||
evaluating it at the passed input points. The input points must be
|
||||
a LabelTensor.
|
||||
|
||||
4. By specifying only the data matrix; in such a case the model is
|
||||
trained with an unsupervised costum loss and uses the data in training.
|
||||
Additionaly conditioning variables can be passed, whenever the model
|
||||
has extra conditioning variable it depends on.
|
||||
|
||||
Example::
|
||||
|
||||
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||
>>> def example_dirichlet(input_, output_):
|
||||
>>> value = 0.0
|
||||
>>> return output_.extract(['u']) - value
|
||||
>>> example_input_pts = LabelTensor(
|
||||
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
||||
>>>
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> output_points=example_output_pts)
|
||||
>>> Condition(
|
||||
>>> location=example_domain,
|
||||
>>> equation=example_dirichlet)
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> equation=example_dirichlet)
|
||||
>>> TODO
|
||||
|
||||
"""
|
||||
|
||||
# def _dictvalue_isinstance(self, dict_, key_, class_):
|
||||
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
||||
# if key_ not in dict_.keys():
|
||||
# return True
|
||||
__slots__ = list(
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__,
|
||||
InputPointsEquationCondition.__slots__,
|
||||
DomainEquationCondition.__slots__,
|
||||
DataConditionInterface.__slots__
|
||||
|
||||
# return isinstance(dict_[key_], class_)
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# """
|
||||
# Constructor for the `Condition` class.
|
||||
# """
|
||||
# self.data_weight = kwargs.pop("data_weight", 1.0)
|
||||
|
||||
# if len(args) != 0:
|
||||
# raise ValueError(
|
||||
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
|
||||
# )
|
||||
)
|
||||
)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
||||
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
|
||||
return DomainOutputCondition(
|
||||
domain=kwargs["input_points"],
|
||||
output_points=kwargs["output_points"]
|
||||
|
||||
if len(args) != 0:
|
||||
raise ValueError(
|
||||
f"Condition takes only the following keyword '
|
||||
'arguments: {Condition.__slots__}."
|
||||
)
|
||||
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
|
||||
return DomainOutputCondition(**kwargs)
|
||||
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
|
||||
|
||||
sorted_keys = sorted(kwargs.keys())
|
||||
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
|
||||
return InputOutputPointsCondition(**kwargs)
|
||||
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
|
||||
return InputPointsEquationCondition(**kwargs)
|
||||
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
|
||||
return DomainEquationCondition(**kwargs)
|
||||
elif sorted_keys == sorted(DataConditionInterface.__slots__):
|
||||
return DataConditionInterface(**kwargs)
|
||||
elif sorted_keys == DataConditionInterface.__slots__[0]:
|
||||
return DataConditionInterface(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
# TODO: remove, not used anymore
|
||||
'''
|
||||
if (
|
||||
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
||||
and sorted(kwargs.keys()) != sorted(["location", "equation"])
|
||||
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
|
||||
):
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
# TODO: remove, not used anymore
|
||||
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
|
||||
raise TypeError("`input_points` must be a torch.Tensor.")
|
||||
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
|
||||
raise TypeError("`output_points` must be a torch.Tensor.")
|
||||
if not self._dictvalue_isinstance(kwargs, "location", Location):
|
||||
raise TypeError("`location` must be a Location.")
|
||||
if not self._dictvalue_isinstance(kwargs, "equation", Equation):
|
||||
raise TypeError("`equation` must be a Equation.")
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
'''
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
Reference in New Issue
Block a user