From 6b001c6c53166a3dcaac54f38f1cf9c3e312f83a Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 29 Mar 2022 18:05:26 +0200 Subject: [PATCH] use LabelTensor, fix minor, docs --- docs/source/_rst/code.rst | 3 + docs/source/_rst/deeponet.rst | 12 ++ docs/source/_rst/fnn.rst | 12 ++ docs/source/_rst/pinn.rst | 12 ++ docs/source/conf.py | 3 +- examples/problems/poisson.py | 9 +- examples/run_poisson.py | 12 +- pina/label_tensor.py | 31 ++-- pina/location.py | 5 + pina/model/__init__.py | 2 + pina/model/deeponet.py | 91 +++++++++++ pina/model/feed_forward.py | 80 +++++++--- pina/operators.py | 18 +-- pina/pinn.py | 276 ++-------------------------------- pina/plotter.py | 14 +- pina/span.py | 9 +- tests/test_deeponet.py | 26 ++++ tests/test_fnn.py | 58 +++++++ tests/test_label_tensor.py | 19 +++ 19 files changed, 370 insertions(+), 322 deletions(-) create mode 100644 docs/source/_rst/deeponet.rst create mode 100644 docs/source/_rst/fnn.rst create mode 100644 docs/source/_rst/pinn.rst create mode 100644 pina/model/deeponet.py create mode 100644 tests/test_deeponet.py create mode 100644 tests/test_fnn.py diff --git a/docs/source/_rst/code.rst b/docs/source/_rst/code.rst index 0916b80..c32039c 100644 --- a/docs/source/_rst/code.rst +++ b/docs/source/_rst/code.rst @@ -5,3 +5,6 @@ Code Documentation :maxdepth: 3 LabelTensor + FeedForward + DeepONet + PINN diff --git a/docs/source/_rst/deeponet.rst b/docs/source/_rst/deeponet.rst new file mode 100644 index 0000000..34f3661 --- /dev/null +++ b/docs/source/_rst/deeponet.rst @@ -0,0 +1,12 @@ +DeepONet +=========== +.. currentmodule:: pina.model.deeponet + +.. automodule:: pina.model.deeponet + +.. autoclass:: DeepONet + :members: + :private-members: + :undoc-members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/fnn.rst b/docs/source/_rst/fnn.rst new file mode 100644 index 0000000..3ddfaf1 --- /dev/null +++ b/docs/source/_rst/fnn.rst @@ -0,0 +1,12 @@ +FeedForward +=========== +.. currentmodule:: pina.model.feed_forward + +.. automodule:: pina.model.feed_forward + +.. autoclass:: FeedForward + :members: + :private-members: + :undoc-members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/pinn.rst b/docs/source/_rst/pinn.rst new file mode 100644 index 0000000..4cf03af --- /dev/null +++ b/docs/source/_rst/pinn.rst @@ -0,0 +1,12 @@ +PINN +==== +.. currentmodule:: pina.pinn + +.. automodule:: pina.pinn + +.. autoclass:: PINN + :members: + :private-members: + :undoc-members: + :show-inheritance: + :noindex: diff --git a/docs/source/conf.py b/docs/source/conf.py index b2dbc70..766b7d6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,7 +54,8 @@ intersphinx_mapping = { 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), 'matplotlib': ('http://matplotlib.sourceforge.net/', None), - 'torch': ('https://pytorch.org/docs/stable/', None) + 'torch': ('https://pytorch.org/docs/stable/', None), + 'pina': ('https://mathlab.github.io/PINA/', None) } # Add any paths that contain templates here, relative to this directory. diff --git a/examples/problems/poisson.py b/examples/problems/poisson.py index 1ffc564..912db38 100644 --- a/examples/problems/poisson.py +++ b/examples/problems/poisson.py @@ -13,13 +13,14 @@ class Poisson(SpatialProblem): domain = Span({'x': [0, 1], 'y': [0, 1]}) def laplace_equation(input_, output_): - force_term = (torch.sin(input_['x']*torch.pi) * - torch.sin(input_['y']*torch.pi)) - return nabla(output_['u'], input_).flatten() - force_term + force_term = (torch.sin(input_.extract(['x'])*torch.pi) * + torch.sin(input_.extract(['y'])*torch.pi)) + nabla_u = nabla(output_.extract(['u']), input_) + return nabla_u - force_term def nil_dirichlet(input_, output_): value = 0.0 - return output_['u'] - value + return output_.extract(['u']) - value conditions = { 'gamma1': Condition(Span({'x': [-1, 1], 'y': 1}), nil_dirichlet), diff --git a/examples/run_poisson.py b/examples/run_poisson.py index a6fea52..b992f30 100644 --- a/examples/run_poisson.py +++ b/examples/run_poisson.py @@ -35,7 +35,7 @@ if __name__ == "__main__": poisson_problem = Poisson() model = FeedForward( - layers=[10, 10], + layers=[20, 20], output_variables=poisson_problem.output_variables, input_variables=poisson_problem.input_variables, func=Softplus, @@ -45,17 +45,17 @@ if __name__ == "__main__": pinn = PINN( poisson_problem, model, - lr=0.003, + lr=0.03, error_norm='mse', - regularizer=1e-8, - lr_accelerate=None) + regularizer=1e-8) if args.s: - pinn.span_pts(20, 'grid', ['D']) + print(pinn) pinn.span_pts(20, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4']) + pinn.span_pts(20, 'grid', ['D']) #pinn.plot_pts() - pinn.train(1000, 100) + pinn.train(5000, 100) with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_: for i, losses in enumerate(pinn.history): file_.write('{} {}\n'.format(i, sum(losses))) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 2b3b92c..be8ced1 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -81,11 +81,10 @@ class LabelTensor(torch.Tensor): Performs Tensor dtype and/or device conversion. For more details, see :meth:`torch.Tensor.to`. """ - new_obj = LabelTensor([], self.labels) - tempTensor = super().to(*args, **kwargs) - new_obj.data = tempTensor.data - new_obj.requires_grad = tempTensor.requires_grad - return new_obj + tmp = super().to(*args, **kwargs) + new = self.__class__.clone(self) + new.data = tmp.data + return new def extract(self, label_to_extract): """ @@ -111,9 +110,23 @@ class LabelTensor(torch.Tensor): except ValueError: raise ValueError('`label_to_extract` not in the labels list') - extracted_tensor = LabelTensor( - self[:, indeces], - [self.labels[idx] for idx in indeces] - ) + new_data = self[:, indeces].float() + new_labels = [self.labels[idx] for idx in indeces] + extracted_tensor = LabelTensor(new_data, new_labels) return extracted_tensor + + def append(self, lt): + """ + Return a copy of the merged tensors. + + :param LabelTensor lt: the tensor to merge. + :return: the merged tensors + :rtype: LabelTensor + """ + if set(self.labels).intersection(lt.labels): + raise RuntimeError('The tensors to merge have common labels') + + new_labels = self.labels + lt.labels + new_tensor = torch.cat((self, lt), dim=1) + return LabelTensor(new_tensor, new_labels) diff --git a/pina/location.py b/pina/location.py index 9a7260b..cbd42e4 100644 --- a/pina/location.py +++ b/pina/location.py @@ -1,7 +1,12 @@ +"""Module for Location class.""" + from abc import ABCMeta, abstractmethod class Location(metaclass=ABCMeta): + """ + Abstract class + """ @property @abstractmethod diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 7c2f61e..a2d715a 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -1,7 +1,9 @@ __all__ = [ 'FeedForward', 'MultiFeedForward' + 'DeepONet', ] from .feed_forward import FeedForward from .multi_feed_forward import MultiFeedForward +from .deeponet import DeepONet diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py new file mode 100644 index 0000000..6914a92 --- /dev/null +++ b/pina/model/deeponet.py @@ -0,0 +1,91 @@ +"""Module for DeepONet model""" +import torch +import torch.nn as nn + +from pina.label_tensor import LabelTensor + + +class DeepONet(torch.nn.Module): + """ + The PINA implementation of DeepONet network. + + .. seealso:: + + **Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning + nonlinear operators via DeepONet based on the universal approximation + theorem of operators*. Nat Mach Intell 3, 218–229 (2021). + DOI: `10.1038/s42256-021-00302-5 + `_ + + """ + def __init__(self, branch_net, trunk_net, output_variables): + """ + :param torch.nn.Module branch_net: the neural network to use as branch + model. It has to take as input a :class:`LabelTensor`. The number + of dimension of the output has to be the same of the `trunk_net`. + :param torch.nn.Module trunk_net: the neural network to use as trunk + model. It has to take as input a :class:`LabelTensor`. The number + of dimension of the output has to be the same of the `branch_net`. + :param list(str) output_variables: the list containing the labels + corresponding to the components of the output computed by the + model. + + :Example: + >>> branch = FFN(input_variables=['a', 'c'], output_variables=20) + >>> trunk = FFN(input_variables=['b'], output_variables=20) + >>> onet = DeepONet(trunk_net=trunk, branch_net=branch + >>> output_variables=output_vars) + DeepONet( + (trunk_net): FeedForward( + (extra_features): Sequential() + (model): Sequential( + (0): Linear(in_features=1, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=20, bias=True) + (3): Tanh() + (4): Linear(in_features=20, out_features=10, bias=True) + ) + ) + (branch_net): FeedForward( + (extra_features): Sequential() + (model): Sequential( + (0): Linear(in_features=2, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=20, bias=True) + (3): Tanh() + (4): Linear(in_features=20, out_features=10, bias=True) + ) + ) + ) + """ + super().__init__() + + self.trunk_net = trunk_net + self.branch_net = branch_net + + self.output_variables = output_variables + self.output_dimension = len(output_variables) + if self.output_dimension > 1: + raise NotImplementedError('Vectorial DeepONet to be implemented') + + @property + def input_variables(self): + """The input variables of the model""" + return self.trunk_net.input_variables + self.branch_net.input_variables + + def forward(self, x): + """ + Defines the computation performed at every call. + + :param LabelTensor x: the input tensor. + :return: the output computed by the model. + :rtype: LabelTensor + """ + branch_output = self.branch_net( + x.extract(self.branch_net.input_variables)) + trunk_output = self.trunk_net( + x.extract(self.trunk_net.input_variables)) + + output_ = torch.sum(branch_output * trunk_output, dim=1).reshape(-1, 1) + + return LabelTensor(output_, self.output_variables) diff --git a/pina/model/feed_forward.py b/pina/model/feed_forward.py index 96e82bb..4acb068 100644 --- a/pina/model/feed_forward.py +++ b/pina/model/feed_forward.py @@ -1,3 +1,4 @@ +"""Module for FeedForward model""" import torch import torch.nn as nn @@ -5,22 +6,50 @@ from pina.label_tensor import LabelTensor class FeedForward(torch.nn.Module): + """ + The PINA implementation of feedforward network, also refered as multilayer + perceptron. + :param list(str) input_variables: the list containing the labels + corresponding to the input components of the model. + :param list(str) output_variables: the list containing the labels + corresponding to the components of the output computed by the model. + :param int inner_size: number of neurons in the hidden layer(s). Default is + 20. + :param int n_layers: number of hidden layers. Default is 2. + :param func: the activation function to use. If a single + :class:`torch.nn.Module` is passed, this is used as activation function + after any layers, except the last one. If a list of Modules is passed, + they are used as activation functions at any layers, in order. + :param iterable(int) layers: a list containing the number of neurons for + any hidden layers. If specified, the parameters `n_layers` e + `inner_size` are not considered. + :param iterable(torch.nn.Module) extra_features: the additional input + features to use ad augmented input. + """ def __init__(self, input_variables, output_variables, inner_size=20, n_layers=2, func=nn.Tanh, layers=None, extra_features=None): - ''' - ''' + """ + """ super().__init__() if extra_features is None: extra_features = [] self.extra_features = nn.Sequential(*extra_features) - self.input_variables = input_variables - self.input_dimension = len(input_variables) + if isinstance(input_variables, int): + self.input_variables = None + self.input_dimension = input_variables + elif isinstance(input_variables, (tuple, list)): + self.input_variables = input_variables + self.input_dimension = len(input_variables) - self.output_variables = output_variables - self.output_dimension = len(output_variables) + if isinstance(output_variables, int): + self.output_variables = None + self.output_dimension = output_variables + elif isinstance(output_variables, (tuple, list)): + self.output_variables = output_variables + self.output_dimension = len(output_variables) n_features = len(extra_features) @@ -40,6 +69,9 @@ class FeedForward(torch.nn.Module): else: self.functions = [func for _ in range(len(self.layers)-1)] + if len(self.layers) != len(self.functions) + 1: + raise RuntimeError('uncosistent number of layers and functions') + unique_list = [] for layer, func in zip(self.layers[:-1], self.functions): unique_list.append(layer) @@ -51,18 +83,30 @@ class FeedForward(torch.nn.Module): def forward(self, x): """ + Defines the computation performed at every call. + + :param x: the input tensor. + :type x: :class:`pina.LabelTensor` + :return: the output computed by the model. + :rtype: LabelTensor """ + if self.input_variables: + x = x.extract(self.input_variables) - x = x[self.input_variables] - nf = len(self.extra_features) - if nf == 0: - return LabelTensor(self.model(x.tensor), self.output_variables) + labels = [] + features = [] + for i, feature in enumerate(self.extra_features): + labels.append('k{}'.format(i)) + features.append(feature(x)) - # if self.extra_features - input_ = torch.zeros(x.shape[0], nf+x.shape[1], dtype=x.dtype, - device=x.device) - input_[:, :x.shape[1]] = x.tensor - for i, feature in enumerate(self.extra_features, - start=self.input_dimension): - input_[:, i] = feature(x) - return LabelTensor(self.model(input_), self.output_variables) + if labels and features: + features = torch.cat(features, dim=1) + features_tensor = LabelTensor(features, labels) + input_ = x.append(features_tensor) # TODO error when no LabelTens + else: + input_ = x + + if self.output_variables: + return LabelTensor(self.model(input_), self.output_variables) + else: + return self.model(input_) diff --git a/pina/operators.py b/pina/operators.py index 8c284a3..5072fca 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -13,10 +13,10 @@ def grad(output_, input_): gradients = torch.autograd.grad( output_, - input_.tensor, + input_, grad_outputs=torch.ones(output_.size()).to( - dtype=input_.tensor.dtype, - device=input_.tensor.device), + dtype=input_.dtype, + device=input_.device), create_graph=True, retain_graph=True, allow_unused=True)[0] return LabelTensor(gradients, input_.labels) @@ -25,14 +25,14 @@ def div(output_, input_): """ TODO """ - if output_.tensor.shape[1] == 1: - div = grad(output_.tensor, input_).sum(axis=1) + if output_.shape[1] == 1: + div = grad(output_, input_).sum(axis=1) else: # really to improve a = [] - for o in output_.tensor.T: - a.append(grad(o, input_).tensor) - div = torch.zeros(output_.tensor.shape[0], 1) - for i in range(output_.tensor.shape[1]): + for o in output_.T: + a.append(grad(o, input_)) + div = torch.zeros(output_.shape[0], 1) + for i in range(output_.shape[1]): div += a[i][:, i].reshape(-1, 1) return div diff --git a/pina/pinn.py b/pina/pinn.py index fc3d233..2020058 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -14,36 +14,23 @@ class PINN(object): lr=0.001, regularizer=0.00001, data_weight=1., - dtype=torch.float64, + dtype=torch.float32, device='cpu', - lr_accelerate=None, error_norm='mse'): ''' :param Problem problem: the formualation of the problem. - :param dict architecture: a dictionary containing the information to - build the model. Valid options are: - - inner_size [int] the number of neurons in the hidden layers; by - default is 20. - - n_layers [int] the number of hidden layers; by default is 4. - - func [nn.Module or str] the activation function; passing a `str` - is possible to chose adaptive function (between 'adapt_tanh'); by - default is non-adaptive iperbolic tangent. - :param float lr: the learning rate; default is 0.001 - :param float regularizer: the coefficient for L2 regularizer term + :param torch.nn.Module model: the neural network model to use. + :param float lr: the learning rate; default is 0.001. + :param float regularizer: the coefficient for L2 regularizer term. :param type dtype: the data type to use for the model. Valid option are `torch.float32` and `torch.float64` (`torch.float16` only on GPU); default is `torch.float64`. - :param float lr_accelete: the coefficient that controls the learning - rate increase, such that, for all the epoches in which the loss is - decreasing, the learning_rate is update using - $learning_rate = learning_rate * lr_accelerate$. - When the loss stops to decrease, the learning rate is set to the - initial value [TODO test parameters] - ''' - self.problem = problem + if dtype == torch.float64: + raise NotImplementedError('only float for now') + self.problem = problem # self._architecture = architecture if architecture else dict() # self._architecture['input_dimension'] = self.problem.domain_bound.shape[0] @@ -51,8 +38,6 @@ class PINN(object): # if hasattr(self.problem, 'params_domain'): # self._architecture['input_dimension'] += self.problem.params_domain.shape[0] - self.accelerate = lr_accelerate - self.error_norm = error_norm if device == 'cuda' and not torch.cuda.is_available(): @@ -85,26 +70,6 @@ class PINN(object): raise TypeError self._problem = problem - def get_data_residuals(self): - - data_residuals = [] - - for output in self.data_pts: - data_values_pred = self.model(self.data_pts[output]) - data_residuals.append(data_values_pred - self.data_values[output]) - - return torch.cat(data_residuals) - - def get_phys_residuals(self): - """ - """ - - residuals = [] - for equation in self.problem.equation: - residuals.append(equation(self.phys_pts, self.model(self.phys_pts))) - return residuals - - def _compute_norm(self, vec): """ Compute the norm of the `vec` one-dimensional tensor based on the @@ -156,10 +121,6 @@ class PINN(object): def span_pts(self, n, mode='grid', locations='all'): - ''' - - ''' - if locations == 'all': locations = [condition for condition in self.problem.conditions] @@ -171,12 +132,12 @@ class PINN(object): except: pts = condition.input_points - print(location, pts) - - self.input_pts[location] = pts - self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device) - self.input_pts[location].tensor.requires_grad_(True) - self.input_pts[location].tensor.retain_grad() + self.input_pts[location] = pts#.double() # TODO + self.input_pts[location] = ( + self.input_pts[location].to(dtype=self.dtype, + device=self.device)) + self.input_pts[location].requires_grad_(True) + self.input_pts[location].retain_grad() @@ -277,214 +238,3 @@ class PINN(object): print("") print("Something went wrong...") print("Not able to compute the error. Please pass a data solution or a true solution") - - - - - def plot(self, res, filename=None, variable=None): - ''' - ''' - import matplotlib - matplotlib.use('Agg') - import matplotlib.pyplot as plt - - self._plot_2D(res, filename, variable) - print('TTTTTTTTTTTTTTTTTt') - print(self.problem.bounds) - pts_container = [] - #for mn, mx in [[-1, 1], [-1, 1]]: - for mn, mx in [[0, 1], [0, 1]]: - #for mn, mx in [[-1, 1], [0, 1]]: - pts_container.append(np.linspace(mn, mx, res)) - grids_container = np.meshgrid(*pts_container) - unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T - unrolled_pts.to(dtype=self.dtype) - Z_pred = self.model(unrolled_pts) - - ####################################################### - # poisson - # Z_truth = self.problem.truth_solution(unrolled_pts[:, 0], unrolled_pts[:, 1]) - # Z_pred = Z_pred.tensor.detach().reshape(grids_container[0].shape) - # Z_truth = Z_truth.detach().reshape(grids_container[0].shape) - # err = np.abs(Z_pred-Z_truth) - - - # with open('poisson2_nofeat_plot.txt', 'w') as f_: - # f_.write('x y truth pred e\n') - # for (x, y), tru, pre, e in zip(unrolled_pts, Z_truth.reshape(-1, 1), Z_pred.reshape(-1, 1), err.reshape(-1, 1)): - # f_.write('{} {} {} {} {}\n'.format(x.item(), y.item(), tru.item(), pre.item(), e.item())) - # n = Z_pred.shape[1] - # plt.figure(figsize=(16, 6)) - # plt.subplot(1, 3, 1) - # plt.contourf(*grids_container, Z_truth) - # plt.colorbar() - # plt.subplot(1, 3, 2) - # plt.contourf(*grids_container, Z_pred) - # plt.colorbar() - # plt.subplot(1, 3, 3) - # plt.contourf(*grids_container, err) - # plt.colorbar() - # plt.show() - - ####################################################### - # burgers - import scipy - data = scipy.io.loadmat('Data/burgers_shock.mat') - data_solution = {'grid': np.meshgrid(data['x'], data['t']), 'grid_solution': data['usol'].T} - - grids_container = data_solution['grid'] - print(data_solution['grid_solution'].shape) - unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T - unrolled_pts.to(dtype=self.dtype) - Z_pred = self.model(unrolled_pts) - Z_truth = data_solution['grid_solution'] - - Z_pred = Z_pred.tensor.detach().reshape(grids_container[0].shape) - print(Z_pred, Z_truth) - err = np.abs(Z_pred.numpy()-Z_truth) - - - with open('burgers_nofeat_plot.txt', 'w') as f_: - f_.write('x y truth pred e\n') - for (x, y), tru, pre, e in zip(unrolled_pts, Z_truth.reshape(-1, 1), Z_pred.reshape(-1, 1), err.reshape(-1, 1)): - f_.write('{} {} {} {} {}\n'.format(x.item(), y.item(), tru.item(), pre.item(), e.item())) - n = Z_pred.shape[1] - plt.figure(figsize=(16, 6)) - plt.subplot(1, 3, 1) - plt.contourf(*grids_container, Z_truth,vmin=-1, vmax=1) - plt.colorbar() - plt.subplot(1, 3, 2) - plt.contourf(*grids_container, Z_pred, vmin=-1, vmax=1) - plt.colorbar() - plt.subplot(1, 3, 3) - plt.contourf(*grids_container, err) - plt.colorbar() - plt.show() - - - # for i, output in enumerate(Z_pred.tensor.T, start=1): - # output = output.detach().numpy().reshape(grids_container[0].shape) - # plt.subplot(1, n, i) - # plt.contourf(*grids_container, output) - # plt.colorbar() - - if filename is None: - plt.show() - else: - plt.savefig(filename) - - def plot_params(self, res, param, filename=None, variable=None): - ''' - ''' - import matplotlib - matplotlib.use('GTK3Agg') - import matplotlib.pyplot as plt - - if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None: - n_plot = 2 - elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None: - n_plot = 2 - else: - n_plot = 1 - - fig, axs = plt.subplots(nrows=1, ncols=n_plot, figsize=(n_plot*6,4)) - if not isinstance(axs, np.ndarray): axs = [axs] - - if hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None: - grids_container = self.problem.data_solution['grid'] - Z_true = self.problem.data_solution['grid_solution'] - elif hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None: - - pts_container = [] - for mn, mx in self.problem.domain_bound: - pts_container.append(np.linspace(mn, mx, res)) - - grids_container = np.meshgrid(*pts_container) - Z_true = self.problem.truth_solution(*grids_container) - - pts_container = [] - for mn, mx in self.problem.domain_bound: - pts_container.append(np.linspace(mn, mx, res)) - grids_container = np.meshgrid(*pts_container) - unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(dtype=self.type) - #print(unrolled_pts) - #print(param) - param_unrolled_pts = torch.cat((unrolled_pts, param.repeat(unrolled_pts.shape[0], 1)), 1) - if variable==None: - variable = self.problem.variables[0] - Z_pred = self.evaluate(param_unrolled_pts)[variable] - variable = "Solution" - else: - Z_pred = self.evaluate(param_unrolled_pts)[variable] - - Z_pred= Z_pred.detach().numpy().reshape(grids_container[0].shape) - set_pred = axs[0].contourf(*grids_container, Z_pred) - axs[0].set_title('PINN [trained epoch = {}]'.format(self.trained_epoch) + " " + variable) #TODO add info about parameter in the title - fig.colorbar(set_pred, ax=axs[0]) - - if n_plot == 2: - - set_true = axs[1].contourf(*grids_container, Z_true) - - axs[1].set_title('Truth solution') - fig.colorbar(set_true, ax=axs[1]) - - if filename is None: - plt.show() - else: - fig.savefig(filename + " " + variable) - - def plot_error(self, res, filename=None): - import matplotlib - matplotlib.use('GTK3Agg') - import matplotlib.pyplot as plt - - - fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(6,4)) - if not isinstance(axs, np.ndarray): axs = [axs] - - if hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None: - grids_container = self.problem.data_solution['grid'] - Z_true = self.problem.data_solution['grid_solution'] - elif hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None: - pts_container = [] - for mn, mx in self.problem.domain_bound: - pts_container.append(np.linspace(mn, mx, res)) - - grids_container = np.meshgrid(*pts_container) - Z_true = self.problem.truth_solution(*grids_container) - try: - unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(dtype=self.type) - - Z_pred = self.model(unrolled_pts) - Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape) - set_pred = axs[0].contourf(*grids_container, abs(Z_pred - Z_true)) - axs[0].set_title('PINN [trained epoch = {}]'.format(self.trained_epoch) + "Pointwise Error") - fig.colorbar(set_pred, ax=axs[0]) - - if filename is None: - plt.show() - else: - fig.savefig(filename) - except: - print("") - print("Something went wrong...") - print("Not able to plot the error. Please pass a data solution or a true solution") - -''' -print(self.pred_loss.item(),loss.item(), self.old_loss.item()) -if self.accelerate is not None: - if self.pred_loss > loss and loss >= self.old_loss: - self.current_lr = self.original_lr - #print('restart') - elif (loss-self.pred_loss).item() < 0.1: - self.current_lr += .5*self.current_lr - #print('powa') - else: - self.current_lr -= .5*self.current_lr - #print(self.current_lr) - #self.current_lr = min(loss.item()*3, 0.02) - - for g in self.optimizer.param_groups: - g['lr'] = self.current_lr -''' diff --git a/pina/plotter.py b/pina/plotter.py index 3fb3b95..848af36 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -1,6 +1,6 @@ """ Module for plotting. """ import matplotlib -#matplotlib.use('Qt5Agg') +matplotlib.use('Qt5Agg') import matplotlib.pyplot as plt import numpy as np import torch @@ -119,16 +119,15 @@ class Plotter: """ res = 256 pts = obj.problem.domain.sample(res, 'grid') - print(pts) grids_container = [ - pts.tensor[:, 0].reshape(res, res), - pts.tensor[:, 1].reshape(res, res), + pts[:, 0].reshape(res, res), + pts[:, 1].reshape(res, res), ] predicted_output = obj.model(pts) - predicted_output = predicted_output['u'] + predicted_output = predicted_output.extract(['u']) if hasattr(obj.problem, 'truth_solution'): - truth_output = obj.problem.truth_solution(*pts.tensor.T).float() + truth_output = obj.problem.truth_solution(*pts.T).float() fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach()) @@ -139,7 +138,6 @@ class Plotter: fig.colorbar(cb, ax=axes[2]) else: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) - # cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach()) fig.colorbar(cb, ax=axes) @@ -153,7 +151,7 @@ class Plotter: def plot_samples(self, obj): for location in obj.input_pts: - plt.plot(*obj.input_pts[location].tensor.T.detach(), '.', label=location) + plt.plot(*obj.input_pts[location].T.detach(), '.', label=location) plt.legend() plt.show() diff --git a/pina/span.py b/pina/span.py index f3c55e5..8eb268f 100644 --- a/pina/span.py +++ b/pina/span.py @@ -37,7 +37,6 @@ class Span(Location): for _ in range(bounds.shape[0])]) grids = np.meshgrid(*pts) pts = np.hstack([grid.reshape(-1, 1) for grid in grids]) - print(pts) elif mode == 'lh' or mode == 'latin': from scipy.stats import qmc sampler = qmc.LatinHypercube(d=bounds.shape[0]) @@ -46,15 +45,17 @@ class Span(Location): # Scale pts pts *= bounds[:, 1] - bounds[:, 0] pts += bounds[:, 0] + pts = pts.astype(np.float32) pts = torch.from_numpy(pts) - pts_range_ = LabelTensor(pts, list(self.range_.keys())) fixed = torch.Tensor(list(self.fixed_.values())) - pts_fixed_ = torch.ones(pts_range_.tensor.shape[0], len(self.fixed_)) * fixed + pts_fixed_ = torch.ones(pts.shape[0], len(self.fixed_), + dtype=pts.dtype) * fixed + pts_range_ = LabelTensor(pts, list(self.range_.keys())) pts_fixed_ = LabelTensor(pts_fixed_, list(self.fixed_.keys())) if self.fixed_: - return LabelTensor.hstack([pts_range_, pts_fixed_]) + return pts_range_.append(pts_fixed_) else: return pts_range_ diff --git a/tests/test_deeponet.py b/tests/test_deeponet.py new file mode 100644 index 0000000..517f045 --- /dev/null +++ b/tests/test_deeponet.py @@ -0,0 +1,26 @@ +import torch +import pytest + +from pina import LabelTensor +from pina.model import DeepONet, FeedForward as FFN + + +data = torch.rand((20, 3)) +input_vars = ['a', 'b', 'c'] +output_vars = ['d'] +input_ = LabelTensor(data, input_vars) + + +def test_constructor(): + branch = FFN(input_variables=['a', 'c'], output_variables=20) + trunk = FFN(input_variables=['b'], output_variables=20) + onet = DeepONet(trunk_net=trunk, branch_net=branch, + output_variables=output_vars) + +def test_forward(): + branch = FFN(input_variables=['a', 'c'], output_variables=10) + trunk = FFN(input_variables=['b'], output_variables=10) + onet = DeepONet(trunk_net=trunk, branch_net=branch, + output_variables=output_vars) + output_ = onet(input_) + assert output_.labels == output_vars diff --git a/tests/test_fnn.py b/tests/test_fnn.py new file mode 100644 index 0000000..2aa367e --- /dev/null +++ b/tests/test_fnn.py @@ -0,0 +1,58 @@ +import torch +import pytest + +from pina import LabelTensor +from pina.model import FeedForward + +class myFeature(torch.nn.Module): + """ + Feature: sin(pi*x) + """ + def __init__(self): + super(myFeature, self).__init__() + + def forward(self, x): + return torch.sin(torch.pi * x.extract('a')) + + +data = torch.rand((20, 3)) +input_vars = ['a', 'b', 'c'] +output_vars = ['d', 'e'] +input_ = LabelTensor(data, input_vars) + + +def test_constructor(): + FeedForward(input_vars, output_vars) + FeedForward(3, 4) + FeedForward(input_vars, output_vars, extra_features=[myFeature()]) + FeedForward(input_vars, output_vars, inner_size=10, n_layers=20) + FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2]) + FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2], + func=torch.nn.ReLU) + FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2], + func=[torch.nn.ReLU, torch.nn.ReLU, None, torch.nn.Tanh]) + + +def test_constructor_wrong(): + with pytest.raises(RuntimeError): + FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2], + func=[torch.nn.ReLU, torch.nn.ReLU]) + + +def test_forward(): + fnn = FeedForward(input_vars, output_vars) + output_ = fnn(input_) + assert output_.labels == output_vars + + +def test_forward2(): + dim_in, dim_out = 3, 2 + fnn = FeedForward(dim_in, dim_out) + output_ = fnn(input_) + assert output_.shape == (input_.shape[0], dim_out) + + +def test_forward_features(): + fnn = FeedForward(input_vars, output_vars, extra_features=[myFeature()]) + output_ = fnn(input_) + assert output_.labels == output_vars diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index dbf0d44..b2ab58f 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -59,3 +59,22 @@ def test_extract_order(): assert new.labels == label_to_extract assert new.shape[1] == len(label_to_extract) assert torch.all(torch.isclose(expected, new)) + + +def test_merge(): + tensor = LabelTensor(data, labels) + tensor_a = tensor.extract('a') + tensor_b = tensor.extract('b') + tensor_c = tensor.extract('c') + + tensor_bc = tensor_b.append(tensor_c) + assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) + +def test_merge(): + tensor = LabelTensor(data, labels) + tensor_a = tensor.extract('a') + tensor_b = tensor.extract('b') + tensor_c = tensor.extract('c') + + tensor_bb = tensor_b.append(tensor_b) + assert torch.allclose(tensor_b, tensor.extract(['b', 'c']))