This commit is contained in:
Your Name
2023-04-18 10:49:57 +02:00
parent da33aeae3a
commit 736c78fd64
17 changed files with 292 additions and 172 deletions

View File

@@ -1,43 +1,86 @@
""" """
import torch
""" Condition module. """
from .label_tensor import LabelTensor
from .location import Location
def dummy(a):
"""Dummy function for testing purposes."""
return None
class Condition:
"""
The class `Condition` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.Abstract_Problem` object.
Conditions can be specified in three ways:
1. By specifying the input and output points of the condition; in such a
case, the model is trained to produce the output points given the input
points.
2. By specifying the location and the function of the condition; in such
a case, the model is trained to minimize that function by evaluating it
at some samples of the location.
3. By specifying the input points and the function of the condition; in
such a case, the model is trained to minimize that function by
evaluating it at the input points.
Example::
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
>>> def example_dirichlet(input_, output_):
>>> value = 0.0
>>> return output_.extract(['u']) - value
>>> example_input_pts = LabelTensor(
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
>>>
>>> Condition(
>>> input_points=example_input_pts,
>>> output_points=example_output_pts)
>>> Condition(
>>> location=example_domain,
>>> function=example_dirichlet)
>>> Condition(
>>> input_points=example_input_pts,
>>> function=example_dirichlet)
"""
__slots__ = [
'input_points', 'output_points', 'location', 'function',
'data_weight'
]
def _dictvalue_isinstance(self, dict_, key_, class_):
"""Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
if key_ not in dict_.keys():
return True
return isinstance(dict_[key_], class_)
def __init__(self, *args, **kwargs):
"""
Constructor for the `Condition` class.
"""
self.data_weight = kwargs.pop('data_weight', 1.0)
if 'data_weight' in kwargs:
self.data_weight = kwargs['data_weight']
if not 'data_weight' in kwargs:
self.data_weight = 1.
if len(args) != 0:
raise ValueError('Condition takes only the following keyword arguments: {`input_points`, `output_points`, `location`, `function`, `data_weight`}.')
if len(args) == 2:
if (
sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
sorted(kwargs.keys()) != sorted(['location', 'function']) and
sorted(kwargs.keys()) != sorted(['input_points', 'function'])
):
raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')
if (isinstance(args[0], torch.Tensor) and
isinstance(args[1], torch.Tensor)):
self.input_points = args[0]
self.output_points = args[1]
elif isinstance(args[0], Location) and callable(args[1]):
self.location = args[0]
self.function = args[1]
elif isinstance(args[0], Location) and isinstance(args[1], list):
self.location = args[0]
self.function = args[1]
else:
raise ValueError
if not self._dictvalue_isinstance(kwargs, 'input_points', LabelTensor):
raise TypeError('`input_points` must be a torch.Tensor.')
if not self._dictvalue_isinstance(kwargs, 'output_points', LabelTensor):
raise TypeError('`output_points` must be a torch.Tensor.')
if not self._dictvalue_isinstance(kwargs, 'location', Location):
raise TypeError('`location` must be a Location.')
elif not args and len(kwargs) >= 2:
if 'input_points' in kwargs and 'output_points' in kwargs:
self.input_points = kwargs['input_points']
self.output_points = kwargs['output_points']
elif 'location' in kwargs and 'function' in kwargs:
self.location = kwargs['location']
self.function = kwargs['function']
else:
raise ValueError
else:
raise ValueError
if hasattr(self, 'function') and not isinstance(self.function, list):
self.function = [self.function]
for key, value in kwargs.items():
setattr(self, key, value)