equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
""" Condition module. """
|
||||
from .label_tensor import LabelTensor
|
||||
from .location import Location
|
||||
from .geometry import Location
|
||||
from .equation.equation import Equation
|
||||
|
||||
def dummy(a):
|
||||
"""Dummy function for testing purposes."""
|
||||
@@ -17,13 +18,13 @@ class Condition:
|
||||
case, the model is trained to produce the output points given the input
|
||||
points.
|
||||
|
||||
2. By specifying the location and the function of the condition; in such
|
||||
a case, the model is trained to minimize that function by evaluating it
|
||||
at some samples of the location.
|
||||
2. By specifying the location and the equation of the condition; in such
|
||||
a case, the model is trained to minimize the equation residual by
|
||||
evaluating it at some samples of the location.
|
||||
|
||||
3. By specifying the input points and the function of the condition; in
|
||||
such a case, the model is trained to minimize that function by
|
||||
evaluating it at the input points.
|
||||
3. By specifying the input points and the equation of the condition; in
|
||||
such a case, the model is trained to minimize the equation residual by
|
||||
evaluating it at the passed input points.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -40,15 +41,15 @@ class Condition:
|
||||
>>> output_points=example_output_pts)
|
||||
>>> Condition(
|
||||
>>> location=example_domain,
|
||||
>>> function=example_dirichlet)
|
||||
>>> equation=example_dirichlet)
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> function=example_dirichlet)
|
||||
>>> equation=example_dirichlet)
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
'input_points', 'output_points', 'location', 'function',
|
||||
'input_points', 'output_points', 'location', 'equation',
|
||||
'data_weight'
|
||||
]
|
||||
|
||||
@@ -70,8 +71,8 @@ class Condition:
|
||||
|
||||
if (
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
|
||||
sorted(kwargs.keys()) != sorted(['location', 'function']) and
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'function'])
|
||||
sorted(kwargs.keys()) != sorted(['location', 'equation']) and
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'equation'])
|
||||
):
|
||||
raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')
|
||||
|
||||
@@ -81,16 +82,8 @@ class Condition:
|
||||
raise TypeError('`output_points` must be a torch.Tensor.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'location', Location):
|
||||
raise TypeError('`location` must be a Location.')
|
||||
|
||||
if 'function' in kwargs:
|
||||
if not isinstance(kwargs['function'], list):
|
||||
kwargs['function'] = [kwargs['function']]
|
||||
|
||||
|
||||
for i, func in enumerate(kwargs['function']):
|
||||
if not callable(func):
|
||||
raise TypeError(
|
||||
f'`function[{i}]` must be a callable function.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'equation', Equation):
|
||||
raise TypeError('`equation` must be a Equation.')
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
Reference in New Issue
Block a user