Merge branch 'label_tensor' of https://github.com/ndem0/PINA into label_tensor

This commit is contained in:
aivagnes
2022-03-30 10:34:30 +02:00
21 changed files with 558 additions and 372 deletions

10
docs/source/_rst/code.rst Normal file
View File

@@ -0,0 +1,10 @@
Code Documentation
==================
.. toctree::
:maxdepth: 3
LabelTensor <label_tensor.rst>
FeedForward <fnn.rst>
DeepONet <deeponet.rst>
PINN <pinn.rst>

View File

@@ -0,0 +1,12 @@
DeepONet
===========
.. currentmodule:: pina.model.deeponet
.. automodule:: pina.model.deeponet
.. autoclass:: DeepONet
:members:
:private-members:
:undoc-members:
:show-inheritance:
:noindex:

12
docs/source/_rst/fnn.rst Normal file
View File

@@ -0,0 +1,12 @@
FeedForward
===========
.. currentmodule:: pina.model.feed_forward
.. automodule:: pina.model.feed_forward
.. autoclass:: FeedForward
:members:
:private-members:
:undoc-members:
:show-inheritance:
:noindex:

View File

@@ -0,0 +1,12 @@
LabelTensor
===========
.. currentmodule:: pina.label_tensor
.. automodule:: pina.label_tensor
.. autoclass:: LabelTensor
:members:
:private-members:
:undoc-members:
:show-inheritance:
:noindex:

12
docs/source/_rst/pinn.rst Normal file
View File

@@ -0,0 +1,12 @@
PINN
====
.. currentmodule:: pina.pinn
.. automodule:: pina.pinn
.. autoclass:: PINN
:members:
:private-members:
:undoc-members:
:show-inheritance:
:noindex:

View File

@@ -47,12 +47,15 @@ extensions = [
'sphinx.ext.ifconfig',
'sphinx.ext.mathjax',
]
autosummary_generate = True
intersphinx_mapping = {
'python': ('http://docs.python.org/2', None), 'numpy':
('http://docs.scipy.org/doc/numpy/', None), 'scipy':
('http://docs.scipy.org/doc/scipy/reference/', None), 'matplotlib':
('http://matplotlib.sourceforge.net/', None)
'python': ('http://docs.python.org/2', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
'matplotlib': ('http://matplotlib.sourceforge.net/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'pina': ('https://mathlab.github.io/PINA/', None)
}
# Add any paths that contain templates here, relative to this directory.

View File

@@ -22,6 +22,7 @@ solve problems in a continuous and nonlinear settings.
:caption: Package Documentation:
Installation <_rst/installation>
API <_rst/code>
Contributing <_rst/contributing>
License <LICENSE.rst>
@@ -30,7 +31,7 @@ solve problems in a continuous and nonlinear settings.
.. ........................................................................................
.. toctree::
:maxdepth: 2
:maxdepth: 1
:numbered:
:caption: Tutorials:

View File

@@ -13,13 +13,14 @@ class Poisson(SpatialProblem):
domain = Span({'x': [0, 1], 'y': [0, 1]})
def laplace_equation(input_, output_):
force_term = (torch.sin(input_['x']*torch.pi) *
torch.sin(input_['y']*torch.pi))
return nabla(output_['u'], input_).flatten() - force_term
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
nabla_u = nabla(output_.extract(['u']), input_)
return nabla_u - force_term
def nil_dirichlet(input_, output_):
value = 0.0
return output_['u'] - value
return output_.extract(['u']) - value
conditions = {
'gamma1': Condition(Span({'x': [-1, 1], 'y': 1}), nil_dirichlet),

View File

@@ -35,7 +35,7 @@ if __name__ == "__main__":
poisson_problem = Poisson()
model = FeedForward(
layers=[10, 10],
layers=[20, 20],
output_variables=poisson_problem.output_variables,
input_variables=poisson_problem.input_variables,
func=Softplus,
@@ -45,17 +45,17 @@ if __name__ == "__main__":
pinn = PINN(
poisson_problem,
model,
lr=0.003,
lr=0.03,
error_norm='mse',
regularizer=1e-8,
lr_accelerate=None)
regularizer=1e-8)
if args.s:
pinn.span_pts(20, 'grid', ['D'])
print(pinn)
pinn.span_pts(20, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.span_pts(20, 'grid', ['D'])
#pinn.plot_pts()
pinn.train(1000, 100)
pinn.train(5000, 100)
with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
for i, losses in enumerate(pinn.history):
file_.write('{} {}\n'.format(i, sum(losses)))

View File

@@ -1,64 +1,132 @@
""" Module for LabelTensor """
import torch
class LabelTensor():
def __init__(self, x, labels):
if len(labels) != x.shape[1]:
print(len(labels), x.shape[1])
raise ValueError
self.__labels = labels
self.tensor = x
def __getitem__(self, key):
if isinstance(key, (tuple, list)):
indeces = [self.labels.index(k) for k in key]
return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces])
if key in self.labels:
return self.tensor[:, self.labels.index(key)]
else:
return self.tensor.__getitem__(key)
def __repr__(self):
return self.tensor
def __str__(self):
return '{}\n {}\n'.format(self.labels, self.tensor)
@property
def shape(self):
return self.tensor.shape
@property
def dtype(self):
return self.tensor.dtype
@property
def device(self):
return self.tensor.device
@property
def labels(self):
return self.__labels
class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
@staticmethod
def hstack(labeltensor_list):
concatenated_tensor = torch.cat([lt.tensor for lt in labeltensor_list], axis=1)
concatenated_label = sum([lt.labels for lt in labeltensor_list], [])
return LabelTensor(concatenated_tensor, concatenated_label)
def __new__(cls, x, labels, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x, labels):
'''
Construct a `LabelTensor` by passing a tensor and a list of column
labels. Such labels uniquely identify the columns of the tensor,
allowing for an easier manipulation.
:param torch.Tensor x: the data tensor.
:param labels: the labels of the columns.
:type labels: str or iterable(str)
if __name__ == "__main__":
import numpy as np
a = np.random.uniform(size=(20, 3))
a = np.random.uniform(size=(20, 3))
p = torch.from_numpy(a)
t = LabelTensor(p, labels=['u', 'p', 't'])
print(t)
print(t['u'])
t *= 2
print(t['u'])
print(t[:, 0])
:Example:
>>> from pina import LabelTensor
>>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c'])
>>> tensor
tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01],
[9.2392e-01, 8.2065e-01, 4.1986e-04],
[8.9266e-01, 5.5446e-01, 6.3500e-01],
...,
[5.8194e-01, 9.4268e-01, 4.1841e-01],
[1.0246e-01, 9.5179e-01, 3.7043e-02],
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
>>> tensor.extract('a')
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor.extract(['a', 'b'])
tensor([[0.0671, 0.4889],
[0.9239, 0.8207],
[0.8927, 0.5545],
...,
[0.5819, 0.9427],
[0.1025, 0.9518],
[0.9615, 0.8066]])
>>> tensor.extract(['b', 'a'])
tensor([[0.4889, 0.0671],
[0.8207, 0.9239],
[0.5545, 0.8927],
...,
[0.9427, 0.5819],
[0.9518, 0.1025],
[0.8066, 0.9615]])
'''
if isinstance(labels, str):
labels = [labels]
if len(labels) != x.shape[1]:
raise ValueError(
'the tensor has not the same number of columns of '
'the passed labels.'
)
self.labels = labels
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
:meth:`torch.Tensor.clone`.
:return: a copy of the tensor
:rtype: LabelTensor
"""
return LabelTensor(super().clone(*args, **kwargs), self.labels)
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
"""
tmp = super().to(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns
corresponding to the passed `label_to_extract`.
:param label_to_extract: the label(s) to extract.
:type label_to_extract: str or iterable(str)
:raises TypeError: labels are not str
:raises ValueError: label to extract is not in the labels list
"""
if isinstance(label_to_extract, str):
label_to_extract = [label_to_extract]
elif isinstance(label_to_extract, (tuple, list)): # TODO
pass
else:
raise TypeError(
'`label_to_extract` should be a str, or a str iterator')
try:
indeces = [self.labels.index(f) for f in label_to_extract]
except ValueError:
raise ValueError('`label_to_extract` not in the labels list')
new_data = self[:, indeces].float()
new_labels = [self.labels[idx] for idx in indeces]
extracted_tensor = LabelTensor(new_data, new_labels)
return extracted_tensor
def append(self, lt):
"""
Return a copy of the merged tensors.
:param LabelTensor lt: the tensor to merge.
:return: the merged tensors
:rtype: LabelTensor
"""
if set(self.labels).intersection(lt.labels):
raise RuntimeError('The tensors to merge have common labels')
new_labels = self.labels + lt.labels
new_tensor = torch.cat((self, lt), dim=1)
return LabelTensor(new_tensor, new_labels)

View File

@@ -1,7 +1,12 @@
"""Module for Location class."""
from abc import ABCMeta, abstractmethod
class Location(metaclass=ABCMeta):
"""
Abstract class
"""
@property
@abstractmethod

View File

@@ -1,7 +1,9 @@
__all__ = [
'FeedForward',
'MultiFeedForward'
'DeepONet',
]
from .feed_forward import FeedForward
from .multi_feed_forward import MultiFeedForward
from .deeponet import DeepONet

91
pina/model/deeponet.py Normal file
View File

@@ -0,0 +1,91 @@
"""Module for DeepONet model"""
import torch
import torch.nn as nn
from pina.label_tensor import LabelTensor
class DeepONet(torch.nn.Module):
"""
The PINA implementation of DeepONet network.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operators*. Nat Mach Intell 3, 218229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
"""
def __init__(self, branch_net, trunk_net, output_variables):
"""
:param torch.nn.Module branch_net: the neural network to use as branch
model. It has to take as input a :class:`LabelTensor`. The number
of dimension of the output has to be the same of the `trunk_net`.
:param torch.nn.Module trunk_net: the neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`. The number
of dimension of the output has to be the same of the `branch_net`.
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the
model.
:Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
>>> trunk = FFN(input_variables=['b'], output_variables=20)
>>> onet = DeepONet(trunk_net=trunk, branch_net=branch
>>> output_variables=output_vars)
DeepONet(
(trunk_net): FeedForward(
(extra_features): Sequential()
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(branch_net): FeedForward(
(extra_features): Sequential()
(model): Sequential(
(0): Linear(in_features=2, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
)
"""
super().__init__()
self.trunk_net = trunk_net
self.branch_net = branch_net
self.output_variables = output_variables
self.output_dimension = len(output_variables)
if self.output_dimension > 1:
raise NotImplementedError('Vectorial DeepONet to be implemented')
@property
def input_variables(self):
"""The input variables of the model"""
return self.trunk_net.input_variables + self.branch_net.input_variables
def forward(self, x):
"""
Defines the computation performed at every call.
:param LabelTensor x: the input tensor.
:return: the output computed by the model.
:rtype: LabelTensor
"""
branch_output = self.branch_net(
x.extract(self.branch_net.input_variables))
trunk_output = self.trunk_net(
x.extract(self.trunk_net.input_variables))
output_ = torch.sum(branch_output * trunk_output, dim=1).reshape(-1, 1)
return LabelTensor(output_, self.output_variables)

View File

@@ -1,3 +1,4 @@
"""Module for FeedForward model"""
import torch
import torch.nn as nn
@@ -5,22 +6,50 @@ from pina.label_tensor import LabelTensor
class FeedForward(torch.nn.Module):
"""
The PINA implementation of feedforward network, also refered as multilayer
perceptron.
:param list(str) input_variables: the list containing the labels
corresponding to the input components of the model.
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the model.
:param int inner_size: number of neurons in the hidden layer(s). Default is
20.
:param int n_layers: number of hidden layers. Default is 2.
:param func: the activation function to use. If a single
:class:`torch.nn.Module` is passed, this is used as activation function
after any layers, except the last one. If a list of Modules is passed,
they are used as activation functions at any layers, in order.
:param iterable(int) layers: a list containing the number of neurons for
any hidden layers. If specified, the parameters `n_layers` e
`inner_size` are not considered.
:param iterable(torch.nn.Module) extra_features: the additional input
features to use ad augmented input.
"""
def __init__(self, input_variables, output_variables, inner_size=20,
n_layers=2, func=nn.Tanh, layers=None, extra_features=None):
'''
'''
"""
"""
super().__init__()
if extra_features is None:
extra_features = []
self.extra_features = nn.Sequential(*extra_features)
self.input_variables = input_variables
self.input_dimension = len(input_variables)
if isinstance(input_variables, int):
self.input_variables = None
self.input_dimension = input_variables
elif isinstance(input_variables, (tuple, list)):
self.input_variables = input_variables
self.input_dimension = len(input_variables)
self.output_variables = output_variables
self.output_dimension = len(output_variables)
if isinstance(output_variables, int):
self.output_variables = None
self.output_dimension = output_variables
elif isinstance(output_variables, (tuple, list)):
self.output_variables = output_variables
self.output_dimension = len(output_variables)
n_features = len(extra_features)
@@ -40,6 +69,9 @@ class FeedForward(torch.nn.Module):
else:
self.functions = [func for _ in range(len(self.layers)-1)]
if len(self.layers) != len(self.functions) + 1:
raise RuntimeError('uncosistent number of layers and functions')
unique_list = []
for layer, func in zip(self.layers[:-1], self.functions):
unique_list.append(layer)
@@ -51,18 +83,30 @@ class FeedForward(torch.nn.Module):
def forward(self, x):
"""
Defines the computation performed at every call.
:param x: the input tensor.
:type x: :class:`pina.LabelTensor`
:return: the output computed by the model.
:rtype: LabelTensor
"""
if self.input_variables:
x = x.extract(self.input_variables)
x = x[self.input_variables]
nf = len(self.extra_features)
if nf == 0:
return LabelTensor(self.model(x.tensor), self.output_variables)
labels = []
features = []
for i, feature in enumerate(self.extra_features):
labels.append('k{}'.format(i))
features.append(feature(x))
# if self.extra_features
input_ = torch.zeros(x.shape[0], nf+x.shape[1], dtype=x.dtype,
device=x.device)
input_[:, :x.shape[1]] = x.tensor
for i, feature in enumerate(self.extra_features,
start=self.input_dimension):
input_[:, i] = feature(x)
return LabelTensor(self.model(input_), self.output_variables)
if labels and features:
features = torch.cat(features, dim=1)
features_tensor = LabelTensor(features, labels)
input_ = x.append(features_tensor) # TODO error when no LabelTens
else:
input_ = x
if self.output_variables:
return LabelTensor(self.model(input_), self.output_variables)
else:
return self.model(input_)

View File

@@ -13,10 +13,10 @@ def grad(output_, input_):
gradients = torch.autograd.grad(
output_,
input_.tensor,
input_,
grad_outputs=torch.ones(output_.size()).to(
dtype=input_.tensor.dtype,
device=input_.tensor.device),
dtype=input_.dtype,
device=input_.device),
create_graph=True, retain_graph=True, allow_unused=True)[0]
return LabelTensor(gradients, input_.labels)
@@ -25,14 +25,14 @@ def div(output_, input_):
"""
TODO
"""
if output_.tensor.shape[1] == 1:
div = grad(output_.tensor, input_).sum(axis=1)
if output_.shape[1] == 1:
div = grad(output_, input_).sum(axis=1)
else: # really to improve
a = []
for o in output_.tensor.T:
a.append(grad(o, input_).tensor)
div = torch.zeros(output_.tensor.shape[0], 1)
for i in range(output_.tensor.shape[1]):
for o in output_.T:
a.append(grad(o, input_))
div = torch.zeros(output_.shape[0], 1)
for i in range(output_.shape[1]):
div += a[i][:, i].reshape(-1, 1)
return div

View File

@@ -14,36 +14,23 @@ class PINN(object):
lr=0.001,
regularizer=0.00001,
data_weight=1.,
dtype=torch.float64,
dtype=torch.float32,
device='cpu',
lr_accelerate=None,
error_norm='mse'):
'''
:param Problem problem: the formualation of the problem.
:param dict architecture: a dictionary containing the information to
build the model. Valid options are:
- inner_size [int] the number of neurons in the hidden layers; by
default is 20.
- n_layers [int] the number of hidden layers; by default is 4.
- func [nn.Module or str] the activation function; passing a `str`
is possible to chose adaptive function (between 'adapt_tanh'); by
default is non-adaptive iperbolic tangent.
:param float lr: the learning rate; default is 0.001
:param float regularizer: the coefficient for L2 regularizer term
:param torch.nn.Module model: the neural network model to use.
:param float lr: the learning rate; default is 0.001.
:param float regularizer: the coefficient for L2 regularizer term.
:param type dtype: the data type to use for the model. Valid option are
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
default is `torch.float64`.
:param float lr_accelete: the coefficient that controls the learning
rate increase, such that, for all the epoches in which the loss is
decreasing, the learning_rate is update using
$learning_rate = learning_rate * lr_accelerate$.
When the loss stops to decrease, the learning rate is set to the
initial value [TODO test parameters]
'''
self.problem = problem
if dtype == torch.float64:
raise NotImplementedError('only float for now')
self.problem = problem
# self._architecture = architecture if architecture else dict()
# self._architecture['input_dimension'] = self.problem.domain_bound.shape[0]
@@ -51,8 +38,6 @@ class PINN(object):
# if hasattr(self.problem, 'params_domain'):
# self._architecture['input_dimension'] += self.problem.params_domain.shape[0]
self.accelerate = lr_accelerate
self.error_norm = error_norm
if device == 'cuda' and not torch.cuda.is_available():
@@ -85,26 +70,6 @@ class PINN(object):
raise TypeError
self._problem = problem
def get_data_residuals(self):
data_residuals = []
for output in self.data_pts:
data_values_pred = self.model(self.data_pts[output])
data_residuals.append(data_values_pred - self.data_values[output])
return torch.cat(data_residuals)
def get_phys_residuals(self):
"""
"""
residuals = []
for equation in self.problem.equation:
residuals.append(equation(self.phys_pts, self.model(self.phys_pts)))
return residuals
def _compute_norm(self, vec):
"""
Compute the norm of the `vec` one-dimensional tensor based on the
@@ -156,10 +121,6 @@ class PINN(object):
def span_pts(self, n, mode='grid', locations='all'):
'''
'''
if locations == 'all':
locations = [condition for condition in self.problem.conditions]
@@ -171,12 +132,12 @@ class PINN(object):
except:
pts = condition.input_points
print(location, pts)
self.input_pts[location] = pts
self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device)
self.input_pts[location].tensor.requires_grad_(True)
self.input_pts[location].tensor.retain_grad()
self.input_pts[location] = pts#.double() # TODO
self.input_pts[location] = (
self.input_pts[location].to(dtype=self.dtype,
device=self.device))
self.input_pts[location].requires_grad_(True)
self.input_pts[location].retain_grad()
@@ -277,214 +238,3 @@ class PINN(object):
print("")
print("Something went wrong...")
print("Not able to compute the error. Please pass a data solution or a true solution")
def plot(self, res, filename=None, variable=None):
'''
'''
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
self._plot_2D(res, filename, variable)
print('TTTTTTTTTTTTTTTTTt')
print(self.problem.bounds)
pts_container = []
#for mn, mx in [[-1, 1], [-1, 1]]:
for mn, mx in [[0, 1], [0, 1]]:
#for mn, mx in [[-1, 1], [0, 1]]:
pts_container.append(np.linspace(mn, mx, res))
grids_container = np.meshgrid(*pts_container)
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T
unrolled_pts.to(dtype=self.dtype)
Z_pred = self.model(unrolled_pts)
#######################################################
# poisson
# Z_truth = self.problem.truth_solution(unrolled_pts[:, 0], unrolled_pts[:, 1])
# Z_pred = Z_pred.tensor.detach().reshape(grids_container[0].shape)
# Z_truth = Z_truth.detach().reshape(grids_container[0].shape)
# err = np.abs(Z_pred-Z_truth)
# with open('poisson2_nofeat_plot.txt', 'w') as f_:
# f_.write('x y truth pred e\n')
# for (x, y), tru, pre, e in zip(unrolled_pts, Z_truth.reshape(-1, 1), Z_pred.reshape(-1, 1), err.reshape(-1, 1)):
# f_.write('{} {} {} {} {}\n'.format(x.item(), y.item(), tru.item(), pre.item(), e.item()))
# n = Z_pred.shape[1]
# plt.figure(figsize=(16, 6))
# plt.subplot(1, 3, 1)
# plt.contourf(*grids_container, Z_truth)
# plt.colorbar()
# plt.subplot(1, 3, 2)
# plt.contourf(*grids_container, Z_pred)
# plt.colorbar()
# plt.subplot(1, 3, 3)
# plt.contourf(*grids_container, err)
# plt.colorbar()
# plt.show()
#######################################################
# burgers
import scipy
data = scipy.io.loadmat('Data/burgers_shock.mat')
data_solution = {'grid': np.meshgrid(data['x'], data['t']), 'grid_solution': data['usol'].T}
grids_container = data_solution['grid']
print(data_solution['grid_solution'].shape)
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T
unrolled_pts.to(dtype=self.dtype)
Z_pred = self.model(unrolled_pts)
Z_truth = data_solution['grid_solution']
Z_pred = Z_pred.tensor.detach().reshape(grids_container[0].shape)
print(Z_pred, Z_truth)
err = np.abs(Z_pred.numpy()-Z_truth)
with open('burgers_nofeat_plot.txt', 'w') as f_:
f_.write('x y truth pred e\n')
for (x, y), tru, pre, e in zip(unrolled_pts, Z_truth.reshape(-1, 1), Z_pred.reshape(-1, 1), err.reshape(-1, 1)):
f_.write('{} {} {} {} {}\n'.format(x.item(), y.item(), tru.item(), pre.item(), e.item()))
n = Z_pred.shape[1]
plt.figure(figsize=(16, 6))
plt.subplot(1, 3, 1)
plt.contourf(*grids_container, Z_truth,vmin=-1, vmax=1)
plt.colorbar()
plt.subplot(1, 3, 2)
plt.contourf(*grids_container, Z_pred, vmin=-1, vmax=1)
plt.colorbar()
plt.subplot(1, 3, 3)
plt.contourf(*grids_container, err)
plt.colorbar()
plt.show()
# for i, output in enumerate(Z_pred.tensor.T, start=1):
# output = output.detach().numpy().reshape(grids_container[0].shape)
# plt.subplot(1, n, i)
# plt.contourf(*grids_container, output)
# plt.colorbar()
if filename is None:
plt.show()
else:
plt.savefig(filename)
def plot_params(self, res, param, filename=None, variable=None):
'''
'''
import matplotlib
matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
n_plot = 2
elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
n_plot = 2
else:
n_plot = 1
fig, axs = plt.subplots(nrows=1, ncols=n_plot, figsize=(n_plot*6,4))
if not isinstance(axs, np.ndarray): axs = [axs]
if hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
grids_container = self.problem.data_solution['grid']
Z_true = self.problem.data_solution['grid_solution']
elif hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
pts_container = []
for mn, mx in self.problem.domain_bound:
pts_container.append(np.linspace(mn, mx, res))
grids_container = np.meshgrid(*pts_container)
Z_true = self.problem.truth_solution(*grids_container)
pts_container = []
for mn, mx in self.problem.domain_bound:
pts_container.append(np.linspace(mn, mx, res))
grids_container = np.meshgrid(*pts_container)
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(dtype=self.type)
#print(unrolled_pts)
#print(param)
param_unrolled_pts = torch.cat((unrolled_pts, param.repeat(unrolled_pts.shape[0], 1)), 1)
if variable==None:
variable = self.problem.variables[0]
Z_pred = self.evaluate(param_unrolled_pts)[variable]
variable = "Solution"
else:
Z_pred = self.evaluate(param_unrolled_pts)[variable]
Z_pred= Z_pred.detach().numpy().reshape(grids_container[0].shape)
set_pred = axs[0].contourf(*grids_container, Z_pred)
axs[0].set_title('PINN [trained epoch = {}]'.format(self.trained_epoch) + " " + variable) #TODO add info about parameter in the title
fig.colorbar(set_pred, ax=axs[0])
if n_plot == 2:
set_true = axs[1].contourf(*grids_container, Z_true)
axs[1].set_title('Truth solution')
fig.colorbar(set_true, ax=axs[1])
if filename is None:
plt.show()
else:
fig.savefig(filename + " " + variable)
def plot_error(self, res, filename=None):
import matplotlib
matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(6,4))
if not isinstance(axs, np.ndarray): axs = [axs]
if hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
grids_container = self.problem.data_solution['grid']
Z_true = self.problem.data_solution['grid_solution']
elif hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
pts_container = []
for mn, mx in self.problem.domain_bound:
pts_container.append(np.linspace(mn, mx, res))
grids_container = np.meshgrid(*pts_container)
Z_true = self.problem.truth_solution(*grids_container)
try:
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(dtype=self.type)
Z_pred = self.model(unrolled_pts)
Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
set_pred = axs[0].contourf(*grids_container, abs(Z_pred - Z_true))
axs[0].set_title('PINN [trained epoch = {}]'.format(self.trained_epoch) + "Pointwise Error")
fig.colorbar(set_pred, ax=axs[0])
if filename is None:
plt.show()
else:
fig.savefig(filename)
except:
print("")
print("Something went wrong...")
print("Not able to plot the error. Please pass a data solution or a true solution")
'''
print(self.pred_loss.item(),loss.item(), self.old_loss.item())
if self.accelerate is not None:
if self.pred_loss > loss and loss >= self.old_loss:
self.current_lr = self.original_lr
#print('restart')
elif (loss-self.pred_loss).item() < 0.1:
self.current_lr += .5*self.current_lr
#print('powa')
else:
self.current_lr -= .5*self.current_lr
#print(self.current_lr)
#self.current_lr = min(loss.item()*3, 0.02)
for g in self.optimizer.param_groups:
g['lr'] = self.current_lr
'''

View File

@@ -1,6 +1,6 @@
""" Module for plotting. """
import matplotlib
#matplotlib.use('Qt5Agg')
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
@@ -119,16 +119,15 @@ class Plotter:
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
print(pts)
grids_container = [
pts.tensor[:, 0].reshape(res, res),
pts.tensor[:, 1].reshape(res, res),
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
predicted_output = obj.model(pts)
predicted_output = predicted_output['u']
predicted_output = predicted_output.extract(['u'])
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
truth_output = obj.problem.truth_solution(*pts.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
@@ -139,7 +138,6 @@ class Plotter:
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
@@ -153,7 +151,7 @@ class Plotter:
def plot_samples(self, obj):
for location in obj.input_pts:
plt.plot(*obj.input_pts[location].tensor.T.detach(), '.', label=location)
plt.plot(*obj.input_pts[location].T.detach(), '.', label=location)
plt.legend()
plt.show()

View File

@@ -37,7 +37,6 @@ class Span(Location):
for _ in range(bounds.shape[0])])
grids = np.meshgrid(*pts)
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
print(pts)
elif mode == 'lh' or mode == 'latin':
from scipy.stats import qmc
sampler = qmc.LatinHypercube(d=bounds.shape[0])
@@ -46,15 +45,17 @@ class Span(Location):
# Scale pts
pts *= bounds[:, 1] - bounds[:, 0]
pts += bounds[:, 0]
pts = pts.astype(np.float32)
pts = torch.from_numpy(pts)
pts_range_ = LabelTensor(pts, list(self.range_.keys()))
fixed = torch.Tensor(list(self.fixed_.values()))
pts_fixed_ = torch.ones(pts_range_.tensor.shape[0], len(self.fixed_)) * fixed
pts_fixed_ = torch.ones(pts.shape[0], len(self.fixed_),
dtype=pts.dtype) * fixed
pts_range_ = LabelTensor(pts, list(self.range_.keys()))
pts_fixed_ = LabelTensor(pts_fixed_, list(self.fixed_.keys()))
if self.fixed_:
return LabelTensor.hstack([pts_range_, pts_fixed_])
return pts_range_.append(pts_fixed_)
else:
return pts_range_

26
tests/test_deeponet.py Normal file
View File

@@ -0,0 +1,26 @@
import torch
import pytest
from pina import LabelTensor
from pina.model import DeepONet, FeedForward as FFN
data = torch.rand((20, 3))
input_vars = ['a', 'b', 'c']
output_vars = ['d']
input_ = LabelTensor(data, input_vars)
def test_constructor():
branch = FFN(input_variables=['a', 'c'], output_variables=20)
trunk = FFN(input_variables=['b'], output_variables=20)
onet = DeepONet(trunk_net=trunk, branch_net=branch,
output_variables=output_vars)
def test_forward():
branch = FFN(input_variables=['a', 'c'], output_variables=10)
trunk = FFN(input_variables=['b'], output_variables=10)
onet = DeepONet(trunk_net=trunk, branch_net=branch,
output_variables=output_vars)
output_ = onet(input_)
assert output_.labels == output_vars

58
tests/test_fnn.py Normal file
View File

@@ -0,0 +1,58 @@
import torch
import pytest
from pina import LabelTensor
from pina.model import FeedForward
class myFeature(torch.nn.Module):
"""
Feature: sin(pi*x)
"""
def __init__(self):
super(myFeature, self).__init__()
def forward(self, x):
return torch.sin(torch.pi * x.extract('a'))
data = torch.rand((20, 3))
input_vars = ['a', 'b', 'c']
output_vars = ['d', 'e']
input_ = LabelTensor(data, input_vars)
def test_constructor():
FeedForward(input_vars, output_vars)
FeedForward(3, 4)
FeedForward(input_vars, output_vars, extra_features=[myFeature()])
FeedForward(input_vars, output_vars, inner_size=10, n_layers=20)
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2])
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2],
func=torch.nn.ReLU)
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2],
func=[torch.nn.ReLU, torch.nn.ReLU, None, torch.nn.Tanh])
def test_constructor_wrong():
with pytest.raises(RuntimeError):
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2],
func=[torch.nn.ReLU, torch.nn.ReLU])
def test_forward():
fnn = FeedForward(input_vars, output_vars)
output_ = fnn(input_)
assert output_.labels == output_vars
def test_forward2():
dim_in, dim_out = 3, 2
fnn = FeedForward(dim_in, dim_out)
output_ = fnn(input_)
assert output_.shape == (input_.shape[0], dim_out)
def test_forward_features():
fnn = FeedForward(input_vars, output_vars, extra_features=[myFeature()])
output_ = fnn(input_)
assert output_.labels == output_vars

View File

@@ -0,0 +1,80 @@
import torch
import pytest
from pina import LabelTensor
data = torch.rand((20, 3))
labels = ['a', 'b', 'c']
def test_constructor():
LabelTensor(data, labels)
def test_wrong_constructor():
with pytest.raises(ValueError):
LabelTensor(data, ['a', 'b'])
def test_labels():
tensor = LabelTensor(data, labels)
assert isinstance(tensor, torch.Tensor)
assert tensor.labels == labels
def test_extract():
label_to_extract = ['a', 'c']
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(data[:, 0::2], new))
def test_extract_onelabel():
label_to_extract = ['a']
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
assert new.ndim == 2
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
def test_wrong_extract():
label_to_extract = ['a', 'cc']
tensor = LabelTensor(data, labels)
with pytest.raises(ValueError):
tensor.extract(label_to_extract)
def test_extract_order():
label_to_extract = ['c', 'a']
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
expected = torch.cat(
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
dim=1)
print(expected)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bb = tensor_b.append(tensor_b)
assert torch.allclose(tensor_b, tensor.extract(['b', 'c']))