equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

@@ -1,6 +1,7 @@
""" Condition module. """
from .label_tensor import LabelTensor
from .location import Location
from .geometry import Location
from .equation.equation import Equation
def dummy(a):
"""Dummy function for testing purposes."""
@@ -17,13 +18,13 @@ class Condition:
case, the model is trained to produce the output points given the input
points.
2. By specifying the location and the function of the condition; in such
a case, the model is trained to minimize that function by evaluating it
at some samples of the location.
2. By specifying the location and the equation of the condition; in such
a case, the model is trained to minimize the equation residual by
evaluating it at some samples of the location.
3. By specifying the input points and the function of the condition; in
such a case, the model is trained to minimize that function by
evaluating it at the input points.
3. By specifying the input points and the equation of the condition; in
such a case, the model is trained to minimize the equation residual by
evaluating it at the passed input points.
Example::
@@ -40,15 +41,15 @@ class Condition:
>>> output_points=example_output_pts)
>>> Condition(
>>> location=example_domain,
>>> function=example_dirichlet)
>>> equation=example_dirichlet)
>>> Condition(
>>> input_points=example_input_pts,
>>> function=example_dirichlet)
>>> equation=example_dirichlet)
"""
__slots__ = [
'input_points', 'output_points', 'location', 'function',
'input_points', 'output_points', 'location', 'equation',
'data_weight'
]
@@ -70,8 +71,8 @@ class Condition:
if (
sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
sorted(kwargs.keys()) != sorted(['location', 'function']) and
sorted(kwargs.keys()) != sorted(['input_points', 'function'])
sorted(kwargs.keys()) != sorted(['location', 'equation']) and
sorted(kwargs.keys()) != sorted(['input_points', 'equation'])
):
raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')
@@ -81,16 +82,8 @@ class Condition:
raise TypeError('`output_points` must be a torch.Tensor.')
if not self._dictvalue_isinstance(kwargs, 'location', Location):
raise TypeError('`location` must be a Location.')
if 'function' in kwargs:
if not isinstance(kwargs['function'], list):
kwargs['function'] = [kwargs['function']]
for i, func in enumerate(kwargs['function']):
if not callable(func):
raise TypeError(
f'`function[{i}]` must be a callable function.')
if not self._dictvalue_isinstance(kwargs, 'equation', Equation):
raise TypeError('`equation` must be a Equation.')
for key, value in kwargs.items():
setattr(self, key, value)