This commit is contained in:
Your Name
2023-04-18 10:49:57 +02:00
parent da33aeae3a
commit 736c78fd64
17 changed files with 292 additions and 172 deletions

View File

@@ -1,43 +1,86 @@
""" """
import torch
""" Condition module. """
from .label_tensor import LabelTensor
from .location import Location
def dummy(a):
"""Dummy function for testing purposes."""
return None
class Condition:
"""
The class `Condition` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.Abstract_Problem` object.
Conditions can be specified in three ways:
1. By specifying the input and output points of the condition; in such a
case, the model is trained to produce the output points given the input
points.
2. By specifying the location and the function of the condition; in such
a case, the model is trained to minimize that function by evaluating it
at some samples of the location.
3. By specifying the input points and the function of the condition; in
such a case, the model is trained to minimize that function by
evaluating it at the input points.
Example::
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
>>> def example_dirichlet(input_, output_):
>>> value = 0.0
>>> return output_.extract(['u']) - value
>>> example_input_pts = LabelTensor(
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
>>>
>>> Condition(
>>> input_points=example_input_pts,
>>> output_points=example_output_pts)
>>> Condition(
>>> location=example_domain,
>>> function=example_dirichlet)
>>> Condition(
>>> input_points=example_input_pts,
>>> function=example_dirichlet)
"""
__slots__ = [
'input_points', 'output_points', 'location', 'function',
'data_weight'
]
def _dictvalue_isinstance(self, dict_, key_, class_):
"""Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
if key_ not in dict_.keys():
return True
return isinstance(dict_[key_], class_)
def __init__(self, *args, **kwargs):
"""
Constructor for the `Condition` class.
"""
self.data_weight = kwargs.pop('data_weight', 1.0)
if 'data_weight' in kwargs:
self.data_weight = kwargs['data_weight']
if not 'data_weight' in kwargs:
self.data_weight = 1.
if len(args) != 0:
raise ValueError('Condition takes only the following keyword arguments: {`input_points`, `output_points`, `location`, `function`, `data_weight`}.')
if len(args) == 2:
if (
sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
sorted(kwargs.keys()) != sorted(['location', 'function']) and
sorted(kwargs.keys()) != sorted(['input_points', 'function'])
):
raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')
if (isinstance(args[0], torch.Tensor) and
isinstance(args[1], torch.Tensor)):
self.input_points = args[0]
self.output_points = args[1]
elif isinstance(args[0], Location) and callable(args[1]):
self.location = args[0]
self.function = args[1]
elif isinstance(args[0], Location) and isinstance(args[1], list):
self.location = args[0]
self.function = args[1]
else:
raise ValueError
if not self._dictvalue_isinstance(kwargs, 'input_points', LabelTensor):
raise TypeError('`input_points` must be a torch.Tensor.')
if not self._dictvalue_isinstance(kwargs, 'output_points', LabelTensor):
raise TypeError('`output_points` must be a torch.Tensor.')
if not self._dictvalue_isinstance(kwargs, 'location', Location):
raise TypeError('`location` must be a Location.')
elif not args and len(kwargs) >= 2:
if 'input_points' in kwargs and 'output_points' in kwargs:
self.input_points = kwargs['input_points']
self.output_points = kwargs['output_points']
elif 'location' in kwargs and 'function' in kwargs:
self.location = kwargs['location']
self.function = kwargs['function']
else:
raise ValueError
else:
raise ValueError
if hasattr(self, 'function') and not isinstance(self.function, list):
self.function = [self.function]
for key, value in kwargs.items():
setattr(self, key, value)

View File

@@ -5,10 +5,14 @@ from abc import ABCMeta, abstractmethod
class Location(metaclass=ABCMeta):
"""
Abstract class
Abstract Location class.
Any geometry entity should inherit from this class.
"""
@property
@abstractmethod
def sample(self):
pass
"""
Abstract method for sampling a point from the location. To be
implemented in the child class.
"""
pass

View File

@@ -12,7 +12,7 @@ from pina.utils import is_function
def check_combos(combos, variables):
"""
Check that the given combinations are subsets (overlapping
Check that the given combinations are subsets (overlapping
is allowed) of the given set of variables.
:param iterable(iterable(str)) combos: Combinations of variables.
@@ -35,7 +35,7 @@ def spawn_combo_networks(
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(int) layers: Size of hidden layers.
:param int output_dimension: Size of the output layer of the networks.
:param func: Nonlinearity.
:param func: Nonlinearity.
:param extra_feature: Extra feature to be considered by the networks.
:param bool bias: Whether to consider bias or not.
"""
@@ -78,15 +78,16 @@ class DeepONet(torch.nn.Module):
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the
model.
:param string | callable aggregator: Aggregator to be used to aggregate
:param str | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See :func:`_symbol_functions` for the
aggregated component-wise. See
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
available default aggregators.
:param string | callable reduction: Reduction to be used to reduce
:param str | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :func:`_symbol_functions` for the available default
reductions.
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
reductions.
:Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
>>> trunk = FFN(input_variables=['b'], output_variables=20)
@@ -127,9 +128,15 @@ class DeepONet(torch.nn.Module):
raise ValueError("All networks should have the same output size")
self._nets = torch.nn.ModuleList(nets)
logging.info("Combo DeepONet children: %s", list(self.children()))
self.scale = torch.nn.Parameter(torch.tensor([1.0]))
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
@staticmethod
def _symbol_functions(**kwargs):
"""
Return a dictionary of functions that can be used as aggregators or
reductions.
"""
return {
"+": partial(torch.sum, **kwargs),
"*": partial(torch.prod, **kwargs),
@@ -215,4 +222,7 @@ class DeepONet(torch.nn.Module):
output_ = output_.as_subclass(LabelTensor)
output_.labels = self.output_variables
output_ *= self.scale
output_ += self.trasl
return output_

View File

@@ -89,8 +89,8 @@ class FeedForward(torch.nn.Module):
"""
Defines the computation performed at every call.
:param x: the input tensor.
:type x: :class:`pina.LabelTensor`
:param x: .
:type x: :class:`pina.LabelTensor`
:return: the output computed by the model.
:rtype: LabelTensor
"""

View File

@@ -9,8 +9,9 @@ class MultiFeedForward(torch.nn.Module):
:param dict dff_dict: dictionary of FeedForward networks.
"""
def __init__(self, dff_dict):
"""
"""
'''
dff_dict: dict of FeedForward objects
'''
super().__init__()
if not isinstance(dff_dict, dict):

View File

@@ -88,10 +88,13 @@ class PINN(object):
@property
def problem(self):
""" The problem formulation."""
return self._problem
@problem.setter
def problem(self, problem):
"""
Set the problem formulation."""
if not isinstance(problem, AbstractProblem):
raise TypeError
self._problem = problem
@@ -99,11 +102,11 @@ class PINN(object):
def _compute_norm(self, vec):
"""
Compute the norm of the `vec` one-dimensional tensor based on the
`self.error_norm` attribute.
`self.error_norm` attribute.
.. todo: complete
:param vec torch.tensor: the tensor
:param torch.Tensor vec: the tensor
"""
if isinstance(self.error_norm, int):
return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe)
@@ -115,7 +118,11 @@ class PINN(object):
raise RuntimeError
def save_state(self, filename):
"""
Save the state of the model.
:param str filename: the filename to save the state to.
"""
checkpoint = {
'epoch': self.trained_epoch,
'model_state': self.model.state_dict(),
@@ -133,6 +140,11 @@ class PINN(object):
torch.save(checkpoint, filename)
def load_state(self, filename):
"""
Load the state of the model.
:param str filename: the filename to load the state from.
"""
checkpoint = torch.load(filename)
self.model.load_state_dict(checkpoint['model_state'])
@@ -298,32 +310,32 @@ class PINN(object):
return sum(losses).item()
def error(self, dtype='l2', res=100):
# def error(self, dtype='l2', res=100):
import numpy as np
if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
pts_container = []
for mn, mx in self.problem.domain_bound:
pts_container.append(np.linspace(mn, mx, res))
grids_container = np.meshgrid(*pts_container)
Z_true = self.problem.truth_solution(*grids_container)
# import numpy as np
# if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
# pts_container = []
# for mn, mx in self.problem.domain_bound:
# pts_container.append(np.linspace(mn, mx, res))
# grids_container = np.meshgrid(*pts_container)
# Z_true = self.problem.truth_solution(*grids_container)
elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
grids_container = self.problem.data_solution['grid']
Z_true = self.problem.data_solution['grid_solution']
try:
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
dtype=self.dtype, device=self.device)
Z_pred = self.model(unrolled_pts)
Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
# elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
# grids_container = self.problem.data_solution['grid']
# Z_true = self.problem.data_solution['grid_solution']
# try:
# unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
# dtype=self.dtype, device=self.device)
# Z_pred = self.model(unrolled_pts)
# Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
if dtype == 'l2':
return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true)
else:
# TODO H1
pass
except:
print("")
print("Something went wrong...")
print(
"Not able to compute the error. Please pass a data solution or a true solution")
# if dtype == 'l2':
# return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true)
# else:
# # TODO H1
# pass
# except:
# print("")
# print("Something went wrong...")
# print(
# "Not able to compute the error. Please pass a data solution or a true solution")