add docs
This commit is contained in:
@@ -1,43 +1,86 @@
|
||||
""" """
|
||||
import torch
|
||||
|
||||
""" Condition module. """
|
||||
from .label_tensor import LabelTensor
|
||||
from .location import Location
|
||||
|
||||
def dummy(a):
|
||||
"""Dummy function for testing purposes."""
|
||||
return None
|
||||
|
||||
class Condition:
|
||||
"""
|
||||
The class `Condition` is used to represent the constraints (physical
|
||||
equations, boundary conditions, etc.) that should be satisfied in the
|
||||
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.Abstract_Problem` object.
|
||||
Conditions can be specified in three ways:
|
||||
|
||||
1. By specifying the input and output points of the condition; in such a
|
||||
case, the model is trained to produce the output points given the input
|
||||
points.
|
||||
|
||||
2. By specifying the location and the function of the condition; in such
|
||||
a case, the model is trained to minimize that function by evaluating it
|
||||
at some samples of the location.
|
||||
|
||||
3. By specifying the input points and the function of the condition; in
|
||||
such a case, the model is trained to minimize that function by
|
||||
evaluating it at the input points.
|
||||
|
||||
Example::
|
||||
|
||||
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||
>>> def example_dirichlet(input_, output_):
|
||||
>>> value = 0.0
|
||||
>>> return output_.extract(['u']) - value
|
||||
>>> example_input_pts = LabelTensor(
|
||||
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
||||
>>>
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> output_points=example_output_pts)
|
||||
>>> Condition(
|
||||
>>> location=example_domain,
|
||||
>>> function=example_dirichlet)
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> function=example_dirichlet)
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
'input_points', 'output_points', 'location', 'function',
|
||||
'data_weight'
|
||||
]
|
||||
|
||||
def _dictvalue_isinstance(self, dict_, key_, class_):
|
||||
"""Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
||||
if key_ not in dict_.keys():
|
||||
return True
|
||||
|
||||
return isinstance(dict_[key_], class_)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Constructor for the `Condition` class.
|
||||
"""
|
||||
self.data_weight = kwargs.pop('data_weight', 1.0)
|
||||
|
||||
if 'data_weight' in kwargs:
|
||||
self.data_weight = kwargs['data_weight']
|
||||
if not 'data_weight' in kwargs:
|
||||
self.data_weight = 1.
|
||||
if len(args) != 0:
|
||||
raise ValueError('Condition takes only the following keyword arguments: {`input_points`, `output_points`, `location`, `function`, `data_weight`}.')
|
||||
|
||||
if len(args) == 2:
|
||||
if (
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
|
||||
sorted(kwargs.keys()) != sorted(['location', 'function']) and
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'function'])
|
||||
):
|
||||
raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')
|
||||
|
||||
if (isinstance(args[0], torch.Tensor) and
|
||||
isinstance(args[1], torch.Tensor)):
|
||||
self.input_points = args[0]
|
||||
self.output_points = args[1]
|
||||
elif isinstance(args[0], Location) and callable(args[1]):
|
||||
self.location = args[0]
|
||||
self.function = args[1]
|
||||
elif isinstance(args[0], Location) and isinstance(args[1], list):
|
||||
self.location = args[0]
|
||||
self.function = args[1]
|
||||
else:
|
||||
raise ValueError
|
||||
if not self._dictvalue_isinstance(kwargs, 'input_points', LabelTensor):
|
||||
raise TypeError('`input_points` must be a torch.Tensor.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'output_points', LabelTensor):
|
||||
raise TypeError('`output_points` must be a torch.Tensor.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'location', Location):
|
||||
raise TypeError('`location` must be a Location.')
|
||||
|
||||
elif not args and len(kwargs) >= 2:
|
||||
|
||||
if 'input_points' in kwargs and 'output_points' in kwargs:
|
||||
self.input_points = kwargs['input_points']
|
||||
self.output_points = kwargs['output_points']
|
||||
elif 'location' in kwargs and 'function' in kwargs:
|
||||
self.location = kwargs['location']
|
||||
self.function = kwargs['function']
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if hasattr(self, 'function') and not isinstance(self.function, list):
|
||||
self.function = [self.function]
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
Reference in New Issue
Block a user