add docs
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SPHINXBUILD = python -msphinx #sphinx-build
|
||||
PAPER =
|
||||
BUILDDIR = build
|
||||
|
||||
|
||||
@@ -15,4 +15,6 @@ Code Documentation
|
||||
SpatialProblem <spatialproblem.rst>
|
||||
TimeDependentProblem <timedepproblem.rst>
|
||||
Operators <operators.rst>
|
||||
Plotter <plotter.rst>
|
||||
Plotter <plotter.rst>
|
||||
PINN <pinn.rst>
|
||||
Condition <condition.rst>
|
||||
|
||||
12
docs/source/_rst/condition.rst
Normal file
12
docs/source/_rst/condition.rst
Normal file
@@ -0,0 +1,12 @@
|
||||
Condition
|
||||
=========
|
||||
.. currentmodule:: pina.condition
|
||||
|
||||
.. automodule:: pina.condition
|
||||
|
||||
.. autoclass:: Condition
|
||||
:members:
|
||||
:private-members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
@@ -5,8 +5,6 @@ DeepONet
|
||||
.. automodule:: pina.model.deeponet
|
||||
|
||||
.. autoclass:: DeepONet
|
||||
:members:
|
||||
:private-members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
:members:
|
||||
:private-members:
|
||||
:show-inheritance:
|
||||
@@ -5,8 +5,6 @@ FeedForward
|
||||
.. automodule:: pina.model.feed_forward
|
||||
|
||||
.. autoclass:: FeedForward
|
||||
:members:
|
||||
:private-members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
:members:
|
||||
:private-members:
|
||||
:show-inheritance:
|
||||
@@ -5,8 +5,6 @@ LabelTensor
|
||||
.. automodule:: pina.label_tensor
|
||||
|
||||
.. autoclass:: LabelTensor
|
||||
:members:
|
||||
:private-members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
:members:
|
||||
:private-members:
|
||||
:show-inheritance:
|
||||
@@ -5,8 +5,6 @@ PINN
|
||||
.. automodule:: pina.pinn
|
||||
|
||||
.. autoclass:: PINN
|
||||
:members:
|
||||
:private-members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
:members:
|
||||
:private-members:
|
||||
:show-inheritance:
|
||||
@@ -36,28 +36,28 @@ import pina
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.graphviz',
|
||||
'sphinx.ext.doctest',
|
||||
#'sphinx.ext.autosummary',
|
||||
#'sphinx.ext.coverage',
|
||||
#'sphinx.ext.graphviz',
|
||||
#'sphinx.ext.doctest',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.coverage',
|
||||
#'sphinx.ext.coverage',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.ifconfig',
|
||||
#'sphinx.ext.ifconfig',
|
||||
'sphinx.ext.mathjax',
|
||||
]
|
||||
autosummary_generate = True
|
||||
#autosummary_generate = True
|
||||
|
||||
intersphinx_mapping = {
|
||||
'python': ('http://docs.python.org/2', None),
|
||||
'python': ('http://docs.python.org/3', None),
|
||||
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
|
||||
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
|
||||
'matplotlib': ('http://matplotlib.sourceforge.net/', None),
|
||||
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||
'pina': ('https://mathlab.github.io/PINA/', None)
|
||||
}
|
||||
|
||||
nitpicky = True
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
|
||||
@@ -14,11 +14,11 @@ learning tasks while respecting any given law of physics described by general
|
||||
nonlinear differential equations. Proposed in "Physics-informed neural
|
||||
networks: A deep learning framework for solving forward and inverse problems
|
||||
involving nonlinear partial differential equations", such framework aims to
|
||||
solve problems in a continuous and nonlinear settings.
|
||||
solve problems in a continuous and nonlinear settings. :py:class:`pina.pinn.PINN`
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:maxdepth: 2
|
||||
:caption: Package Documentation:
|
||||
|
||||
Installation <_rst/installation>
|
||||
|
||||
@@ -5,11 +5,7 @@ import torch
|
||||
from problems.poisson import Poisson
|
||||
|
||||
from pina import PINN, LabelTensor, Plotter
|
||||
from pina.model.deeponet import DeepONet, check_combos, spawn_combo_networks
|
||||
|
||||
logging.basicConfig(
|
||||
filename="poisson_deeponet.log", filemode="w", level=logging.INFO
|
||||
)
|
||||
from pina.model import DeepONet, FeedForward
|
||||
|
||||
|
||||
class SinFeature(torch.nn.Module):
|
||||
@@ -36,27 +32,33 @@ class SinFeature(torch.nn.Module):
|
||||
return LabelTensor(t, [f"sin({self._label})"])
|
||||
|
||||
|
||||
def prepare_deeponet_model(args, problem, extra_feature_combo_func=None):
|
||||
combos = tuple(map(lambda combo: combo.split("-"), args.combos.split(",")))
|
||||
check_combos(combos, problem.input_variables)
|
||||
class myRBF(torch.nn.Module):
|
||||
def __init__(self, input_):
|
||||
|
||||
extra_feature = extra_feature_combo_func if args.extra else None
|
||||
networks = spawn_combo_networks(
|
||||
combos=combos,
|
||||
layers=list(map(int, args.layers.split(","))) if args.layers else [],
|
||||
output_dimension=args.hidden * len(problem.output_variables),
|
||||
func=torch.nn.Softplus,
|
||||
extra_feature=extra_feature,
|
||||
bias=not args.nobias,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
return DeepONet(
|
||||
networks,
|
||||
problem.output_variables,
|
||||
aggregator=args.aggregator,
|
||||
reduction=args.reduction,
|
||||
)
|
||||
self.input_variables = [input_]
|
||||
self.a = torch.nn.Parameter(torch.tensor([-.3]))
|
||||
# self.b = torch.nn.Parameter(torch.tensor([0.5]))
|
||||
self.b = torch.tensor([0.5])
|
||||
self.c = torch.nn.Parameter(torch.tensor([.5]))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.extract(self.input_variables)
|
||||
result = self.a * torch.exp(-(x - self.b)**2/(self.c**2))
|
||||
return result
|
||||
|
||||
class myModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
super().__init__()
|
||||
self.ffn_x = myRBF('x')
|
||||
self.ffn_y = myRBF('y')
|
||||
|
||||
def forward(self, x):
|
||||
result = self.ffn_x(x) * self.ffn_y(x)
|
||||
result.labels = ['u']
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run PINA")
|
||||
@@ -65,49 +67,49 @@ if __name__ == "__main__":
|
||||
parser.add_argument("id_run", help="Run ID", type=int)
|
||||
|
||||
parser.add_argument("--extra", help="Extra features", action="store_true")
|
||||
parser.add_argument("--nobias", action="store_true")
|
||||
parser.add_argument(
|
||||
"--combos",
|
||||
help="DeepONet internal network combinations",
|
||||
type=str,
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aggregator", help="Aggregator for DeepONet", type=str, default="*"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reduction", help="Reduction for DeepONet", type=str, default="+"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden",
|
||||
help="Number of variables in the hidden DeepONet layer",
|
||||
type=int,
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layers",
|
||||
help="Structure of the DeepONet partial layers",
|
||||
type=str,
|
||||
required=True,
|
||||
)
|
||||
cli_args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
poisson_problem = Poisson()
|
||||
problem = Poisson()
|
||||
|
||||
model = prepare_deeponet_model(
|
||||
cli_args,
|
||||
poisson_problem,
|
||||
extra_feature_combo_func=lambda combo: [SinFeature(combo)],
|
||||
)
|
||||
pinn = PINN(poisson_problem, model, lr=0.01, regularizer=1e-8)
|
||||
if cli_args.save:
|
||||
# ffn_x = FeedForward(
|
||||
# input_variables=['x'], layers=[], output_variables=1,
|
||||
# func=torch.nn.Softplus,
|
||||
# extra_features=[SinFeature('x')]
|
||||
# )
|
||||
# ffn_y = FeedForward
|
||||
# input_variables=['y'], layers=[], output_variables=1,
|
||||
# func=torch.nn.Softplus,
|
||||
# extra_features=[SinFeature('y')]
|
||||
# )
|
||||
model = myModel()
|
||||
test = torch.tensor([[0.0, 0.5]])
|
||||
test.labels = ['x', 'y']
|
||||
pinn = PINN(problem, model, lr=0.0001)
|
||||
|
||||
if args.save:
|
||||
pinn.span_pts(
|
||||
20, "grid", locations=["gamma1", "gamma2", "gamma3", "gamma4"]
|
||||
)
|
||||
pinn.span_pts(20, "grid", locations=["D"])
|
||||
pinn.train(1.0e-10, 100)
|
||||
pinn.save_state(f"pina.poisson_{cli_args.id_run}")
|
||||
if cli_args.load:
|
||||
pinn.load_state(f"pina.poisson_{cli_args.id_run}")
|
||||
while True:
|
||||
pinn.train(500, 50)
|
||||
print(model.ffn_x.a)
|
||||
print(model.ffn_x.b)
|
||||
print(model.ffn_x.c)
|
||||
|
||||
xi = torch.linspace(0, 1, 64).reshape(-1, 1).as_subclass(LabelTensor)
|
||||
xi.labels = ['x']
|
||||
yi = model.ffn_x(xi)
|
||||
y_truth = -torch.sin(xi*torch.pi)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(xi.detach().flatten(), yi.detach().flatten(), 'r-')
|
||||
plt.plot(xi.detach().flatten(), y_truth.detach().flatten(), 'b-')
|
||||
plt.plot(xi.detach().flatten(), -y_truth.detach().flatten(), 'b-')
|
||||
plt.show()
|
||||
pinn.save_state(f"pina.poisson_{args.id_run}")
|
||||
|
||||
if args.load:
|
||||
pinn.load_state(f"pina.poisson_{args.id_run}")
|
||||
plotter = Plotter()
|
||||
plotter.plot(pinn)
|
||||
|
||||
@@ -1,43 +1,86 @@
|
||||
""" """
|
||||
import torch
|
||||
|
||||
""" Condition module. """
|
||||
from .label_tensor import LabelTensor
|
||||
from .location import Location
|
||||
|
||||
def dummy(a):
|
||||
"""Dummy function for testing purposes."""
|
||||
return None
|
||||
|
||||
class Condition:
|
||||
"""
|
||||
The class `Condition` is used to represent the constraints (physical
|
||||
equations, boundary conditions, etc.) that should be satisfied in the
|
||||
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.Abstract_Problem` object.
|
||||
Conditions can be specified in three ways:
|
||||
|
||||
1. By specifying the input and output points of the condition; in such a
|
||||
case, the model is trained to produce the output points given the input
|
||||
points.
|
||||
|
||||
2. By specifying the location and the function of the condition; in such
|
||||
a case, the model is trained to minimize that function by evaluating it
|
||||
at some samples of the location.
|
||||
|
||||
3. By specifying the input points and the function of the condition; in
|
||||
such a case, the model is trained to minimize that function by
|
||||
evaluating it at the input points.
|
||||
|
||||
Example::
|
||||
|
||||
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||
>>> def example_dirichlet(input_, output_):
|
||||
>>> value = 0.0
|
||||
>>> return output_.extract(['u']) - value
|
||||
>>> example_input_pts = LabelTensor(
|
||||
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
||||
>>>
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> output_points=example_output_pts)
|
||||
>>> Condition(
|
||||
>>> location=example_domain,
|
||||
>>> function=example_dirichlet)
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> function=example_dirichlet)
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
'input_points', 'output_points', 'location', 'function',
|
||||
'data_weight'
|
||||
]
|
||||
|
||||
def _dictvalue_isinstance(self, dict_, key_, class_):
|
||||
"""Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
||||
if key_ not in dict_.keys():
|
||||
return True
|
||||
|
||||
return isinstance(dict_[key_], class_)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Constructor for the `Condition` class.
|
||||
"""
|
||||
self.data_weight = kwargs.pop('data_weight', 1.0)
|
||||
|
||||
if 'data_weight' in kwargs:
|
||||
self.data_weight = kwargs['data_weight']
|
||||
if not 'data_weight' in kwargs:
|
||||
self.data_weight = 1.
|
||||
if len(args) != 0:
|
||||
raise ValueError('Condition takes only the following keyword arguments: {`input_points`, `output_points`, `location`, `function`, `data_weight`}.')
|
||||
|
||||
if len(args) == 2:
|
||||
if (
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
|
||||
sorted(kwargs.keys()) != sorted(['location', 'function']) and
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'function'])
|
||||
):
|
||||
raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')
|
||||
|
||||
if (isinstance(args[0], torch.Tensor) and
|
||||
isinstance(args[1], torch.Tensor)):
|
||||
self.input_points = args[0]
|
||||
self.output_points = args[1]
|
||||
elif isinstance(args[0], Location) and callable(args[1]):
|
||||
self.location = args[0]
|
||||
self.function = args[1]
|
||||
elif isinstance(args[0], Location) and isinstance(args[1], list):
|
||||
self.location = args[0]
|
||||
self.function = args[1]
|
||||
else:
|
||||
raise ValueError
|
||||
if not self._dictvalue_isinstance(kwargs, 'input_points', LabelTensor):
|
||||
raise TypeError('`input_points` must be a torch.Tensor.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'output_points', LabelTensor):
|
||||
raise TypeError('`output_points` must be a torch.Tensor.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'location', Location):
|
||||
raise TypeError('`location` must be a Location.')
|
||||
|
||||
elif not args and len(kwargs) >= 2:
|
||||
|
||||
if 'input_points' in kwargs and 'output_points' in kwargs:
|
||||
self.input_points = kwargs['input_points']
|
||||
self.output_points = kwargs['output_points']
|
||||
elif 'location' in kwargs and 'function' in kwargs:
|
||||
self.location = kwargs['location']
|
||||
self.function = kwargs['function']
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if hasattr(self, 'function') and not isinstance(self.function, list):
|
||||
self.function = [self.function]
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@@ -5,10 +5,14 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
class Location(metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract class
|
||||
Abstract Location class.
|
||||
Any geometry entity should inherit from this class.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def sample(self):
|
||||
pass
|
||||
"""
|
||||
Abstract method for sampling a point from the location. To be
|
||||
implemented in the child class.
|
||||
"""
|
||||
pass
|
||||
@@ -12,7 +12,7 @@ from pina.utils import is_function
|
||||
|
||||
def check_combos(combos, variables):
|
||||
"""
|
||||
Check that the given combinations are subsets (overlapping
|
||||
Check that the given combinations are subsets (overlapping
|
||||
is allowed) of the given set of variables.
|
||||
|
||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||
@@ -35,7 +35,7 @@ def spawn_combo_networks(
|
||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||
:param iterable(int) layers: Size of hidden layers.
|
||||
:param int output_dimension: Size of the output layer of the networks.
|
||||
:param func: Nonlinearity.
|
||||
:param func: Nonlinearity.
|
||||
:param extra_feature: Extra feature to be considered by the networks.
|
||||
:param bool bias: Whether to consider bias or not.
|
||||
"""
|
||||
@@ -78,15 +78,16 @@ class DeepONet(torch.nn.Module):
|
||||
:param list(str) output_variables: the list containing the labels
|
||||
corresponding to the components of the output computed by the
|
||||
model.
|
||||
:param string | callable aggregator: Aggregator to be used to aggregate
|
||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. See :func:`_symbol_functions` for the
|
||||
aggregated component-wise. See
|
||||
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
|
||||
available default aggregators.
|
||||
:param string | callable reduction: Reduction to be used to reduce
|
||||
:param str | callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. See :func:`_symbol_functions` for the available default
|
||||
reductions.
|
||||
|
||||
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
||||
reductions.
|
||||
|
||||
:Example:
|
||||
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
||||
>>> trunk = FFN(input_variables=['b'], output_variables=20)
|
||||
@@ -127,9 +128,15 @@ class DeepONet(torch.nn.Module):
|
||||
raise ValueError("All networks should have the same output size")
|
||||
self._nets = torch.nn.ModuleList(nets)
|
||||
logging.info("Combo DeepONet children: %s", list(self.children()))
|
||||
self.scale = torch.nn.Parameter(torch.tensor([1.0]))
|
||||
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
|
||||
|
||||
@staticmethod
|
||||
def _symbol_functions(**kwargs):
|
||||
"""
|
||||
Return a dictionary of functions that can be used as aggregators or
|
||||
reductions.
|
||||
"""
|
||||
return {
|
||||
"+": partial(torch.sum, **kwargs),
|
||||
"*": partial(torch.prod, **kwargs),
|
||||
@@ -215,4 +222,7 @@ class DeepONet(torch.nn.Module):
|
||||
|
||||
output_ = output_.as_subclass(LabelTensor)
|
||||
output_.labels = self.output_variables
|
||||
|
||||
output_ *= self.scale
|
||||
output_ += self.trasl
|
||||
return output_
|
||||
|
||||
@@ -89,8 +89,8 @@ class FeedForward(torch.nn.Module):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
|
||||
:param x: the input tensor.
|
||||
:type x: :class:`pina.LabelTensor`
|
||||
:param x: .
|
||||
:type x: :class:`pina.LabelTensor`
|
||||
:return: the output computed by the model.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
@@ -9,8 +9,9 @@ class MultiFeedForward(torch.nn.Module):
|
||||
:param dict dff_dict: dictionary of FeedForward networks.
|
||||
"""
|
||||
def __init__(self, dff_dict):
|
||||
"""
|
||||
"""
|
||||
'''
|
||||
dff_dict: dict of FeedForward objects
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(dff_dict, dict):
|
||||
|
||||
68
pina/pinn.py
68
pina/pinn.py
@@ -88,10 +88,13 @@ class PINN(object):
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
""" The problem formulation."""
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, problem):
|
||||
"""
|
||||
Set the problem formulation."""
|
||||
if not isinstance(problem, AbstractProblem):
|
||||
raise TypeError
|
||||
self._problem = problem
|
||||
@@ -99,11 +102,11 @@ class PINN(object):
|
||||
def _compute_norm(self, vec):
|
||||
"""
|
||||
Compute the norm of the `vec` one-dimensional tensor based on the
|
||||
`self.error_norm` attribute.
|
||||
`self.error_norm` attribute.
|
||||
|
||||
.. todo: complete
|
||||
|
||||
:param vec torch.tensor: the tensor
|
||||
:param torch.Tensor vec: the tensor
|
||||
"""
|
||||
if isinstance(self.error_norm, int):
|
||||
return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe)
|
||||
@@ -115,7 +118,11 @@ class PINN(object):
|
||||
raise RuntimeError
|
||||
|
||||
def save_state(self, filename):
|
||||
"""
|
||||
Save the state of the model.
|
||||
|
||||
:param str filename: the filename to save the state to.
|
||||
"""
|
||||
checkpoint = {
|
||||
'epoch': self.trained_epoch,
|
||||
'model_state': self.model.state_dict(),
|
||||
@@ -133,6 +140,11 @@ class PINN(object):
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
def load_state(self, filename):
|
||||
"""
|
||||
Load the state of the model.
|
||||
|
||||
:param str filename: the filename to load the state from.
|
||||
"""
|
||||
|
||||
checkpoint = torch.load(filename)
|
||||
self.model.load_state_dict(checkpoint['model_state'])
|
||||
@@ -298,32 +310,32 @@ class PINN(object):
|
||||
|
||||
return sum(losses).item()
|
||||
|
||||
def error(self, dtype='l2', res=100):
|
||||
# def error(self, dtype='l2', res=100):
|
||||
|
||||
import numpy as np
|
||||
if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
|
||||
pts_container = []
|
||||
for mn, mx in self.problem.domain_bound:
|
||||
pts_container.append(np.linspace(mn, mx, res))
|
||||
grids_container = np.meshgrid(*pts_container)
|
||||
Z_true = self.problem.truth_solution(*grids_container)
|
||||
# import numpy as np
|
||||
# if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
|
||||
# pts_container = []
|
||||
# for mn, mx in self.problem.domain_bound:
|
||||
# pts_container.append(np.linspace(mn, mx, res))
|
||||
# grids_container = np.meshgrid(*pts_container)
|
||||
# Z_true = self.problem.truth_solution(*grids_container)
|
||||
|
||||
elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
|
||||
grids_container = self.problem.data_solution['grid']
|
||||
Z_true = self.problem.data_solution['grid_solution']
|
||||
try:
|
||||
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
|
||||
dtype=self.dtype, device=self.device)
|
||||
Z_pred = self.model(unrolled_pts)
|
||||
Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
|
||||
# elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
|
||||
# grids_container = self.problem.data_solution['grid']
|
||||
# Z_true = self.problem.data_solution['grid_solution']
|
||||
# try:
|
||||
# unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
|
||||
# dtype=self.dtype, device=self.device)
|
||||
# Z_pred = self.model(unrolled_pts)
|
||||
# Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
|
||||
|
||||
if dtype == 'l2':
|
||||
return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true)
|
||||
else:
|
||||
# TODO H1
|
||||
pass
|
||||
except:
|
||||
print("")
|
||||
print("Something went wrong...")
|
||||
print(
|
||||
"Not able to compute the error. Please pass a data solution or a true solution")
|
||||
# if dtype == 'l2':
|
||||
# return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true)
|
||||
# else:
|
||||
# # TODO H1
|
||||
# pass
|
||||
# except:
|
||||
# print("")
|
||||
# print("Something went wrong...")
|
||||
# print(
|
||||
# "Not able to compute the error. Please pass a data solution or a true solution")
|
||||
|
||||
42
tests/test_condition.py
Normal file
42
tests/test_condition.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina import LabelTensor, Condition, Span, PINN
|
||||
from pina.problem import SpatialProblem
|
||||
from pina.model import FeedForward
|
||||
from pina.operators import nabla
|
||||
|
||||
|
||||
example_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||
def example_dirichlet(input_, output_):
|
||||
value = 0.0
|
||||
return output_.extract(['u']) - value
|
||||
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||
example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
||||
|
||||
def test_init_inputoutput():
|
||||
Condition(input_points=example_input_pts, output_points=example_output_pts)
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_input_pts, example_output_pts)
|
||||
with pytest.raises(TypeError):
|
||||
Condition(input_points=3., output_points='example')
|
||||
with pytest.raises(TypeError):
|
||||
Condition(input_points=example_domain, output_points=example_dirichlet)
|
||||
|
||||
def test_init_locfunc():
|
||||
Condition(location=example_domain, function=example_dirichlet)
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_domain, example_dirichlet)
|
||||
with pytest.raises(TypeError):
|
||||
Condition(location=3., function='example')
|
||||
with pytest.raises(TypeError):
|
||||
Condition(location=example_input_pts, function=example_output_pts)
|
||||
|
||||
def test_init_inputfunc():
|
||||
Condition(input_points=example_input_pts, function=example_dirichlet)
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_domain, example_dirichlet)
|
||||
with pytest.raises(TypeError):
|
||||
Condition(input_points=3., function='example')
|
||||
with pytest.raises(TypeError):
|
||||
Condition(input_points=example_domain, funtion=example_output_pts)
|
||||
Reference in New Issue
Block a user