98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
""" Condition module. """
|
|
|
|
from .label_tensor import LabelTensor
|
|
from .geometry import Location
|
|
from .equation.equation import Equation
|
|
|
|
|
|
def dummy(a):
|
|
"""Dummy function for testing purposes."""
|
|
return None
|
|
|
|
|
|
class Condition:
|
|
"""
|
|
The class ``Condition`` is used to represent the constraints (physical
|
|
equations, boundary conditions, etc.) that should be satisfied in the
|
|
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
|
Conditions can be specified in three ways:
|
|
|
|
1. By specifying the input and output points of the condition; in such a
|
|
case, the model is trained to produce the output points given the input
|
|
points.
|
|
|
|
2. By specifying the location and the equation of the condition; in such
|
|
a case, the model is trained to minimize the equation residual by
|
|
evaluating it at some samples of the location.
|
|
|
|
3. By specifying the input points and the equation of the condition; in
|
|
such a case, the model is trained to minimize the equation residual by
|
|
evaluating it at the passed input points.
|
|
|
|
Example::
|
|
|
|
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
|
>>> def example_dirichlet(input_, output_):
|
|
>>> value = 0.0
|
|
>>> return output_.extract(['u']) - value
|
|
>>> example_input_pts = LabelTensor(
|
|
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
|
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
|
>>>
|
|
>>> Condition(
|
|
>>> input_points=example_input_pts,
|
|
>>> output_points=example_output_pts)
|
|
>>> Condition(
|
|
>>> location=example_domain,
|
|
>>> equation=example_dirichlet)
|
|
>>> Condition(
|
|
>>> input_points=example_input_pts,
|
|
>>> equation=example_dirichlet)
|
|
|
|
"""
|
|
|
|
__slots__ = [
|
|
"input_points",
|
|
"output_points",
|
|
"location",
|
|
"equation",
|
|
"data_weight",
|
|
]
|
|
|
|
def _dictvalue_isinstance(self, dict_, key_, class_):
|
|
"""Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
|
if key_ not in dict_.keys():
|
|
return True
|
|
|
|
return isinstance(dict_[key_], class_)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Constructor for the `Condition` class.
|
|
"""
|
|
self.data_weight = kwargs.pop("data_weight", 1.0)
|
|
|
|
if len(args) != 0:
|
|
raise ValueError(
|
|
f"Condition takes only the following keyword arguments: {Condition.__slots__}."
|
|
)
|
|
|
|
if (
|
|
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
|
and sorted(kwargs.keys()) != sorted(["location", "equation"])
|
|
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
|
|
):
|
|
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
|
|
|
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
|
|
raise TypeError("`input_points` must be a torch.Tensor.")
|
|
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
|
|
raise TypeError("`output_points` must be a torch.Tensor.")
|
|
if not self._dictvalue_isinstance(kwargs, "location", Location):
|
|
raise TypeError("`location` must be a Location.")
|
|
if not self._dictvalue_isinstance(kwargs, "equation", Equation):
|
|
raise TypeError("`equation` must be a Equation.")
|
|
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|